Adibvafa
Improve style
7393de8
raw
history blame
3.79 kB
from langchain.tools import BaseTool
from medrax.rag.rag import RAGConfig, CohereRAG
from langchain.chains import RetrievalQA
from typing import Dict, Tuple, Any
class RAGTool(BaseTool):
"""Tool for answering medical questions using RAG with comprehensive medical knowledge base.
This tool leverages Retrieval-Augmented Generation (RAG) to answer medical questions by
retrieving relevant information from a curated knowledge base of medical textbooks, research papers,
and clinical manuals. The tool uses advanced embedding models to find contextually relevant
information and then generates comprehensive, evidence-based answers using large language models.
The knowledge base includes:
- Medical textbooks and reference materials
- Research papers and clinical studies
- Medical manuals and guidelines
- Specialized medical literature
Args:
config (RAGConfig): Configuration object containing model settings, embedding model,
knowledge base paths, and other RAG system parameters
"""
name: str = "medical_knowledge_rag"
description: str = (
"Answers medical questions using Retrieval-Augmented Generation with a comprehensive medical knowledge base. "
"Retrieves relevant information from medical textbooks, research papers, and clinical manuals, "
"then generates evidence-based answers using advanced language models. "
"Input should be a medical question or query in natural language. "
"Output includes a comprehensive answer with supporting source documents and metadata. "
)
rag: CohereRAG = None
chain: RetrievalQA = None
def __init__(
self,
config: RAGConfig,
):
"""Initialize RAG tool with configuration.
Args:
config (RAGConfig): Configuration for the RAG system including model settings,
embedding model, knowledge base paths, and retrieval parameters
"""
super().__init__()
self.rag = CohereRAG(config)
self.chain = self.rag.initialize_rag(with_memory=True)
def _run(self, query: str) -> Tuple[Dict[str, Any], Dict]:
"""Execute the RAG tool with the given query.
Args:
query (str): Medical question to answer
Returns:
Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary
"""
try:
result = self.chain.invoke({"query": query})
output = {
"answer": result["result"],
"source_documents": [
{"content": doc.page_content, "metadata": doc.metadata}
for doc in result.get("source_documents", [])
],
}
metadata = {
"query": query,
"analysis_status": "completed",
"num_sources": len(result.get("source_documents", [])),
"model": self.rag.config.model,
"embedding_model": self.rag.config.embedding_model,
}
return output, metadata
except Exception as e:
output = {"error": str(e)}
metadata = {
"query": query,
"analysis_status": "failed",
"error_details": str(e),
}
return output, metadata
async def _arun(self, query: str) -> Tuple[Dict[str, Any], Dict]:
"""Async version of _run.
Args:
query (str): Medical question to answer
Returns:
Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary
Raises:
NotImplementedError: Async not implemented yet
"""
return self._run(query)