| from langchain.tools import BaseTool | |
| from medrax.rag.rag import RAGConfig, CohereRAG | |
| from langchain.chains import RetrievalQA | |
| from typing import Dict, Tuple, Any | |
| class RAGTool(BaseTool): | |
| """Tool for answering medical questions using RAG with comprehensive medical knowledge base. | |
| This tool leverages Retrieval-Augmented Generation (RAG) to answer medical questions by | |
| retrieving relevant information from a curated knowledge base of medical textbooks, research papers, | |
| and clinical manuals. The tool uses advanced embedding models to find contextually relevant | |
| information and then generates comprehensive, evidence-based answers using large language models. | |
| The knowledge base includes: | |
| - Medical textbooks and reference materials | |
| - Research papers and clinical studies | |
| - Medical manuals and guidelines | |
| - Specialized medical literature | |
| Args: | |
| config (RAGConfig): Configuration object containing model settings, embedding model, | |
| knowledge base paths, and other RAG system parameters | |
| """ | |
| name: str = "medical_knowledge_rag" | |
| description: str = ( | |
| "Answers medical questions using Retrieval-Augmented Generation with a comprehensive medical knowledge base. " | |
| "Retrieves relevant information from medical textbooks, research papers, and clinical manuals, " | |
| "then generates evidence-based answers using advanced language models. " | |
| "Input should be a medical question or query in natural language. " | |
| "Output includes a comprehensive answer with supporting source documents and metadata. " | |
| ) | |
| rag: CohereRAG = None | |
| chain: RetrievalQA = None | |
| def __init__( | |
| self, | |
| config: RAGConfig, | |
| ): | |
| """Initialize RAG tool with configuration. | |
| Args: | |
| config (RAGConfig): Configuration for the RAG system including model settings, | |
| embedding model, knowledge base paths, and retrieval parameters | |
| """ | |
| super().__init__() | |
| self.rag = CohereRAG(config) | |
| self.chain = self.rag.initialize_rag(with_memory=True) | |
| def _run(self, query: str) -> Tuple[Dict[str, Any], Dict]: | |
| """Execute the RAG tool with the given query. | |
| Args: | |
| query (str): Medical question to answer | |
| Returns: | |
| Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary | |
| """ | |
| try: | |
| result = self.chain.invoke({"query": query}) | |
| output = { | |
| "answer": result["result"], | |
| "source_documents": [ | |
| {"content": doc.page_content, "metadata": doc.metadata} | |
| for doc in result.get("source_documents", []) | |
| ], | |
| } | |
| metadata = { | |
| "query": query, | |
| "analysis_status": "completed", | |
| "num_sources": len(result.get("source_documents", [])), | |
| "model": self.rag.config.model, | |
| "embedding_model": self.rag.config.embedding_model, | |
| } | |
| return output, metadata | |
| except Exception as e: | |
| output = {"error": str(e)} | |
| metadata = { | |
| "query": query, | |
| "analysis_status": "failed", | |
| "error_details": str(e), | |
| } | |
| return output, metadata | |
| async def _arun(self, query: str) -> Tuple[Dict[str, Any], Dict]: | |
| """Async version of _run. | |
| Args: | |
| query (str): Medical question to answer | |
| Returns: | |
| Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary | |
| Raises: | |
| NotImplementedError: Async not implemented yet | |
| """ | |
| return self._run(query) | |