Spaces:
Running
on
T4
Running
on
T4
Junzhe Li
commited on
Commit
·
fd330d9
1
Parent(s):
d38594f
fixing agent
Browse files
benchmarking/llm_providers/medrax_provider.py
CHANGED
|
@@ -39,7 +39,7 @@ class MedRAXProvider(LLMProvider):
|
|
| 39 |
"TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 40 |
"ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
| 41 |
"ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
|
| 42 |
-
"XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 43 |
"MedGemmaVQATool", # Google MedGemma VQA tool
|
| 44 |
"MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 45 |
"WebBrowserTool", # For web browsing and search capabilities
|
|
|
|
| 39 |
"TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 40 |
"ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
| 41 |
"ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
|
| 42 |
+
# "XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 43 |
"MedGemmaVQATool", # Google MedGemma VQA tool
|
| 44 |
"MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 45 |
"WebBrowserTool", # For web browsing and search capabilities
|
medrax/agent/agent.py
CHANGED
|
@@ -6,7 +6,7 @@ from datetime import datetime
|
|
| 6 |
from typing import List, Dict, Any, TypedDict, Annotated, Optional
|
| 7 |
|
| 8 |
from langgraph.graph import StateGraph, END
|
| 9 |
-
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage
|
| 10 |
from langchain_core.language_models import BaseLanguageModel
|
| 11 |
from langchain_core.tools import BaseTool
|
| 12 |
|
|
@@ -112,8 +112,27 @@ class Agent:
|
|
| 112 |
Dict[str, List[AnyMessage]]: A dictionary containing the model's response.
|
| 113 |
"""
|
| 114 |
messages = state["messages"]
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
| 116 |
messages = [SystemMessage(content=self.system_prompt)] + messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
response = self.model.invoke(messages)
|
| 118 |
return {"messages": [response]}
|
| 119 |
|
|
|
|
| 6 |
from typing import List, Dict, Any, TypedDict, Annotated, Optional
|
| 7 |
|
| 8 |
from langgraph.graph import StateGraph, END
|
| 9 |
+
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage, HumanMessage
|
| 10 |
from langchain_core.language_models import BaseLanguageModel
|
| 11 |
from langchain_core.tools import BaseTool
|
| 12 |
|
|
|
|
| 112 |
Dict[str, List[AnyMessage]]: A dictionary containing the model's response.
|
| 113 |
"""
|
| 114 |
messages = state["messages"]
|
| 115 |
+
|
| 116 |
+
# Only add system prompt if it's not already present (i.e., first call in this conversation)
|
| 117 |
+
# This avoids redundantly sending the system prompt on every model invocation
|
| 118 |
+
if self.system_prompt and (len(messages) == 0 or not isinstance(messages[0], SystemMessage)):
|
| 119 |
messages = [SystemMessage(content=self.system_prompt)] + messages
|
| 120 |
+
|
| 121 |
+
# Check if we just executed tools by checking if the last message is a ToolMessage
|
| 122 |
+
# This indicates we're in the process node immediately after the execute node
|
| 123 |
+
has_tool_results = len(messages) > 0 and isinstance(messages[-1], ToolMessage)
|
| 124 |
+
|
| 125 |
+
# If we have tool results, add explicit instruction to continue reasoning
|
| 126 |
+
# This is especially important for models like Gemini that may stop without generating output
|
| 127 |
+
# Use HumanMessage instead of SystemMessage as it's more compatible with all models
|
| 128 |
+
# The prompt allows the model to call more tools OR provide final answer, giving it flexibility
|
| 129 |
+
if has_tool_results:
|
| 130 |
+
synthesis_prompt = HumanMessage(
|
| 131 |
+
content="Review the tool results above. If you need more information, you can call additional tools. "
|
| 132 |
+
"Otherwise, provide your complete final answer synthesizing all the information."
|
| 133 |
+
)
|
| 134 |
+
messages = messages + [synthesis_prompt]
|
| 135 |
+
|
| 136 |
response = self.model.invoke(messages)
|
| 137 |
return {"messages": [response]}
|
| 138 |
|