Junzhe Li commited on
Commit
fd330d9
·
1 Parent(s): d38594f

fixing agent

Browse files
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -39,7 +39,7 @@ class MedRAXProvider(LLMProvider):
39
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
40
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
41
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
42
- "XRayPhraseGroundingTool", # For locating described features in X-rays
43
  "MedGemmaVQATool", # Google MedGemma VQA tool
44
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
45
  "WebBrowserTool", # For web browsing and search capabilities
 
39
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
40
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
41
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
42
+ # "XRayPhraseGroundingTool", # For locating described features in X-rays
43
  "MedGemmaVQATool", # Google MedGemma VQA tool
44
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
45
  "WebBrowserTool", # For web browsing and search capabilities
medrax/agent/agent.py CHANGED
@@ -6,7 +6,7 @@ from datetime import datetime
6
  from typing import List, Dict, Any, TypedDict, Annotated, Optional
7
 
8
  from langgraph.graph import StateGraph, END
9
- from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage
10
  from langchain_core.language_models import BaseLanguageModel
11
  from langchain_core.tools import BaseTool
12
 
@@ -112,8 +112,27 @@ class Agent:
112
  Dict[str, List[AnyMessage]]: A dictionary containing the model's response.
113
  """
114
  messages = state["messages"]
115
- if self.system_prompt:
 
 
 
116
  messages = [SystemMessage(content=self.system_prompt)] + messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  response = self.model.invoke(messages)
118
  return {"messages": [response]}
119
 
 
6
  from typing import List, Dict, Any, TypedDict, Annotated, Optional
7
 
8
  from langgraph.graph import StateGraph, END
9
+ from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage, HumanMessage
10
  from langchain_core.language_models import BaseLanguageModel
11
  from langchain_core.tools import BaseTool
12
 
 
112
  Dict[str, List[AnyMessage]]: A dictionary containing the model's response.
113
  """
114
  messages = state["messages"]
115
+
116
+ # Only add system prompt if it's not already present (i.e., first call in this conversation)
117
+ # This avoids redundantly sending the system prompt on every model invocation
118
+ if self.system_prompt and (len(messages) == 0 or not isinstance(messages[0], SystemMessage)):
119
  messages = [SystemMessage(content=self.system_prompt)] + messages
120
+
121
+ # Check if we just executed tools by checking if the last message is a ToolMessage
122
+ # This indicates we're in the process node immediately after the execute node
123
+ has_tool_results = len(messages) > 0 and isinstance(messages[-1], ToolMessage)
124
+
125
+ # If we have tool results, add explicit instruction to continue reasoning
126
+ # This is especially important for models like Gemini that may stop without generating output
127
+ # Use HumanMessage instead of SystemMessage as it's more compatible with all models
128
+ # The prompt allows the model to call more tools OR provide final answer, giving it flexibility
129
+ if has_tool_results:
130
+ synthesis_prompt = HumanMessage(
131
+ content="Review the tool results above. If you need more information, you can call additional tools. "
132
+ "Otherwise, provide your complete final answer synthesizing all the information."
133
+ )
134
+ messages = messages + [synthesis_prompt]
135
+
136
  response = self.model.invoke(messages)
137
  return {"messages": [response]}
138