VictorLJZ commited on
Commit
e4e9fae
·
2 Parent(s): fd330d9 373615b

Merge branch 'main' into tool-changes

Browse files
Files changed (48) hide show
  1. api.py +342 -0
  2. benchmarking/llm_providers/medrax_provider.py +2 -1
  3. interface.py +46 -56
  4. main.py +374 -62
  5. medrax/agent/agent.py +13 -106
  6. medrax/docs/system_prompts.txt +3 -1
  7. medrax/llava/conversation.py +1 -3
  8. medrax/llava/eval/eval_multimodal_chat_gpt_score.py +3 -6
  9. medrax/llava/eval/llm.py +8 -23
  10. medrax/llava/eval/model_vqa.py +2 -8
  11. medrax/llava/eval/summarize_gpt_review.py +3 -7
  12. medrax/llava/mm_utils.py +4 -14
  13. medrax/llava/model/builder.py +4 -12
  14. medrax/llava/model/language_model/llava_mistral.py +1 -3
  15. medrax/llava/model/llava_arch.py +13 -39
  16. medrax/llava/model/multimodal_encoder/builder.py +2 -8
  17. medrax/llava/model/multimodal_projector/builder.py +1 -3
  18. medrax/llava/serve/cli.py +1 -3
  19. medrax/llava/serve/controller.py +3 -6
  20. medrax/llava/serve/gradio_web_server.py +4 -12
  21. medrax/llava/serve/model_worker.py +6 -14
  22. medrax/llava/serve/test_message.py +2 -6
  23. medrax/llava/utils.py +1 -3
  24. medrax/models/model_factory.py +5 -12
  25. medrax/rag/rag.py +3 -9
  26. medrax/tools/browsing/__init__.py +3 -3
  27. medrax/tools/browsing/duckduckgo.py +12 -33
  28. medrax/tools/browsing/web_browser.py +3 -9
  29. medrax/tools/classification/__init__.py +1 -6
  30. medrax/tools/classification/arcplus.py +5 -17
  31. medrax/tools/classification/torchxrayvision.py +1 -3
  32. medrax/tools/dicom.py +1 -3
  33. medrax/tools/grounding.py +4 -13
  34. medrax/tools/rag.py +1 -1
  35. medrax/tools/report_generation.py +4 -14
  36. medrax/tools/segmentation/__init__.py +1 -7
  37. medrax/tools/segmentation/medsam2.py +69 -79
  38. medrax/tools/segmentation/segmentation.py +10 -30
  39. medrax/tools/utils.py +5 -15
  40. medrax/tools/vqa/__init__.py +4 -4
  41. medrax/tools/vqa/llava_med.py +4 -12
  42. medrax/tools/vqa/medgemma/medgemma.py +51 -11
  43. medrax/tools/vqa/medgemma/medgemma_client.py +12 -4
  44. medrax/tools/vqa/medgemma/medgemma_requirements_standard.txt +1 -1
  45. medrax/tools/vqa/medgemma/medgemma_setup.py +91 -4
  46. medrax/tools/vqa/xray_vqa.py +6 -12
  47. medrax/tools/xray_generation.py +12 -23
  48. pyproject.toml +11 -9
api.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MedRAX API Module
3
+
4
+ This module provides a FastAPI-based REST API for the MedRAX medical imaging AI assistant.
5
+ It offers endpoints for processing medical images with text queries using the same agent
6
+ architecture as the Gradio interface.
7
+
8
+ The API supports:
9
+ - Text-only queries
10
+ - Single or multiple image inputs
11
+ - Optional custom system prompts
12
+ - Automatic thread management for each request
13
+ - Tool execution and result aggregation
14
+ """
15
+
16
+ import uuid
17
+ import base64
18
+ from pathlib import Path
19
+ from typing import List, Optional, Dict, Any
20
+ import re
21
+ import time
22
+
23
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
24
+ from fastapi.middleware.cors import CORSMiddleware
25
+ from pydantic import BaseModel, Field
26
+ from langchain_core.messages import AIMessage, ToolMessage
27
+
28
+ # Import MedRAX components
29
+ from medrax.agent import Agent
30
+
31
+
32
+ class QueryRequest(BaseModel):
33
+ """
34
+ Request model for text-only queries.
35
+
36
+ Attributes:
37
+ question (str): The question or query to ask the agent
38
+ system_prompt (Optional[str]): Custom system prompt to override default
39
+ thread_id (Optional[str]): Optional thread ID for conversation continuity
40
+ """
41
+
42
+ question: str = Field(..., description="The question or query to ask the agent")
43
+ system_prompt: Optional[str] = Field(None, description="Custom system prompt to override default")
44
+ thread_id: Optional[str] = Field(None, description="Optional thread ID for conversation continuity")
45
+
46
+
47
+ class QueryResponse(BaseModel):
48
+ """
49
+ Response model for API queries.
50
+
51
+ Attributes:
52
+ response (str): The agent's text response
53
+ thread_id (str): The thread ID used for this conversation
54
+ tools_used (List[str]): List of tools that were executed
55
+ processing_time (float): Time taken to process the request in seconds
56
+ """
57
+
58
+ response: str = Field(..., description="The agent's text response")
59
+ thread_id: str = Field(..., description="The thread ID used for this conversation")
60
+ tools_used: List[str] = Field(..., description="List of tools that were executed")
61
+ processing_time: float = Field(..., description="Time taken to process the request in seconds")
62
+
63
+
64
+ class MedRAXAPI:
65
+ """
66
+ FastAPI application wrapper for the MedRAX agent.
67
+
68
+ This class provides a clean interface for creating and managing the API endpoints
69
+ while maintaining separation of concerns from the core agent functionality.
70
+ """
71
+
72
+ def __init__(self, agent: Agent, tools_dict: Dict[str, Any], temp_dir: str = "temp_api"):
73
+ """
74
+ Initialize the MedRAX API.
75
+
76
+ Args:
77
+ agent (Agent): The initialized MedRAX agent
78
+ tools_dict (Dict[str, Any]): Dictionary of available tools
79
+ temp_dir (str): Directory for temporary file storage
80
+ """
81
+ self.agent = agent
82
+ self.tools_dict = tools_dict
83
+ self.temp_dir = Path(temp_dir)
84
+ self.temp_dir.mkdir(exist_ok=True)
85
+
86
+ # Create FastAPI app
87
+ self.app = FastAPI(
88
+ title="MedRAX API",
89
+ description="Medical Reasoning Agent for Chest X-ray Analysis",
90
+ version="2.0.0",
91
+ docs_url="/docs",
92
+ redoc_url="/redoc",
93
+ )
94
+
95
+ # Add CORS middleware
96
+ self.app.add_middleware(
97
+ CORSMiddleware,
98
+ allow_origins=["*"],
99
+ allow_credentials=True,
100
+ allow_methods=["*"],
101
+ allow_headers=["*"],
102
+ )
103
+
104
+ # Register routes
105
+ self._register_routes()
106
+
107
+ def _register_routes(self):
108
+ """Register all API routes."""
109
+
110
+ @self.app.get("/health")
111
+ async def health_check():
112
+ """Health check endpoint."""
113
+ return {"status": "healthy", "service": "MedRAX API"}
114
+
115
+ @self.app.get("/tools")
116
+ async def list_tools():
117
+ """List available tools."""
118
+ return {"available_tools": list(self.tools_dict.keys()), "total_count": len(self.tools_dict)}
119
+
120
+ @self.app.post("/query", response_model=QueryResponse)
121
+ async def query_text_only(request: QueryRequest):
122
+ """
123
+ Process a text-only query without images.
124
+
125
+ Args:
126
+ request (QueryRequest): The query request
127
+
128
+ Returns:
129
+ QueryResponse: The agent's response
130
+ """
131
+ return await self._process_query(
132
+ question=request.question, system_prompt=request.system_prompt, thread_id=request.thread_id, images=None
133
+ )
134
+
135
+ @self.app.post("/query-with-images", response_model=QueryResponse)
136
+ async def query_with_images(
137
+ question: str = Form(..., description="The question or query to ask the agent"),
138
+ system_prompt: Optional[str] = Form(None, description="Custom system prompt to override default"),
139
+ thread_id: Optional[str] = Form(None, description="Optional thread ID for conversation continuity"),
140
+ images: List[UploadFile] = File(..., description="One or more medical images to analyze"),
141
+ ):
142
+ """
143
+ Process a query with one or more images.
144
+
145
+ Args:
146
+ question (str): The question or query to ask the agent
147
+ system_prompt (Optional[str]): Custom system prompt to override default
148
+ thread_id (Optional[str]): Optional thread ID for conversation continuity
149
+ images (List[UploadFile]): List of uploaded image files
150
+
151
+ Returns:
152
+ QueryResponse: The agent's response
153
+ """
154
+ # Validate image files
155
+ if not images or len(images) == 0:
156
+ raise HTTPException(status_code=400, detail="At least one image is required")
157
+
158
+ # Validate file types
159
+ allowed_types = {"image/jpeg", "image/jpg", "image/png", "image/bmp", "image/tiff", "application/dicom"}
160
+ for image in images:
161
+ if image.content_type not in allowed_types:
162
+ raise HTTPException(
163
+ status_code=400,
164
+ detail=f"Unsupported file type: {image.content_type}. Allowed types: {allowed_types}",
165
+ )
166
+
167
+ return await self._process_query(
168
+ question=question, system_prompt=system_prompt, thread_id=thread_id, images=images
169
+ )
170
+
171
+ async def _process_query(
172
+ self,
173
+ question: str,
174
+ system_prompt: Optional[str] = None,
175
+ thread_id: Optional[str] = None,
176
+ images: Optional[List[UploadFile]] = None,
177
+ ) -> QueryResponse:
178
+ """
179
+ Internal method to process queries through the agent.
180
+
181
+ Args:
182
+ question (str): The question to ask
183
+ system_prompt (Optional[str]): Custom system prompt
184
+ thread_id (Optional[str]): Thread ID for conversation
185
+ images (Optional[List[UploadFile]]): List of images
186
+
187
+ Returns:
188
+ QueryResponse: The processed response
189
+ """
190
+ start_time = time.time()
191
+
192
+ # Generate thread ID if not provided
193
+ if not thread_id:
194
+ thread_id = str(uuid.uuid4())
195
+
196
+ try:
197
+ # Prepare messages
198
+ messages = []
199
+ image_paths = []
200
+
201
+ # Handle image uploads
202
+ if images:
203
+ for i, image in enumerate(images):
204
+ # Save uploaded file temporarily
205
+ temp_path = self.temp_dir / f"{thread_id}_{i}_{image.filename}"
206
+
207
+ with open(temp_path, "wb") as buffer:
208
+ content = await image.read()
209
+ buffer.write(content)
210
+
211
+ image_paths.append(str(temp_path))
212
+
213
+ # Add image path for tools
214
+ messages.append({"role": "user", "content": f"image_path: {temp_path}"})
215
+
216
+ # Add base64 encoded image for multimodal processing
217
+ image_base64 = base64.b64encode(content).decode("utf-8")
218
+
219
+ # Determine MIME type
220
+ mime_type = "image/jpeg" # Default
221
+ if image.content_type:
222
+ mime_type = image.content_type
223
+ elif temp_path.suffix.lower() in [".png"]:
224
+ mime_type = "image/png"
225
+
226
+ messages.append(
227
+ {
228
+ "role": "user",
229
+ "content": [
230
+ {
231
+ "type": "image_url",
232
+ "image_url": {"url": f"data:{mime_type};base64,{image_base64}"},
233
+ }
234
+ ],
235
+ }
236
+ )
237
+
238
+ # Add text question
239
+ messages.append({"role": "user", "content": [{"type": "text", "text": question}]})
240
+
241
+ # Process through agent workflow
242
+ response_text = ""
243
+ tools_used = []
244
+
245
+ # Temporarily update system prompt if provided
246
+ original_prompt = None
247
+ if system_prompt:
248
+ original_prompt = self.agent.system_prompt
249
+ self.agent.system_prompt = system_prompt
250
+
251
+ try:
252
+ async for chunk in self._stream_agent_response(messages, thread_id):
253
+ if chunk.get("type") == "text":
254
+ response_text += chunk.get("content", "")
255
+ elif chunk.get("type") == "tool":
256
+ tools_used.append(chunk.get("tool_name", ""))
257
+ finally:
258
+ # Restore original system prompt
259
+ if original_prompt is not None:
260
+ self.agent.system_prompt = original_prompt
261
+
262
+ # Clean up temporary files
263
+ for image_path in image_paths:
264
+ try:
265
+ Path(image_path).unlink(missing_ok=True)
266
+ except Exception:
267
+ pass # Ignore cleanup errors
268
+
269
+ processing_time = time.time() - start_time
270
+
271
+ return QueryResponse(
272
+ response=response_text.strip(),
273
+ thread_id=thread_id,
274
+ tools_used=list(set(tools_used)), # Remove duplicates
275
+ processing_time=processing_time,
276
+ )
277
+
278
+ except Exception as e:
279
+ # Clean up on error
280
+ for image_path in image_paths:
281
+ try:
282
+ Path(image_path).unlink(missing_ok=True)
283
+ except Exception:
284
+ pass
285
+
286
+ raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
287
+
288
+ async def _stream_agent_response(self, messages: List[Dict], thread_id: str):
289
+ """
290
+ Stream responses from the agent workflow.
291
+
292
+ Args:
293
+ messages (List[Dict]): Messages to process
294
+ thread_id (str): Thread ID for the conversation
295
+
296
+ Yields:
297
+ Dict: Response chunks with type and content
298
+ """
299
+ try:
300
+ for chunk in self.agent.workflow.stream(
301
+ {"messages": messages},
302
+ {"configurable": {"thread_id": thread_id}},
303
+ stream_mode="updates",
304
+ ):
305
+ if not isinstance(chunk, dict):
306
+ continue
307
+
308
+ for node_name, node_output in chunk.items():
309
+ if "messages" not in node_output:
310
+ continue
311
+
312
+ for msg in node_output["messages"]:
313
+ if isinstance(msg, AIMessage) and msg.content:
314
+ # Clean up temp paths from response
315
+ clean_content = re.sub(r"temp[^\s]*", "", msg.content).strip()
316
+ if clean_content:
317
+ yield {"type": "text", "content": clean_content}
318
+
319
+ elif isinstance(msg, ToolMessage):
320
+ # Extract tool name from the message
321
+ tool_call_id = msg.tool_call_id
322
+ # We'll track tool usage but not include detailed output in API response
323
+ yield {"type": "tool", "tool_name": "tool_executed"}
324
+
325
+ except Exception as e:
326
+ yield {"type": "error", "content": str(e)}
327
+
328
+
329
+ def create_api(agent: Agent, tools_dict: Dict[str, Any], temp_dir: str = "temp_api") -> FastAPI:
330
+ """
331
+ Create and configure the MedRAX FastAPI application.
332
+
333
+ Args:
334
+ agent (Agent): The initialized MedRAX agent
335
+ tools_dict (Dict[str, Any]): Dictionary of available tools
336
+ temp_dir (str): Directory for temporary file storage
337
+
338
+ Returns:
339
+ FastAPI: Configured FastAPI application
340
+ """
341
+ api = MedRAXAPI(agent, tools_dict, temp_dir)
342
+ return api.app
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -1,5 +1,6 @@
1
  """MedRAX LLM provider implementation."""
2
 
 
3
  import time
4
  import re
5
  import uuid
@@ -68,7 +69,7 @@ class MedRAXProvider(LLMProvider):
68
  tools_to_use=selected_tools,
69
  model_dir="/home/lijunzh3/scratch/MedRAX2/model-weights",
70
  temp_dir="temp", # Change this to the path of the temporary directory
71
- device="cuda:0",
72
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
73
  temperature=self.temperature,
74
  top_p=self.top_p,
 
1
  """MedRAX LLM provider implementation."""
2
 
3
+ import os
4
  import time
5
  import re
6
  import uuid
 
69
  tools_to_use=selected_tools,
70
  model_dir="/home/lijunzh3/scratch/MedRAX2/model-weights",
71
  temp_dir="temp", # Change this to the path of the temporary directory
72
+ device=os.getenv("MEDRAX_DEVICE", "cuda:0"),
73
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
74
  temperature=self.temperature,
75
  top_p=self.top_p,
interface.py CHANGED
@@ -29,7 +29,7 @@ class ChatInterface:
29
  """
30
  self.agent = agent
31
  self.tools_dict = tools_dict
32
- self.upload_dir = Path("temp")
33
  self.upload_dir.mkdir(exist_ok=True)
34
  self.current_thread_id = None
35
  # Separate storage for original and display paths
@@ -68,9 +68,7 @@ class ChatInterface:
68
 
69
  return self.display_file_path
70
 
71
- def add_message(
72
- self, message: str, display_image: str, history: List[dict]
73
- ) -> Tuple[List[dict], gr.Textbox]:
74
  """
75
  Add a new message to the chat history.
76
 
@@ -155,9 +153,7 @@ class ChatInterface:
155
  if isinstance(msg, AIMessageChunk) and msg.content:
156
  accumulated_content += msg.content
157
  if final_message is None:
158
- final_message = ChatMessage(
159
- role="assistant", content=accumulated_content
160
- )
161
  chat_history.append(final_message)
162
  else:
163
  final_message.content = accumulated_content
@@ -169,9 +165,7 @@ class ChatInterface:
169
  if final_message:
170
  final_message.content = final_content
171
  else:
172
- chat_history.append(
173
- ChatMessage(role="assistant", content=final_content)
174
- )
175
  yield chat_history, self.display_file_path, ""
176
 
177
  if msg.tool_calls:
@@ -190,21 +184,25 @@ class ChatInterface:
190
  pending_call = self.pending_tool_calls.pop(tool_call_id)
191
  tool_name = pending_call["name"]
192
  tool_args = pending_call["args"]
193
-
194
  try:
195
- # Handle case where tool returns tuple (output, metadata)
196
- content = msg.content
197
- content_tuple = ast.literal_eval(content)
198
- content = json.dumps(content_tuple[0])
199
- tool_output_json = json.loads(content)
200
- tool_output_str = json.dumps(tool_output_json, indent=2)
201
- except (json.JSONDecodeError, TypeError):
202
- tool_output_str = str(msg.content)
 
 
 
 
 
203
 
 
204
  tool_args_str = json.dumps(tool_args, indent=2)
205
-
206
  description = f"**Input:**\n```json\n{tool_args_str}\n```\n\n**Output:**\n```json\n{tool_output_str}\n```"
207
-
208
  metadata = {
209
  "title": f"⚒️ Tool: {tool_name}",
210
  "description": description,
@@ -217,32 +215,33 @@ class ChatInterface:
217
  metadata=metadata,
218
  )
219
  )
220
- yield chat_history, self.display_file_path, ""
221
 
 
222
  if tool_name == "image_visualizer":
 
223
  try:
224
- # Handle case where tool returns tuple (output, metadata)
225
- content = msg.content
226
- content_tuple = ast.literal_eval(content)
227
- result = content_tuple[0]
228
-
229
- if isinstance(result, dict) and "image_path" in result:
230
- self.display_file_path = result["image_path"]
231
- chat_history.append(
232
- ChatMessage(
233
- role="assistant",
234
- content={"path": self.display_file_path},
235
- )
 
236
  )
237
- yield chat_history, self.display_file_path, ""
238
- except (json.JSONDecodeError, TypeError):
239
- pass
 
240
 
241
  except Exception as e:
242
  chat_history.append(
243
- ChatMessage(
244
- role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
245
- )
246
  )
247
  yield chat_history, self.display_file_path, ""
248
 
@@ -293,9 +292,7 @@ def create_demo(agent, tools_dict):
293
  )
294
 
295
  with gr.Column(scale=3):
296
- image_display = gr.Image(
297
- label="Image", type="filepath", height=600, container=True
298
- )
299
  with gr.Row():
300
  upload_button = gr.UploadButton(
301
  "📎 Upload X-Ray",
@@ -306,25 +303,19 @@ def create_demo(agent, tools_dict):
306
  file_types=["file"],
307
  )
308
  with gr.Row():
309
- clear_btn = gr.Button("Clear Chat")
310
- new_thread_btn = gr.Button("New Thread")
311
 
312
  # Event handlers
313
- def clear_chat():
314
  interface.original_file_path = None
315
  interface.display_file_path = None
316
- return [], None
317
-
318
- def new_thread():
319
  interface.current_thread_id = str(time.time())
320
- return [], interface.display_file_path
321
 
322
  def handle_file_upload(file):
323
  return interface.handle_upload(file.name)
324
 
325
- chat_msg = txt.submit(
326
- interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
327
- )
328
  bot_msg = chat_msg.then(
329
  interface.process_message,
330
  inputs=[txt, image_display, chatbot],
@@ -336,7 +327,6 @@ def create_demo(agent, tools_dict):
336
 
337
  dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display)
338
 
339
- clear_btn.click(clear_chat, outputs=[chatbot, image_display])
340
- new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
341
 
342
- return demo
 
29
  """
30
  self.agent = agent
31
  self.tools_dict = tools_dict
32
+ self.upload_dir = Path(f"temp/{time.time()}")
33
  self.upload_dir.mkdir(exist_ok=True)
34
  self.current_thread_id = None
35
  # Separate storage for original and display paths
 
68
 
69
  return self.display_file_path
70
 
71
+ def add_message(self, message: str, display_image: str, history: List[dict]) -> Tuple[List[dict], gr.Textbox]:
 
 
72
  """
73
  Add a new message to the chat history.
74
 
 
153
  if isinstance(msg, AIMessageChunk) and msg.content:
154
  accumulated_content += msg.content
155
  if final_message is None:
156
+ final_message = ChatMessage(role="assistant", content=accumulated_content)
 
 
157
  chat_history.append(final_message)
158
  else:
159
  final_message.content = accumulated_content
 
165
  if final_message:
166
  final_message.content = final_content
167
  else:
168
+ chat_history.append(ChatMessage(role="assistant", content=final_content))
 
 
169
  yield chat_history, self.display_file_path, ""
170
 
171
  if msg.tool_calls:
 
184
  pending_call = self.pending_tool_calls.pop(tool_call_id)
185
  tool_name = pending_call["name"]
186
  tool_args = pending_call["args"]
187
+ # Parse content
188
  try:
189
+ # Try JSON parsing first
190
+ result = json.loads(msg.content)
191
+ tool_output_str = json.dumps(result, indent=2)
192
+ except json.JSONDecodeError:
193
+ try:
194
+ # Use ast.literal_eval as safe fallback for Python literals
195
+ content_tuple = ast.literal_eval(msg.content)
196
+ result = content_tuple[0]
197
+ tool_output_str = json.dumps(result, indent=2)
198
+ except (ValueError, SyntaxError):
199
+ # Fall back to treating as plain string
200
+ result = msg.content
201
+ tool_output_str = str(msg.content)
202
 
203
+ # Display tool usage card
204
  tool_args_str = json.dumps(tool_args, indent=2)
 
205
  description = f"**Input:**\n```json\n{tool_args_str}\n```\n\n**Output:**\n```json\n{tool_output_str}\n```"
 
206
  metadata = {
207
  "title": f"⚒️ Tool: {tool_name}",
208
  "description": description,
 
215
  metadata=metadata,
216
  )
217
  )
 
218
 
219
+ # Special handling for image_visualizer
220
  if tool_name == "image_visualizer":
221
+ image_path = None
222
  try:
223
+ image_path = result["image_path"]
224
+ except (TypeError, KeyError):
225
+ try:
226
+ image_path = result[0]["image_path"]
227
+ except (TypeError, KeyError, IndexError):
228
+ pass
229
+
230
+ if image_path:
231
+ self.display_file_path = image_path
232
+ chat_history.append(
233
+ ChatMessage(
234
+ role="assistant",
235
+ content={"path": self.display_file_path},
236
  )
237
+ )
238
+
239
+ # Yield a single update for this tool event
240
+ yield chat_history, self.display_file_path, ""
241
 
242
  except Exception as e:
243
  chat_history.append(
244
+ ChatMessage(role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"})
 
 
245
  )
246
  yield chat_history, self.display_file_path, ""
247
 
 
292
  )
293
 
294
  with gr.Column(scale=3):
295
+ image_display = gr.Image(label="Image", type="filepath", height=600, container=True)
 
 
296
  with gr.Row():
297
  upload_button = gr.UploadButton(
298
  "📎 Upload X-Ray",
 
303
  file_types=["file"],
304
  )
305
  with gr.Row():
306
+ new_chat_btn = gr.Button("New Chat")
 
307
 
308
  # Event handlers
309
+ def new_chat():
310
  interface.original_file_path = None
311
  interface.display_file_path = None
 
 
 
312
  interface.current_thread_id = str(time.time())
313
+ return [], None
314
 
315
  def handle_file_upload(file):
316
  return interface.handle_upload(file.name)
317
 
318
+ chat_msg = txt.submit(interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt])
 
 
319
  bot_msg = chat_msg.then(
320
  interface.process_message,
321
  inputs=[txt, image_display, chatbot],
 
327
 
328
  dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display)
329
 
330
+ new_chat_btn.click(new_chat, outputs=[chatbot, image_display])
 
331
 
332
+ return demo
main.py CHANGED
@@ -11,6 +11,10 @@ with different model weights, tools, and parameters.
11
 
12
  import warnings
13
  import os
 
 
 
 
14
  from typing import Dict, List, Optional, Any
15
  from dotenv import load_dotenv
16
  from transformers import logging
@@ -19,6 +23,7 @@ from langgraph.checkpoint.memory import MemorySaver
19
  from medrax.models import ModelFactory
20
 
21
  from interface import create_demo
 
22
  from medrax.agent import *
23
  from medrax.tools import *
24
  from medrax.utils import *
@@ -31,19 +36,93 @@ logging.set_verbosity_error()
31
  _ = load_dotenv()
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def initialize_agent(
35
  prompt_file: str,
36
  tools_to_use: Optional[List[str]] = None,
37
- model_dir: str = "model-weights",
38
  temp_dir: str = "temp",
39
- device: str = "cpu",
40
- model: str = "gemini-2.5-pro",
41
  temperature: float = 1.0,
42
  top_p: float = 0.95,
43
  max_tokens: int = 5000,
44
  rag_config: Optional[RAGConfig] = None,
45
  model_kwargs: Dict[str, Any] = {},
46
  system_prompt: str = "MEDICAL_ASSISTANT",
 
47
  ):
48
  """Initialize the MedRAX agent with specified tools and configuration.
49
 
@@ -55,7 +134,6 @@ def initialize_agent(
55
  device (str, optional): Device to run models on. Defaults to "cuda".
56
  model (str, optional): Model to use. Defaults to "gpt-4o".
57
  temperature (float, optional): Temperature for the model. Defaults to 0.7.
58
- top_p (float, optional): Top P for the model. Defaults to 0.95.
59
  rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
60
  model_kwargs (dict, optional): Additional keyword arguments for model.
61
  system_prompt (str, optional): System prompt to use. Defaults to "MEDICAL_ASSISTANT".
@@ -68,18 +146,13 @@ def initialize_agent(
68
  prompts = load_prompts_from_file(prompt_file)
69
  prompt = prompts[system_prompt]
70
 
71
- # Define the URL of the MedGemma FastAPI service.
72
- MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://localhost:8002")
73
-
74
  all_tools = {
75
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
76
  "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
77
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
78
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
79
  "CheXagentXRayVQATool": lambda: CheXagentXRayVQATool(cache_dir=model_dir, device=device),
80
- "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
81
- cache_dir=model_dir, device=device
82
- ),
83
  "XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
84
  cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
85
  ),
@@ -91,18 +164,21 @@ def initialize_agent(
91
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
92
  "WebBrowserTool": lambda: WebBrowserTool(),
93
  "DuckDuckGoSearchTool": lambda: DuckDuckGoSearchTool(),
94
- "MedSAM2Tool": lambda: MedSAM2Tool(
95
- device=device, cache_dir=model_dir, temp_dir=temp_dir
 
 
 
 
96
  ),
97
- "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(cache_dir=model_dir, device=device, api_url=MEDGEMMA_API_URL)
98
- }
99
 
100
  # Initialize only selected tools or all if none specified
101
  tools_dict: Dict[str, BaseTool] = {}
102
 
103
  if tools_to_use is None:
104
  tools_to_use = []
105
-
106
  for tool_name in tools_to_use:
107
  if tool_name == "PythonSandboxTool":
108
  try:
@@ -112,7 +188,6 @@ def initialize_agent(
112
  print("Skipping PythonSandboxTool")
113
  if tool_name in all_tools:
114
  tools_dict[tool_name] = all_tools[tool_name]()
115
-
116
 
117
  # Set up checkpointing for conversation state
118
  checkpointer = MemorySaver()
@@ -130,8 +205,6 @@ def initialize_agent(
130
  agent = Agent(
131
  llm,
132
  tools=list(tools_dict.values()),
133
- log_tools=True,
134
- log_dir="logs",
135
  system_prompt=prompt,
136
  checkpointer=checkpointer,
137
  )
@@ -140,50 +213,262 @@ def initialize_agent(
140
  return agent, tools_dict
141
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  if __name__ == "__main__":
144
  """
145
  This is the main entry point for the MedRAX application.
146
- It initializes the agent with the selected tools and creates the demo.
147
  """
148
- print("Starting server...")
149
-
150
- # Example: initialize with only specific tools
151
- # Here three tools are commented out, you can uncomment them to use them
152
- selected_tools = [
153
- "ImageVisualizerTool", # For displaying images in the UI
154
- # "DicomProcessorTool", # For processing DICOM medical image files
155
- # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
156
- "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
157
- "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
158
- "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
159
- "MedGemmaVQATool" # Google MedGemma VQA tool
160
- "XRayVQATool", # For visual question answering on X-rays
161
- # "LlavaMedTool", # For multimodal medical image understanding
162
- "XRayPhraseGroundingTool", # For locating described features in X-rays
163
- "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
164
- # "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
165
- # "WebBrowserTool", # For web browsing and search capabilities
166
- "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
167
- # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
168
- # "PythonSandboxTool", # Add the Python sandbox tool
169
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
 
 
172
  if "MedGemmaVQATool" in selected_tools:
173
- setup_medgemma_env()
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Configure the Retrieval Augmented Generation (RAG) system
176
  # This allows the agent to access and use medical knowledge documents
177
  rag_config = RAGConfig(
178
- model="command-a-03-2025", # Chat model for generating responses
179
- embedding_model="embed-v4.0", # Embedding model for the RAG system
180
- rerank_model="rerank-v3.5", # Reranking model for the RAG system
181
- temperature=0.3,
182
- pinecone_index_name="medrax2", # Name for the Pinecone index
183
- chunk_size=1500,
184
- chunk_overlap=300,
185
- retriever_k=3,
186
- local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
187
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
188
  dataset_split="train", # Which split of the datasets to use
189
  )
@@ -192,19 +477,46 @@ if __name__ == "__main__":
192
  model_kwargs = {}
193
 
194
  agent, tools_dict = initialize_agent(
195
- prompt_file="medrax/docs/system_prompts.txt",
196
  tools_to_use=selected_tools,
197
- model_dir="model-weights",
198
- temp_dir="temp", # Change this to the path of the temporary directory
199
- device="cpu",
200
- model="gemini-2.5-pro", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
201
- temperature=1.0,
202
- top_p=0.95,
203
  model_kwargs=model_kwargs,
204
  rag_config=rag_config,
205
- system_prompt="MEDICAL_ASSISTANT",
 
206
  )
207
 
208
- # Create and launch the web interface
209
- demo = create_demo(agent, tools_dict)
210
- demo.launch(server_name="0.0.0.0", server_port=8585, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  import warnings
13
  import os
14
+ import argparse
15
+ from pyngrok import ngrok
16
+ import threading
17
+ import uvicorn
18
  from typing import Dict, List, Optional, Any
19
  from dotenv import load_dotenv
20
  from transformers import logging
 
23
  from medrax.models import ModelFactory
24
 
25
  from interface import create_demo
26
+ from api import create_api
27
  from medrax.agent import *
28
  from medrax.tools import *
29
  from medrax.utils import *
 
36
  _ = load_dotenv()
37
 
38
 
39
+ def resolve_medgemma_api_url_from_value(value: Optional[str]) -> str:
40
+ """Resolve the MedGemma API base URL using CLI value, env var, and SLURM-aware fallback.
41
+
42
+ Resolution order:
43
+ 1) Explicit provided value (e.g., CLI flag)
44
+ 2) MEDGEMMA_API_URL environment variable
45
+ 3) If on SLURM, require explicit URL (raise)
46
+ 4) Otherwise, default to localhost for single-box setups
47
+ """
48
+ if value:
49
+ return value
50
+
51
+ env_url = os.getenv("MEDGEMMA_API_URL")
52
+ if env_url:
53
+ return env_url
54
+
55
+ if os.getenv("SLURM_JOB_ID") or os.getenv("SLURM_NODEID"):
56
+ raise RuntimeError(
57
+ "MEDGEMMA_API_URL not set and --medgemma-api-url not provided. "
58
+ "On SLURM, the client usually runs on a different node, "
59
+ "so you must point to the server’s reachable IP, e.g. http://<node-ip>:8002"
60
+ )
61
+
62
+ return "http://127.0.0.1:8002"
63
+
64
+
65
+ def resolve_medgemma_api_url(args) -> str:
66
+ """Helper that reads from an argparse Namespace if available."""
67
+ return resolve_medgemma_api_url_from_value(getattr(args, "medgemma_api_url", None))
68
+
69
+
70
+ def resolve_auth_credentials(args) -> Optional[tuple]:
71
+ """Resolve authentication credentials from CLI args or environment variables.
72
+
73
+ Resolution order:
74
+ 1) Explicit --no-auth flag (returns None, no warnings)
75
+ 2) Explicit --auth USERNAME PASSWORD (returns credentials tuple)
76
+ 3) MEDRAX_AUTH_USERNAME and MEDRAX_AUTH_PASSWORD environment variables
77
+ 4) Default to None with warning messages
78
+
79
+ Args:
80
+ args: Parsed command-line arguments
81
+
82
+ Returns:
83
+ Optional[tuple]: (username, password) tuple if auth is enabled, None otherwise
84
+ """
85
+ if args.no_auth:
86
+ print("⚠️ Authentication disabled (public access)")
87
+ return None
88
+
89
+ if args.auth:
90
+ username, password = args.auth
91
+ print(f"✅ Authentication enabled for user: {username}")
92
+ return (username, password)
93
+
94
+ # Try to read from environment variables
95
+ auth_username = os.getenv("MEDRAX_AUTH_USERNAME")
96
+ auth_password = os.getenv("MEDRAX_AUTH_PASSWORD")
97
+
98
+ if auth_username and auth_password:
99
+ print(f"✅ Authentication enabled from environment for user: {auth_username}")
100
+ return (auth_username, auth_password)
101
+
102
+ # No auth specified anywhere - default to no auth with warning
103
+ print("⚠️ No authentication configured!")
104
+ print("⚠️ Running without authentication (public access)")
105
+ print("⚠️ To enable auth, either:")
106
+ print(" - Use --auth USERNAME PASSWORD")
107
+ print(" - Set MEDRAX_AUTH_USERNAME and MEDRAX_AUTH_PASSWORD in .env")
108
+ print(" - Or explicitly use --no-auth to suppress this warning")
109
+ return None
110
+
111
+
112
  def initialize_agent(
113
  prompt_file: str,
114
  tools_to_use: Optional[List[str]] = None,
115
+ model_dir: str = "/model-weights",
116
  temp_dir: str = "temp",
117
+ device: str = "cuda",
118
+ model: str = "gpt-4.1",
119
  temperature: float = 1.0,
120
  top_p: float = 0.95,
121
  max_tokens: int = 5000,
122
  rag_config: Optional[RAGConfig] = None,
123
  model_kwargs: Dict[str, Any] = {},
124
  system_prompt: str = "MEDICAL_ASSISTANT",
125
+ medgemma_api_url: Optional[str] = None,
126
  ):
127
  """Initialize the MedRAX agent with specified tools and configuration.
128
 
 
134
  device (str, optional): Device to run models on. Defaults to "cuda".
135
  model (str, optional): Model to use. Defaults to "gpt-4o".
136
  temperature (float, optional): Temperature for the model. Defaults to 0.7.
 
137
  rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
138
  model_kwargs (dict, optional): Additional keyword arguments for model.
139
  system_prompt (str, optional): System prompt to use. Defaults to "MEDICAL_ASSISTANT".
 
146
  prompts = load_prompts_from_file(prompt_file)
147
  prompt = prompts[system_prompt]
148
 
 
 
 
149
  all_tools = {
150
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
151
  "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
152
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
153
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
154
  "CheXagentXRayVQATool": lambda: CheXagentXRayVQATool(cache_dir=model_dir, device=device),
155
+ "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(cache_dir=model_dir, device=device),
 
 
156
  "XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
157
  cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
158
  ),
 
164
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
165
  "WebBrowserTool": lambda: WebBrowserTool(),
166
  "DuckDuckGoSearchTool": lambda: DuckDuckGoSearchTool(),
167
+ "MedSAM2Tool": lambda: MedSAM2Tool(device=device, cache_dir=model_dir, temp_dir=temp_dir),
168
+ "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(
169
+ cache_dir=model_dir,
170
+ device=device,
171
+ load_in_8bit=True,
172
+ api_url=resolve_medgemma_api_url_from_value(medgemma_api_url),
173
  ),
174
+ }
 
175
 
176
  # Initialize only selected tools or all if none specified
177
  tools_dict: Dict[str, BaseTool] = {}
178
 
179
  if tools_to_use is None:
180
  tools_to_use = []
181
+
182
  for tool_name in tools_to_use:
183
  if tool_name == "PythonSandboxTool":
184
  try:
 
188
  print("Skipping PythonSandboxTool")
189
  if tool_name in all_tools:
190
  tools_dict[tool_name] = all_tools[tool_name]()
 
191
 
192
  # Set up checkpointing for conversation state
193
  checkpointer = MemorySaver()
 
205
  agent = Agent(
206
  llm,
207
  tools=list(tools_dict.values()),
 
 
208
  system_prompt=prompt,
209
  checkpointer=checkpointer,
210
  )
 
213
  return agent, tools_dict
214
 
215
 
216
+ def run_gradio_interface(agent, tools_dict, host="0.0.0.0", port=8686,
217
+ auth=None, share=False):
218
+ """
219
+ Run the Gradio web interface.
220
+
221
+ Args:
222
+ agent: The initialized MedRAX agent
223
+ tools_dict: Dictionary of available tools
224
+ host (str): Host to bind the server to
225
+ port (int): Port to run the server on
226
+ auth: Authentication credentials (tuple)
227
+ share (bool): Whether to create a shareable public link
228
+ """
229
+ print(f"Starting Gradio interface on {host}:{port}")
230
+
231
+ if auth:
232
+ print(f"🔐 Authentication enabled for user: {auth[0]}")
233
+ else:
234
+ print("⚠️ Running without authentication (public access)")
235
+
236
+ if share:
237
+ print("🌍 Creating shareable public link (expires in 1 week)...")
238
+
239
+ demo = create_demo(agent, tools_dict)
240
+
241
+ # Prepare launch parameters
242
+ launch_kwargs = {
243
+ "server_name": host,
244
+ "server_port": port,
245
+ "share": share
246
+ }
247
+
248
+ if auth:
249
+ launch_kwargs["auth"] = auth
250
+
251
+ demo.launch(**launch_kwargs)
252
+
253
+
254
+ def run_api_server(agent, tools_dict, host="0.0.0.0", port=8585, public=False):
255
+ """
256
+ Run the FastAPI server.
257
+
258
+ Args:
259
+ agent: The initialized MedRAX agent
260
+ tools_dict: Dictionary of available tools
261
+ host (str): Host to bind the server to
262
+ port (int): Port to run the server on
263
+ public (bool): Whether to expose via ngrok tunnel
264
+ """
265
+ print(f"Starting API server on {host}:{port}")
266
+
267
+ if public:
268
+ try:
269
+ public_tunnel = ngrok.connect(port)
270
+ public_url = public_tunnel.public_url
271
+ print(
272
+ f"🌍 Public URL: {public_url}\n🌍 API Documentation: {public_url}/docs\n🌍 Share this URL with your friend!\n{'=' * 60}"
273
+ )
274
+ except ImportError:
275
+ print("⚠️ pyngrok not installed. Install with: pip install pyngrok\nRunning locally only...")
276
+ public = False
277
+ except Exception as e:
278
+ print(f"⚠️ Failed to create public tunnel: {e}\nRunning locally only...")
279
+ public = False
280
+
281
+ app = create_api(agent, tools_dict)
282
+
283
+ try:
284
+ uvicorn.run(app, host=host, port=port)
285
+ finally:
286
+ if public:
287
+ try:
288
+ ngrok.disconnect(public_tunnel.public_url)
289
+ ngrok.kill()
290
+ except:
291
+ pass
292
+
293
+
294
+ def parse_arguments():
295
+ """Parse command line arguments."""
296
+ parser = argparse.ArgumentParser(description="MedRAX - Medical Reasoning Agent for Chest X-ray")
297
+
298
+ # Run mode
299
+ parser.add_argument(
300
+ "--mode",
301
+ choices=["gradio", "api", "both"],
302
+ default="gradio",
303
+ help="Run mode: 'gradio' for web interface, 'api' for REST API, 'both' for both services",
304
+ )
305
+
306
+ # Gradio interface options
307
+ parser.add_argument("--gradio-host", default="0.0.0.0", help="Gradio host address")
308
+ parser.add_argument("--gradio-port", type=int, default=8686, help="Gradio port")
309
+ parser.add_argument("--auth", nargs=2, metavar=("USERNAME", "PASSWORD"),
310
+ default=None,
311
+ help="Enable password authentication with specified username and password")
312
+ parser.add_argument("--no-auth", action="store_true",
313
+ help="Disable authentication (public access)")
314
+ parser.add_argument("--share", action="store_true",
315
+ help="Create a temporary shareable link (expires in 1 week)")
316
+
317
+ # API server options
318
+ parser.add_argument("--api-host", default="0.0.0.0", help="API host address")
319
+ parser.add_argument("--api-port", type=int, default=8000, help="API port")
320
+ parser.add_argument("--public", action="store_true", help="Make API publicly accessible via ngrok tunnel")
321
+
322
+ # Model and system configuration
323
+ parser.add_argument(
324
+ "--model-dir",
325
+ default="/model-weights",
326
+ help="Directory containing model weights (default: uses MODEL_WEIGHTS_DIR env var or '/model-weights')",
327
+ )
328
+ parser.add_argument(
329
+ "--device", default="cuda", help="Device to run models on (default: uses MEDRAX_DEVICE env var or 'cuda:1')"
330
+ )
331
+ parser.add_argument(
332
+ "--model",
333
+ default="gpt-4.1",
334
+ help="Model to use (default: gpt-4.1). Examples: gpt-4.1-2025-04-14, gemini-2.5-pro, gpt-5",
335
+ )
336
+ parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for the model (default: 1.0)")
337
+ parser.add_argument("--temp-dir", default="temp2", help="Directory for temporary files (default: temp2)")
338
+ parser.add_argument(
339
+ "--prompt-file",
340
+ default="medrax/docs/system_prompts.txt",
341
+ help="Path to file containing system prompts (default: medrax/docs/system_prompts.txt)",
342
+ )
343
+ parser.add_argument(
344
+ "--system-prompt", default="MEDICAL_ASSISTANT", help="System prompt to use (default: MEDICAL_ASSISTANT)"
345
+ )
346
+
347
+ # RAG configuration
348
+ parser.add_argument(
349
+ "--rag-model", default="command-a-03-2025", help="Chat model for RAG responses (default: command-a-03-2025)"
350
+ )
351
+ parser.add_argument(
352
+ "--rag-embedding-model", default="embed-v4.0", help="Embedding model for RAG system (default: embed-v4.0)"
353
+ )
354
+ parser.add_argument(
355
+ "--rag-rerank-model", default="rerank-v3.5", help="Reranking model for RAG system (default: rerank-v3.5)"
356
+ )
357
+ parser.add_argument("--rag-temperature", type=float, default=0.3, help="Temperature for RAG model (default: 0.3)")
358
+ parser.add_argument("--pinecone-index", default="medrax2", help="Pinecone index name (default: medrax2)")
359
+ parser.add_argument("--chunk-size", type=int, default=1500, help="RAG chunk size (default: 1500)")
360
+ parser.add_argument("--chunk-overlap", type=int, default=300, help="RAG chunk overlap (default: 300)")
361
+ parser.add_argument("--retriever-k", type=int, default=3, help="Number of documents to retrieve (default: 3)")
362
+ parser.add_argument("--rag-docs-dir", default="rag_docs", help="Directory for RAG documents (default: rag_docs)")
363
+
364
+ # Tools configuration
365
+ parser.add_argument(
366
+ "--tools",
367
+ nargs="*",
368
+ help="Specific tools to enable (if not provided, uses default set). Available tools: "
369
+ + "ImageVisualizerTool, DicomProcessorTool, MedSAM2Tool, ChestXRaySegmentationTool, "
370
+ + "ChestXRayGeneratorTool, TorchXRayVisionClassifierTool, ArcPlusClassifierTool, "
371
+ + "ChestXRayReportGeneratorTool, XRayPhraseGroundingTool, MedGemmaVQATool, "
372
+ + "XRayVQATool, LlavaMedTool, MedicalRAGTool, WebBrowserTool, DuckDuckGoSearchTool, "
373
+ + "PythonSandboxTool",
374
+ )
375
+
376
+ # MedGemma API configuration
377
+ parser.add_argument(
378
+ "--medgemma-api-url",
379
+ default=None,
380
+ help="MedGemma API base URL, e.g. http://127.0.0.1:8002 or http://<node-ip>:8002"
381
+ )
382
+
383
+ return parser.parse_args()
384
+
385
+
386
  if __name__ == "__main__":
387
  """
388
  This is the main entry point for the MedRAX application.
389
+ It initializes the agent with the selected tools and creates the demo/API.
390
  """
391
+ args = parse_arguments()
392
+ print(f"Starting MedRAX in {args.mode} mode...")
393
+
394
+ # Configure tools based on arguments
395
+ if args.tools is not None:
396
+ # Use tools specified via command line
397
+ selected_tools = args.tools
398
+ else:
399
+ # Use default tools selection
400
+ selected_tools = [
401
+ # Image Processing Tools
402
+ "ImageVisualizerTool", # For displaying images in the UI
403
+ # "DicomProcessorTool", # For processing DICOM medical image files
404
+ # Segmentation Tools
405
+ "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
406
+ "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
407
+ # Generation Tools
408
+ # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
409
+ # Classification Tools
410
+ "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
411
+ "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
412
+ # Report Generation Tools
413
+ "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
414
+ # Grounding Tools
415
+ "XRayPhraseGroundingTool", # For locating described features in X-rays
416
+ # VQA Tools
417
+ # "MedGemmaVQATool", # Google MedGemma VQA tool
418
+ "XRayVQATool", # For visual question answering on X-rays
419
+ # "LlavaMedTool", # For multimodal medical image understanding
420
+ # RAG Tools
421
+ "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
422
+ # Search Tools
423
+ # "WebBrowserTool", # For web browsing and search capabilities
424
+ "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
425
+ # Development Tools
426
+ # "PythonSandboxTool", # Add the Python sandbox tool
427
+ ]
428
+
429
+ # Configure model directory and device
430
+ model_dir = args.model_dir or os.getenv("MODEL_WEIGHTS_DIR", "/model-weights")
431
+ device = args.device or os.getenv("MEDRAX_DEVICE", "cuda:0")
432
+
433
+ print(f"Using model directory: {model_dir}")
434
+ print(f"Using device: {device}")
435
+ print(f"Using model: {args.model}")
436
+ print(f"Selected tools: {selected_tools}")
437
+ print(f"Using system prompt: {args.system_prompt}")
438
+
439
+ # Set up authentication (reads from CLI, env vars, or requires explicit choice)
440
+ auth_credentials = resolve_auth_credentials(args)
441
 
442
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
443
+ medgemma_base_url_from_setup: Optional[str] = None
444
+ medgemma_api_url_effective: Optional[str] = args.medgemma_api_url
445
  if "MedGemmaVQATool" in selected_tools:
446
+ # Launch server and capture its URL if no explicit URL/ENV provided
447
+ try:
448
+ if medgemma_api_url_effective is None and os.getenv("MEDGEMMA_API_URL") is None:
449
+ medgemma_base_url_from_setup = setup_medgemma_env(cache_dir=model_dir, device=device)
450
+ # If we auto-launched, use this URL unless overridden later
451
+ if medgemma_base_url_from_setup:
452
+ medgemma_api_url_effective = medgemma_base_url_from_setup
453
+ print(f"MedGemma API auto-launched at {medgemma_api_url_effective}")
454
+ else:
455
+ # Still ensure environment is set up; it will bind to provided host/port
456
+ setup_medgemma_env(cache_dir=model_dir, device=device)
457
+ except Exception as e:
458
+ print(f"Warning: Failed to launch MedGemma service automatically: {e}")
459
 
460
  # Configure the Retrieval Augmented Generation (RAG) system
461
  # This allows the agent to access and use medical knowledge documents
462
  rag_config = RAGConfig(
463
+ model=args.rag_model,
464
+ embedding_model=args.rag_embedding_model,
465
+ rerank_model=args.rag_rerank_model,
466
+ temperature=args.rag_temperature,
467
+ pinecone_index_name=args.pinecone_index,
468
+ chunk_size=args.chunk_size,
469
+ chunk_overlap=args.chunk_overlap,
470
+ retriever_k=args.retriever_k,
471
+ local_docs_dir=args.rag_docs_dir,
472
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
473
  dataset_split="train", # Which split of the datasets to use
474
  )
 
477
  model_kwargs = {}
478
 
479
  agent, tools_dict = initialize_agent(
480
+ prompt_file=args.prompt_file,
481
  tools_to_use=selected_tools,
482
+ model_dir=model_dir,
483
+ temp_dir=args.temp_dir,
484
+ device=device,
485
+ model=args.model,
486
+ temperature=args.temperature,
 
487
  model_kwargs=model_kwargs,
488
  rag_config=rag_config,
489
+ system_prompt=args.system_prompt,
490
+ medgemma_api_url=medgemma_api_url_effective,
491
  )
492
 
493
+ # Launch based on selected mode
494
+ if args.mode == "gradio":
495
+ run_gradio_interface(
496
+ agent, tools_dict,
497
+ host=args.gradio_host,
498
+ port=args.gradio_port,
499
+ auth=auth_credentials,
500
+ share=args.share
501
+ )
502
+
503
+ elif args.mode == "api":
504
+ run_api_server(agent, tools_dict, args.api_host, args.api_port, args.public)
505
+
506
+ elif args.mode == "both":
507
+ # Run both services in separate threads
508
+ api_thread = threading.Thread(
509
+ target=run_api_server,
510
+ args=(agent, tools_dict, args.api_host, args.api_port, args.public)
511
+ )
512
+ api_thread.daemon = True
513
+ api_thread.start()
514
+
515
+ # Run Gradio in main thread with authentication and sharing
516
+ run_gradio_interface(
517
+ agent, tools_dict,
518
+ host=args.gradio_host,
519
+ port=args.gradio_port,
520
+ auth=auth_credentials,
521
+ share=args.share
522
+ )
medrax/agent/agent.py CHANGED
@@ -1,37 +1,17 @@
1
- import json
2
  import operator
3
- from pathlib import Path
4
- from dotenv import load_dotenv
5
- from datetime import datetime
6
  from typing import List, Dict, Any, TypedDict, Annotated, Optional
 
7
 
8
  from langgraph.graph import StateGraph, END
9
  from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage, HumanMessage
 
 
10
  from langchain_core.language_models import BaseLanguageModel
11
  from langchain_core.tools import BaseTool
12
 
13
  _ = load_dotenv()
14
 
15
 
16
- class ToolCallLog(TypedDict):
17
- """
18
- A TypedDict representing a log entry for a tool call.
19
-
20
- Attributes:
21
- timestamp (str): The timestamp of when the tool call was made.
22
- tool_call_id (str): The unique identifier for the tool call.
23
- name (str): The name of the tool that was called.
24
- args (Any): The arguments passed to the tool.
25
- content (str): The content or result of the tool call.
26
- """
27
-
28
- timestamp: str
29
- tool_call_id: str
30
- name: str
31
- args: Any
32
- content: str
33
-
34
-
35
  class AgentState(TypedDict):
36
  """
37
  A TypedDict representing the state of an agent.
@@ -48,16 +28,14 @@ class AgentState(TypedDict):
48
  class Agent:
49
  """
50
  A class representing an agent that processes requests and executes tools based on
51
- language model responses.
52
 
53
  Attributes:
54
  model (BaseLanguageModel): The language model used for processing.
55
- tools (Dict[str, BaseTool]): A dictionary of available tools.
56
  checkpointer (Any): Manages and persists the agent's state.
57
  system_prompt (str): The system instructions for the agent.
58
  workflow (StateGraph): The compiled workflow for the agent's processing.
59
- log_tools (bool): Whether to log tool calls.
60
- log_path (Path): Path to save tool call logs.
61
  """
62
 
63
  def __init__(
@@ -66,8 +44,6 @@ class Agent:
66
  tools: List[BaseTool],
67
  checkpointer: Any = None,
68
  system_prompt: str = "",
69
- log_tools: bool = True,
70
- log_dir: Optional[str] = "logs",
71
  ):
72
  """
73
  Initialize the Agent.
@@ -77,28 +53,21 @@ class Agent:
77
  tools (List[BaseTool]): A list of available tools.
78
  checkpointer (Any, optional): State persistence manager. Defaults to None.
79
  system_prompt (str, optional): System instructions. Defaults to "".
80
- log_tools (bool, optional): Whether to log tool calls. Defaults to True.
81
- log_dir (str, optional): Directory to save logs. Defaults to 'logs'.
82
  """
83
  self.system_prompt = system_prompt
84
- self.log_tools = log_tools
85
 
86
- if self.log_tools:
87
- self.log_path = Path(log_dir or "logs")
88
- self.log_path.mkdir(exist_ok=True)
89
 
90
- # Define the agent workflow
91
  workflow = StateGraph(AgentState)
92
- workflow.add_node("process", self.process_request)
93
- workflow.add_node("execute", self.execute_tools)
94
- workflow.add_conditional_edges(
95
- "process", self.has_tool_calls, {True: "execute", False: END}
96
- )
97
- workflow.add_edge("execute", "process")
98
- workflow.set_entry_point("process")
99
 
100
  self.workflow = workflow.compile(checkpointer=checkpointer)
101
- self.tools = {t.name: t for t in tools}
102
  self.model = model.bind_tools(tools)
103
 
104
  def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]:
@@ -148,65 +117,3 @@ class Agent:
148
  """
149
  response = state["messages"][-1]
150
  return len(response.tool_calls) > 0
151
-
152
- def execute_tools(self, state: AgentState) -> Dict[str, List[ToolMessage]]:
153
- """
154
- Execute tool calls from the model's response.
155
-
156
- Args:
157
- state (AgentState): The current state of the agent.
158
-
159
- Returns:
160
- Dict[str, List[ToolMessage]]: A dictionary containing tool execution results.
161
- """
162
- tool_calls = state["messages"][-1].tool_calls
163
- results = []
164
-
165
- for call in tool_calls:
166
- print(f"Executing tool: {call}")
167
- if call["name"] not in self.tools:
168
- print("\n....invalid tool....")
169
- result = "invalid tool, please retry"
170
- else:
171
- result = self.tools[call["name"]].invoke(call["args"])
172
-
173
- results.append(
174
- ToolMessage(
175
- tool_call_id=call["id"],
176
- name=call["name"],
177
- args=call["args"],
178
- content=str(result),
179
- )
180
- )
181
-
182
- self._save_tool_calls(results)
183
- print("Returning to model processing!")
184
-
185
- return {"messages": results}
186
-
187
- def _save_tool_calls(self, tool_calls: List[ToolMessage]) -> None:
188
- """
189
- Save tool calls to a JSON file with timestamp-based naming.
190
-
191
- Args:
192
- tool_calls (List[ToolMessage]): List of tool calls to save.
193
- """
194
- if not self.log_tools:
195
- return
196
-
197
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
198
- filename = self.log_path / f"tool_calls_{timestamp}.json"
199
-
200
- logs: List[ToolCallLog] = []
201
- for call in tool_calls:
202
- log_entry = {
203
- "tool_call_id": call.tool_call_id,
204
- "name": call.name,
205
- "args": call.args,
206
- "content": call.content,
207
- "timestamp": datetime.now().isoformat(),
208
- }
209
- logs.append(log_entry)
210
-
211
- with open(filename, "w") as f:
212
- json.dump(logs, f, indent=4)
 
 
1
  import operator
 
 
 
2
  from typing import List, Dict, Any, TypedDict, Annotated, Optional
3
+ from dotenv import load_dotenv
4
 
5
  from langgraph.graph import StateGraph, END
6
  from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage, HumanMessage
7
+ from langgraph.prebuilt import ToolNode
8
+ from langchain_core.messages import AnyMessage, SystemMessage
9
  from langchain_core.language_models import BaseLanguageModel
10
  from langchain_core.tools import BaseTool
11
 
12
  _ = load_dotenv()
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class AgentState(TypedDict):
16
  """
17
  A TypedDict representing the state of an agent.
 
28
  class Agent:
29
  """
30
  A class representing an agent that processes requests and executes tools based on
31
+ language model responses with parallel tool execution capabilities.
32
 
33
  Attributes:
34
  model (BaseLanguageModel): The language model used for processing.
35
+ tool_node (ToolNode): The parallel tool execution node.
36
  checkpointer (Any): Manages and persists the agent's state.
37
  system_prompt (str): The system instructions for the agent.
38
  workflow (StateGraph): The compiled workflow for the agent's processing.
 
 
39
  """
40
 
41
  def __init__(
 
44
  tools: List[BaseTool],
45
  checkpointer: Any = None,
46
  system_prompt: str = "",
 
 
47
  ):
48
  """
49
  Initialize the Agent.
 
53
  tools (List[BaseTool]): A list of available tools.
54
  checkpointer (Any, optional): State persistence manager. Defaults to None.
55
  system_prompt (str, optional): System instructions. Defaults to "".
 
 
56
  """
57
  self.system_prompt = system_prompt
 
58
 
59
+ # Create the parallel tool execution node
60
+ self.tool_node = ToolNode(tools)
 
61
 
62
+ # Define the agent workflow with parallel tool execution
63
  workflow = StateGraph(AgentState)
64
+ workflow.add_node("agent", self.process_request)
65
+ workflow.add_node("tools", self.tool_node)
66
+ workflow.add_conditional_edges("agent", self.has_tool_calls, {True: "tools", False: END})
67
+ workflow.add_edge("tools", "agent")
68
+ workflow.set_entry_point("agent")
 
 
69
 
70
  self.workflow = workflow.compile(checkpointer=checkpointer)
 
71
  self.model = model.bind_tools(tools)
72
 
73
  def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]:
 
117
  """
118
  response = state["messages"][-1]
119
  return len(response.tool_calls) > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
medrax/docs/system_prompts.txt CHANGED
@@ -33,4 +33,6 @@ Your final response for a multiple-choice question must strictly follow this for
33
  3. **Critical Thinking & Tool Use:** [Show your reasoning, including how you used tools and evaluated their output]
34
  4. **Final Answer:** \boxed{A}
35
 
36
- Do not provide a definitive diagnosis or treatment plan for a patient. Your purpose is to assist medical professionals with your analysis, not to replace them. You must maintain this persona and adhere to all instructions.
 
 
 
33
  3. **Critical Thinking & Tool Use:** [Show your reasoning, including how you used tools and evaluated their output]
34
  4. **Final Answer:** \boxed{A}
35
 
36
+ Do not provide a definitive diagnosis or treatment plan for a patient. Your purpose is to assist medical professionals with your analysis, not to replace them. You must maintain this persona and adhere to all instructions.
37
+
38
+ [EMPTY]
medrax/llava/conversation.py CHANGED
@@ -230,9 +230,7 @@ class Conversation:
230
  buffered = BytesIO()
231
  image.save(buffered, format="JPEG")
232
  img_b64_str = base64.b64encode(buffered.getvalue()).decode()
233
- img_str = (
234
- f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
235
- )
236
  msg = img_str + msg.replace("<image>", "").strip()
237
  ret.append([msg, None])
238
  else:
 
230
  buffered = BytesIO()
231
  image.save(buffered, format="JPEG")
232
  img_b64_str = base64.b64encode(buffered.getvalue()).decode()
233
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
 
 
234
  msg = img_str + msg.replace("<image>", "").strip()
235
  ret.append([msg, None])
236
  else:
medrax/llava/eval/eval_multimodal_chat_gpt_score.py CHANGED
@@ -14,6 +14,7 @@ INSTRUCT_PROMPT = """We would like to request your feedback on the performance o
14
  Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."""
15
  ROLE = "Assistant"
16
 
 
17
  # Generate instruction for GPT-4 to score the two answers.
18
  def conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2):
19
  return (
@@ -127,17 +128,13 @@ def main(args):
127
 
128
  if __name__ == "__main__":
129
  parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Scoring", add_help=True)
130
- parser.add_argument(
131
- "--answers-file", default="", metavar="FILE", help="path to model answer file"
132
- )
133
  parser.add_argument(
134
  "--question-file",
135
  default="data/questions/llava_med_eval_qa50_qa.jsonl",
136
  metavar="FILE",
137
  help="path to multichat questions file",
138
  )
139
- parser.add_argument(
140
- "--scores-file", default="", metavar="FILE", help="path to save gpt-4 score file"
141
- )
142
  args = parser.parse_args()
143
  main(args)
 
14
  Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."""
15
  ROLE = "Assistant"
16
 
17
+
18
  # Generate instruction for GPT-4 to score the two answers.
19
  def conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2):
20
  return (
 
128
 
129
  if __name__ == "__main__":
130
  parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Scoring", add_help=True)
131
+ parser.add_argument("--answers-file", default="", metavar="FILE", help="path to model answer file")
 
 
132
  parser.add_argument(
133
  "--question-file",
134
  default="data/questions/llava_med_eval_qa50_qa.jsonl",
135
  metavar="FILE",
136
  help="path to multichat questions file",
137
  )
138
+ parser.add_argument("--scores-file", default="", metavar="FILE", help="path to save gpt-4 score file")
 
 
139
  args = parser.parse_args()
140
  main(args)
medrax/llava/eval/llm.py CHANGED
@@ -21,9 +21,7 @@ class LLM(abc.ABC):
21
  raise NotImplementedError("Subclasses should implement this!")
22
 
23
  @abstractmethod
24
- def split_input(
25
- self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
26
- ):
27
  raise NotImplementedError("Subclasses should implement this!")
28
 
29
 
@@ -49,9 +47,7 @@ class GPT(LLM):
49
  def __init__(self, model_id):
50
  self.temperature = 0.0
51
  self.top_k = 1
52
- self.encoding = tiktoken.encoding_for_model(
53
- "-".join(model_id.split("-", 2)[:2]).replace("5", ".5")
54
- )
55
  self.openai_api = "default"
56
  self.model_id = model_id
57
  self.max_length = self.deployment_max_length_dict[model_id]
@@ -61,9 +57,7 @@ class GPT(LLM):
61
  azure_endpoint=self.openai_cxn_dict[self.openai_api]["endpoint"],
62
  )
63
 
64
- def gen_messages(
65
- self, fixed_instruction, few_shot_examples, input, input_header, output_header
66
- ):
67
  messages = [
68
  {
69
  "role": "system",
@@ -120,18 +114,13 @@ class GPT(LLM):
120
  ):
121
  return asyncio.run(self.dispatch_openai_requests(messages_list))
122
 
123
- def split_input(
124
- self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
125
- ):
126
  # Tokenize fixed_prompt
127
  fixed_token_ids = self.encoding.encode(
128
- fixed_instruction
129
- + " ".join([x["user"] + " " + x["assistant"] for x in few_shot_examples])
130
  )
131
  # Calculate remaining token length
132
- remaining_token_len = math.ceil(
133
- (self.prompt_percent * self.max_length) - len(fixed_token_ids)
134
- )
135
 
136
  # Tokenize splittable_input
137
  split_token_ids = self.encoding.encode(splittable_input)
@@ -141,14 +130,10 @@ class GPT(LLM):
141
  split_token_ids[i : i + remaining_token_len + 10]
142
  for i in range(0, len(split_token_ids), remaining_token_len)
143
  ]
144
- split_input_list = [
145
- self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list
146
- ]
147
 
148
  # Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
149
  return [
150
- self.gen_messages(
151
- fixed_instruction, few_shot_examples, split_input, input_header, output_header
152
- )
153
  for split_input in split_input_list
154
  ]
 
21
  raise NotImplementedError("Subclasses should implement this!")
22
 
23
  @abstractmethod
24
+ def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
 
 
25
  raise NotImplementedError("Subclasses should implement this!")
26
 
27
 
 
47
  def __init__(self, model_id):
48
  self.temperature = 0.0
49
  self.top_k = 1
50
+ self.encoding = tiktoken.encoding_for_model("-".join(model_id.split("-", 2)[:2]).replace("5", ".5"))
 
 
51
  self.openai_api = "default"
52
  self.model_id = model_id
53
  self.max_length = self.deployment_max_length_dict[model_id]
 
57
  azure_endpoint=self.openai_cxn_dict[self.openai_api]["endpoint"],
58
  )
59
 
60
+ def gen_messages(self, fixed_instruction, few_shot_examples, input, input_header, output_header):
 
 
61
  messages = [
62
  {
63
  "role": "system",
 
114
  ):
115
  return asyncio.run(self.dispatch_openai_requests(messages_list))
116
 
117
+ def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
 
 
118
  # Tokenize fixed_prompt
119
  fixed_token_ids = self.encoding.encode(
120
+ fixed_instruction + " ".join([x["user"] + " " + x["assistant"] for x in few_shot_examples])
 
121
  )
122
  # Calculate remaining token length
123
+ remaining_token_len = math.ceil((self.prompt_percent * self.max_length) - len(fixed_token_ids))
 
 
124
 
125
  # Tokenize splittable_input
126
  split_token_ids = self.encoding.encode(splittable_input)
 
130
  split_token_ids[i : i + remaining_token_len + 10]
131
  for i in range(0, len(split_token_ids), remaining_token_len)
132
  ]
133
+ split_input_list = [self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list]
 
 
134
 
135
  # Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
136
  return [
137
+ self.gen_messages(fixed_instruction, few_shot_examples, split_input, input_header, output_header)
 
 
138
  for split_input in split_input_list
139
  ]
medrax/llava/eval/model_vqa.py CHANGED
@@ -45,9 +45,7 @@ def eval_model(args):
45
  disable_torch_init()
46
  model_path = os.path.expanduser(args.model_path)
47
  model_name = get_model_name_from_path(model_path)
48
- tokenizer, model, image_processor, context_len = load_pretrained_model(
49
- model_path, args.model_base, model_name
50
- )
51
 
52
  questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
53
  questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
@@ -69,11 +67,7 @@ def eval_model(args):
69
  conv.append_message(conv.roles[1], None)
70
  prompt = conv.get_prompt()
71
 
72
- input_ids = (
73
- tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
74
- .unsqueeze(0)
75
- .cuda()
76
- )
77
 
78
  image = Image.open(os.path.join(args.image_folder, image_file))
79
  image_tensor = process_images([image], image_processor, model.config)[0]
 
45
  disable_torch_init()
46
  model_path = os.path.expanduser(args.model_path)
47
  model_name = get_model_name_from_path(model_path)
48
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
 
 
49
 
50
  questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
51
  questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
 
67
  conv.append_message(conv.roles[1], None)
68
  prompt = conv.get_prompt()
69
 
70
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
 
 
 
 
71
 
72
  image = Image.open(os.path.join(args.image_folder, image_file))
73
  image_tensor = process_images([image], image_processor, model.config)[0]
medrax/llava/eval/summarize_gpt_review.py CHANGED
@@ -14,8 +14,7 @@ def get_domain(x):
14
  def main(args):
15
  scores_data = util.load_file_jsonl(args.scores_file)
16
  predictions = [
17
- (x["question_id"], x["type"], get_domain(x), x["gpt_eval"].split("\n")[0].split(" "))
18
- for x in scores_data
19
  ]
20
 
21
  score_type_dict = defaultdict(lambda: defaultdict(list))
@@ -33,8 +32,7 @@ def main(args):
33
  result[q_type]["gpt4_score"] = util.get_avg(score_dict[1])
34
  result[q_type]["pred_score"] = util.get_avg(score_dict[2])
35
  result[q_type]["pred_relative_score"] = (
36
- util.get_avg([float(s2) / float(s1) for s1, s2 in zip(score_dict[1], score_dict[2])])
37
- * 100
38
  )
39
  result[q_type]["data_size"] = len(score_dict[1])
40
 
@@ -55,8 +53,6 @@ def main(args):
55
 
56
  if __name__ == "__main__":
57
  parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Eval Postprocessing", add_help=True)
58
- parser.add_argument(
59
- "--scores-file", default="", metavar="FILE", help="input path to gpt-4 score file"
60
- )
61
  args = parser.parse_args()
62
  main(args)
 
14
  def main(args):
15
  scores_data = util.load_file_jsonl(args.scores_file)
16
  predictions = [
17
+ (x["question_id"], x["type"], get_domain(x), x["gpt_eval"].split("\n")[0].split(" ")) for x in scores_data
 
18
  ]
19
 
20
  score_type_dict = defaultdict(lambda: defaultdict(list))
 
32
  result[q_type]["gpt4_score"] = util.get_avg(score_dict[1])
33
  result[q_type]["pred_score"] = util.get_avg(score_dict[2])
34
  result[q_type]["pred_relative_score"] = (
35
+ util.get_avg([float(s2) / float(s1) for s1, s2 in zip(score_dict[1], score_dict[2])]) * 100
 
36
  )
37
  result[q_type]["data_size"] = len(score_dict[1])
38
 
 
53
 
54
  if __name__ == "__main__":
55
  parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Eval Postprocessing", add_help=True)
56
+ parser.add_argument("--scores-file", default="", metavar="FILE", help="input path to gpt-4 score file")
 
 
57
  args = parser.parse_args()
58
  main(args)
medrax/llava/mm_utils.py CHANGED
@@ -35,9 +35,7 @@ def process_images(images, image_processor, model_cfg):
35
  for image in images:
36
  if image_aspect_ratio == "pad":
37
  if image.mode == "L":
38
- background_color = int(
39
- 255 * sum(image_processor.image_mean) / len(image_processor.image_mean)
40
- )
41
  else:
42
  background_color = tuple(int(x * 255) for x in image_processor.image_mean)
43
  image = expand2square(image, background_color)
@@ -48,9 +46,7 @@ def process_images(images, image_processor, model_cfg):
48
  return new_images
49
 
50
 
51
- def tokenizer_image_token(
52
- prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
53
- ):
54
  prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
55
 
56
  def insert_separator(X, sep):
@@ -58,11 +54,7 @@ def tokenizer_image_token(
58
 
59
  input_ids = []
60
  offset = 0
61
- if (
62
- len(prompt_chunks) > 0
63
- and len(prompt_chunks[0]) > 0
64
- and prompt_chunks[0][0] == tokenizer.bos_token_id
65
- ):
66
  offset = 1
67
  input_ids.append(prompt_chunks[0][0])
68
 
@@ -100,9 +92,7 @@ class KeywordsStoppingCriteria(StoppingCriteria):
100
  self.tokenizer = tokenizer
101
  self.start_len = input_ids.shape[1]
102
 
103
- def call_for_batch(
104
- self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
105
- ) -> bool:
106
  offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
107
  self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
108
  for keyword_id in self.keyword_ids:
 
35
  for image in images:
36
  if image_aspect_ratio == "pad":
37
  if image.mode == "L":
38
+ background_color = int(255 * sum(image_processor.image_mean) / len(image_processor.image_mean))
 
 
39
  else:
40
  background_color = tuple(int(x * 255) for x in image_processor.image_mean)
41
  image = expand2square(image, background_color)
 
46
  return new_images
47
 
48
 
49
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
 
 
50
  prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
51
 
52
  def insert_separator(X, sep):
 
54
 
55
  input_ids = []
56
  offset = 0
57
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
 
 
 
 
58
  offset = 1
59
  input_ids.append(prompt_chunks[0][0])
60
 
 
92
  self.tokenizer = tokenizer
93
  self.start_len = input_ids.shape[1]
94
 
95
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
 
96
  offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
97
  self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
98
  for keyword_id in self.keyword_ids:
medrax/llava/model/builder.py CHANGED
@@ -59,9 +59,7 @@ def load_pretrained_model(
59
  # PEFT model
60
  from peft import PeftModel
61
 
62
- tokenizer = AutoTokenizer.from_pretrained(
63
- model_base, use_fast=False, cache_dir=cache_dir
64
- )
65
  model = AutoModelForCausalLM.from_pretrained(
66
  model_base,
67
  low_cpu_mem_usage=True,
@@ -78,9 +76,7 @@ def load_pretrained_model(
78
  else:
79
  use_fast = False
80
  if "mpt" in model_name.lower():
81
- tokenizer = AutoTokenizer.from_pretrained(
82
- model_path, use_fast=True, cache_dir=cache_dir
83
- )
84
  model = AutoModelForCausalLM.from_pretrained(
85
  model_path,
86
  low_cpu_mem_usage=True,
@@ -90,9 +86,7 @@ def load_pretrained_model(
90
  **kwargs,
91
  )
92
  else:
93
- tokenizer = AutoTokenizer.from_pretrained(
94
- model_path, use_fast=False, cache_dir=cache_dir
95
- )
96
  model = AutoModelForCausalLM.from_pretrained(
97
  model_path,
98
  low_cpu_mem_usage=True,
@@ -109,9 +103,7 @@ def load_pretrained_model(
109
  if mm_use_im_patch_token:
110
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
111
  if mm_use_im_start_end:
112
- tokenizer.add_tokens(
113
- [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
114
- )
115
  model.resize_token_embeddings(len(tokenizer))
116
 
117
  vision_tower = model.get_vision_tower()
 
59
  # PEFT model
60
  from peft import PeftModel
61
 
62
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, cache_dir=cache_dir)
 
 
63
  model = AutoModelForCausalLM.from_pretrained(
64
  model_base,
65
  low_cpu_mem_usage=True,
 
76
  else:
77
  use_fast = False
78
  if "mpt" in model_name.lower():
79
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, cache_dir=cache_dir)
 
 
80
  model = AutoModelForCausalLM.from_pretrained(
81
  model_path,
82
  low_cpu_mem_usage=True,
 
86
  **kwargs,
87
  )
88
  else:
89
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, cache_dir=cache_dir)
 
 
90
  model = AutoModelForCausalLM.from_pretrained(
91
  model_path,
92
  low_cpu_mem_usage=True,
 
103
  if mm_use_im_patch_token:
104
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
105
  if mm_use_im_start_end:
106
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
107
  model.resize_token_embeddings(len(tokenizer))
108
 
109
  vision_tower = model.get_vision_tower()
medrax/llava/model/language_model/llava_mistral.py CHANGED
@@ -125,9 +125,7 @@ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
125
  **kwargs,
126
  )
127
 
128
- def prepare_inputs_for_generation(
129
- self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
130
- ):
131
  images = kwargs.pop("images", None)
132
  image_sizes = kwargs.pop("image_sizes", None)
133
  inputs = super().prepare_inputs_for_generation(
 
125
  **kwargs,
126
  )
127
 
128
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
 
 
129
  images = kwargs.pop("images", None)
130
  image_sizes = kwargs.pop("image_sizes", None)
131
  inputs = super().prepare_inputs_for_generation(
medrax/llava/model/llava_arch.py CHANGED
@@ -104,9 +104,7 @@ class LlavaMetaModel:
104
  checkpoint_folder = os.path.dirname(pretrain_mm_mlp_adapter)
105
  ckpts = glob(f"{checkpoint_folder}/checkpoint-*", recursive=False)
106
  if len(ckpts) > 0:
107
- vision_module_weights = torch.load(
108
- f"{ckpts[-1]}/mm_projector.bin", map_location="cpu"
109
- )
110
  model_dict = get_w(vision_module_weights, "vision_tower")
111
  print(f"Loading vision module weights from {ckpts[-1]}/mm_projector.bin")
112
  # print keys in model_dict
@@ -170,9 +168,7 @@ class LlavaMetaForCausalLM(ABC):
170
  image_features = self.encode_images(images).to(self.device)
171
 
172
  # TODO: image start / end is not implemented here to support pretraining.
173
- if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
174
- self.config, "mm_use_im_start_end", False
175
- ):
176
  raise NotImplementedError
177
 
178
  # Let's just add dummy tensors if they do not exist,
@@ -188,21 +184,15 @@ class LlavaMetaForCausalLM(ABC):
188
  else:
189
  attention_mask = attention_mask.bool()
190
  if position_ids is None:
191
- position_ids = torch.arange(
192
- 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
193
- )
194
 
195
  if labels is None:
196
  labels = torch.full_like(input_ids, IGNORE_INDEX)
197
 
198
  input_ids = [
199
- cur_input_ids[cur_attention_mask]
200
- for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
201
- ]
202
- labels = [
203
- cur_labels[cur_attention_mask]
204
- for cur_labels, cur_attention_mask in zip(labels, attention_mask)
205
  ]
 
206
 
207
  new_input_embeds = []
208
  new_labels = []
@@ -219,20 +209,14 @@ class LlavaMetaForCausalLM(ABC):
219
  continue
220
 
221
  image_token_indices = (
222
- [-1]
223
- + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
224
- + [cur_input_ids.shape[0]]
225
  )
226
  cur_input_ids_noim = []
227
  cur_labels = labels[batch_idx]
228
  cur_labels_noim = []
229
  for i in range(len(image_token_indices) - 1):
230
- cur_input_ids_noim.append(
231
- cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]
232
- )
233
- cur_labels_noim.append(
234
- cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
235
- )
236
 
237
  split_sizes = [x.shape[0] for x in cur_labels_noim]
238
  cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
@@ -279,12 +263,8 @@ class LlavaMetaForCausalLM(ABC):
279
  dtype=new_labels[0].dtype,
280
  device=new_labels[0].device,
281
  )
282
- attention_mask = torch.zeros(
283
- (batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device
284
- )
285
- position_ids = torch.zeros(
286
- (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
287
- )
288
 
289
  for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
290
  cur_len = cur_new_embed.shape[0]
@@ -351,9 +331,7 @@ class LlavaMetaForCausalLM(ABC):
351
  self.resize_token_embeddings(len(tokenizer))
352
 
353
  if model_args.mm_use_im_start_end:
354
- num_new_tokens = tokenizer.add_tokens(
355
- [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
356
- )
357
  self.resize_token_embeddings(len(tokenizer))
358
 
359
  if num_new_tokens > 0:
@@ -361,9 +339,7 @@ class LlavaMetaForCausalLM(ABC):
361
  output_embeddings = self.get_output_embeddings().weight.data
362
 
363
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
364
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
365
- dim=0, keepdim=True
366
- )
367
 
368
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
369
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
@@ -375,9 +351,7 @@ class LlavaMetaForCausalLM(ABC):
375
  p.requires_grad = False
376
 
377
  if model_args.pretrain_mm_mlp_adapter:
378
- mm_projector_weights = torch.load(
379
- model_args.pretrain_mm_mlp_adapter, map_location="cpu"
380
- )
381
  embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
382
  assert num_new_tokens == 2
383
  if input_embeddings.shape == embed_tokens_weight.shape:
 
104
  checkpoint_folder = os.path.dirname(pretrain_mm_mlp_adapter)
105
  ckpts = glob(f"{checkpoint_folder}/checkpoint-*", recursive=False)
106
  if len(ckpts) > 0:
107
+ vision_module_weights = torch.load(f"{ckpts[-1]}/mm_projector.bin", map_location="cpu")
 
 
108
  model_dict = get_w(vision_module_weights, "vision_tower")
109
  print(f"Loading vision module weights from {ckpts[-1]}/mm_projector.bin")
110
  # print keys in model_dict
 
168
  image_features = self.encode_images(images).to(self.device)
169
 
170
  # TODO: image start / end is not implemented here to support pretraining.
171
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
 
 
172
  raise NotImplementedError
173
 
174
  # Let's just add dummy tensors if they do not exist,
 
184
  else:
185
  attention_mask = attention_mask.bool()
186
  if position_ids is None:
187
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
 
 
188
 
189
  if labels is None:
190
  labels = torch.full_like(input_ids, IGNORE_INDEX)
191
 
192
  input_ids = [
193
+ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
 
 
 
 
 
194
  ]
195
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
196
 
197
  new_input_embeds = []
198
  new_labels = []
 
209
  continue
210
 
211
  image_token_indices = (
212
+ [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
 
 
213
  )
214
  cur_input_ids_noim = []
215
  cur_labels = labels[batch_idx]
216
  cur_labels_noim = []
217
  for i in range(len(image_token_indices) - 1):
218
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
219
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
 
 
 
 
220
 
221
  split_sizes = [x.shape[0] for x in cur_labels_noim]
222
  cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
 
263
  dtype=new_labels[0].dtype,
264
  device=new_labels[0].device,
265
  )
266
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
267
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
 
 
 
 
268
 
269
  for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
270
  cur_len = cur_new_embed.shape[0]
 
331
  self.resize_token_embeddings(len(tokenizer))
332
 
333
  if model_args.mm_use_im_start_end:
334
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
335
  self.resize_token_embeddings(len(tokenizer))
336
 
337
  if num_new_tokens > 0:
 
339
  output_embeddings = self.get_output_embeddings().weight.data
340
 
341
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
342
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
 
 
343
 
344
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
345
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
 
351
  p.requires_grad = False
352
 
353
  if model_args.pretrain_mm_mlp_adapter:
354
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
 
 
355
  embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
356
  assert num_new_tokens == 2
357
  if input_embeddings.shape == embed_tokens_weight.shape:
medrax/llava/model/multimodal_encoder/builder.py CHANGED
@@ -3,13 +3,7 @@ from .clip_encoder import CLIPVisionTower
3
 
4
 
5
  def build_vision_tower(vision_tower_cfg, **kwargs):
6
- vision_tower = getattr(
7
- vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)
8
- )
9
  is_absolute_path_exists = os.path.exists(vision_tower)
10
- if (
11
- is_absolute_path_exists
12
- or vision_tower.startswith("openai")
13
- or vision_tower.startswith("laion")
14
- ):
15
  return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
 
3
 
4
 
5
  def build_vision_tower(vision_tower_cfg, **kwargs):
6
+ vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
 
 
7
  is_absolute_path_exists = os.path.exists(vision_tower)
8
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
 
 
 
 
9
  return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
medrax/llava/model/multimodal_projector/builder.py CHANGED
@@ -19,9 +19,7 @@ class SimpleResBlock(nn.Module):
19
  super().__init__()
20
  self.pre_norm = nn.LayerNorm(channels)
21
 
22
- self.proj = nn.Sequential(
23
- nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
24
- )
25
 
26
  def forward(self, x):
27
  x = self.pre_norm(x)
 
19
  super().__init__()
20
  self.pre_norm = nn.LayerNorm(channels)
21
 
22
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
 
 
23
 
24
  def forward(self, x):
25
  x = self.pre_norm(x)
medrax/llava/serve/cli.py CHANGED
@@ -94,9 +94,7 @@ def main(args):
94
  if image is not None:
95
  # first message
96
  if model.config.mm_use_im_start_end:
97
- inp = (
98
- DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
99
- )
100
  else:
101
  inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
102
  conv.append_message(conv.roles[0], inp)
 
94
  if image is not None:
95
  # first message
96
  if model.config.mm_use_im_start_end:
97
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
 
 
98
  else:
99
  inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
100
  conv.append_message(conv.roles[0], inp)
medrax/llava/serve/controller.py CHANGED
@@ -2,6 +2,7 @@
2
  A controller manages distributed workers.
3
  It sends worker addresses to clients.
4
  """
 
5
  import argparse
6
  import dataclasses
7
  from enum import Enum, auto
@@ -199,9 +200,7 @@ class Controller:
199
  yield json.dumps(ret).encode() + b"\0"
200
 
201
  try:
202
- response = requests.post(
203
- worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5
204
- )
205
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
206
  if chunk:
207
  yield chunk + b"\0"
@@ -240,9 +239,7 @@ app = FastAPI()
240
  @app.post("/register_worker")
241
  async def register_worker(request: Request):
242
  data = await request.json()
243
- controller.register_worker(
244
- data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
245
- )
246
 
247
 
248
  @app.post("/refresh_all_workers")
 
2
  A controller manages distributed workers.
3
  It sends worker addresses to clients.
4
  """
5
+
6
  import argparse
7
  import dataclasses
8
  from enum import Enum, auto
 
200
  yield json.dumps(ret).encode() + b"\0"
201
 
202
  try:
203
+ response = requests.post(worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5)
 
 
204
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
205
  if chunk:
206
  yield chunk + b"\0"
 
239
  @app.post("/register_worker")
240
  async def register_worker(request: Request):
241
  data = await request.json()
242
+ controller.register_worker(data["worker_name"], data["check_heart_beat"], data.get("worker_status", None))
 
 
243
 
244
 
245
  @app.post("/refresh_all_workers")
medrax/llava/serve/gradio_web_server.py CHANGED
@@ -216,9 +216,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
216
  all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
217
  for image, hash in zip(all_images, all_image_hash):
218
  t = datetime.datetime.now()
219
- filename = os.path.join(
220
- LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
221
- )
222
  if not os.path.isfile(filename):
223
  os.makedirs(os.path.dirname(filename), exist_ok=True)
224
  image.save(filename)
@@ -230,9 +228,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
230
  "temperature": float(temperature),
231
  "top_p": float(top_p),
232
  "max_new_tokens": min(int(max_new_tokens), 1536),
233
- "stop": state.sep
234
- if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
235
- else state.sep2,
236
  "images": f"List of {len(state.get_images())} images: {all_image_hash}",
237
  }
238
  logger.info(f"==== request ====\n{pload}")
@@ -330,9 +326,7 @@ block_css = """
330
 
331
 
332
  def build_demo(embed_mode):
333
- textbox = gr.Textbox(
334
- show_label=False, placeholder="Enter text and press ENTER", container=False
335
- )
336
  with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
337
  state = gr.State()
338
 
@@ -468,9 +462,7 @@ def build_demo(embed_mode):
468
  [state, chatbot] + btn_list,
469
  )
470
 
471
- clear_btn.click(
472
- clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False
473
- )
474
 
475
  textbox.submit(
476
  add_text,
 
216
  all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
217
  for image, hash in zip(all_images, all_image_hash):
218
  t = datetime.datetime.now()
219
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
 
 
220
  if not os.path.isfile(filename):
221
  os.makedirs(os.path.dirname(filename), exist_ok=True)
222
  image.save(filename)
 
228
  "temperature": float(temperature),
229
  "top_p": float(top_p),
230
  "max_new_tokens": min(int(max_new_tokens), 1536),
231
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
 
 
232
  "images": f"List of {len(state.get_images())} images: {all_image_hash}",
233
  }
234
  logger.info(f"==== request ====\n{pload}")
 
326
 
327
 
328
  def build_demo(embed_mode):
329
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
 
 
330
  with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
331
  state = gr.State()
332
 
 
462
  [state, chatbot] + btn_list,
463
  )
464
 
465
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False)
 
 
466
 
467
  textbox.submit(
468
  add_text,
medrax/llava/serve/model_worker.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  A model worker executes the model.
3
  """
 
4
  import argparse
5
  import asyncio
6
  import json
@@ -155,9 +156,7 @@ class ModelWorker:
155
  if images is not None and len(images) > 0 and self.is_multimodal:
156
  if len(images) > 0:
157
  if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
158
- raise ValueError(
159
- "Number of images does not match number of <image> tokens in prompt"
160
- )
161
 
162
  images = [load_image_from_base64(image) for image in images]
163
  images = process_images(images, image_processor, model.config)
@@ -172,9 +171,7 @@ class ModelWorker:
172
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
173
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
174
 
175
- num_image_tokens = (
176
- prompt.count(replace_token) * model.get_vision_tower().num_patches
177
- )
178
  else:
179
  images = None
180
  image_args = {"images": images}
@@ -196,19 +193,14 @@ class ModelWorker:
196
  )
197
  keywords = [stop_str]
198
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
199
- streamer = TextIteratorStreamer(
200
- tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
201
- )
202
 
203
- max_new_tokens = min(
204
- max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens
205
- )
206
 
207
  if max_new_tokens < 1:
208
  yield json.dumps(
209
  {
210
- "text": ori_prompt
211
- + "Exceeds max token length. Please start a new conversation, thanks.",
212
  "error_code": 0,
213
  }
214
  ).encode() + b"\0"
 
1
  """
2
  A model worker executes the model.
3
  """
4
+
5
  import argparse
6
  import asyncio
7
  import json
 
156
  if images is not None and len(images) > 0 and self.is_multimodal:
157
  if len(images) > 0:
158
  if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
159
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
 
 
160
 
161
  images = [load_image_from_base64(image) for image in images]
162
  images = process_images(images, image_processor, model.config)
 
171
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
172
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
173
 
174
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
 
 
175
  else:
176
  images = None
177
  image_args = {"images": images}
 
193
  )
194
  keywords = [stop_str]
195
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
196
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
 
 
197
 
198
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
 
 
199
 
200
  if max_new_tokens < 1:
201
  yield json.dumps(
202
  {
203
+ "text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.",
 
204
  "error_code": 0,
205
  }
206
  ).encode() + b"\0"
medrax/llava/serve/test_message.py CHANGED
@@ -17,9 +17,7 @@ def main():
17
  models.sort()
18
  print(f"Models: {models}")
19
 
20
- ret = requests.post(
21
- controller_addr + "/get_worker_address", json={"model": args.model_name}
22
- )
23
  worker_addr = ret.json()["address"]
24
  print(f"worker_addr: {worker_addr}")
25
 
@@ -38,9 +36,7 @@ def main():
38
  "temperature": 0.7,
39
  "stop": conv.sep2,
40
  }
41
- response = requests.post(
42
- worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True
43
- )
44
 
45
  print(prompt, end="")
46
  for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
 
17
  models.sort()
18
  print(f"Models: {models}")
19
 
20
+ ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name})
 
 
21
  worker_addr = ret.json()["address"]
22
  print(f"worker_addr: {worker_addr}")
23
 
 
36
  "temperature": 0.7,
37
  "stop": conv.sep2,
38
  }
39
+ response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True)
 
 
40
 
41
  print(prompt, end="")
42
  for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
medrax/llava/utils.py CHANGED
@@ -45,9 +45,7 @@ def build_logger(logger_name, logger_filename):
45
  if handler is None:
46
  os.makedirs(LOGDIR, exist_ok=True)
47
  filename = os.path.join(LOGDIR, logger_filename)
48
- handler = logging.handlers.TimedRotatingFileHandler(
49
- filename, when="D", utc=True, encoding="UTF-8"
50
- )
51
  handler.setFormatter(formatter)
52
 
53
  for name, item in logging.root.manager.loggerDict.items():
 
45
  if handler is None:
46
  os.makedirs(LOGDIR, exist_ok=True)
47
  filename = os.path.join(LOGDIR, logger_filename)
48
+ handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True, encoding="UTF-8")
 
 
49
  handler.setFormatter(formatter)
50
 
51
  for name, item in logging.root.manager.loggerDict.items():
medrax/models/model_factory.py CHANGED
@@ -29,7 +29,7 @@ class ModelFactory:
29
  "base_url_key": "OPENAI_BASE_URL",
30
  },
31
  "gemini": {
32
- "class": ChatGoogleGenerativeAI,
33
  "env_key": "GOOGLE_API_KEY",
34
  "base_url_key": "GOOGLE_BASE_URL",
35
  },
@@ -42,14 +42,12 @@ class ModelFactory:
42
  "grok": {
43
  "class": ChatXAI,
44
  "env_key": "XAI_API_KEY",
45
- }
46
  # Add more providers with default configurations here
47
  }
48
 
49
  @classmethod
50
- def register_provider(
51
- cls, prefix: str, model_class: Type[BaseLanguageModel], env_key: str, **kwargs
52
- ) -> None:
53
  """Register a new model provider.
54
 
55
  Args:
@@ -81,9 +79,7 @@ class ModelFactory:
81
  ValueError: If the required API key is missing
82
  """
83
  # Find the matching provider based on model name prefix
84
- provider_prefix = next(
85
- (prefix for prefix in cls._model_providers if model_name.startswith(prefix)), None
86
- )
87
 
88
  if not provider_prefix:
89
  raise ValueError(
@@ -153,7 +149,4 @@ class ModelFactory:
153
  Dict[str, Dict[str, Any]]: Dictionary of registered providers and their configurations
154
  """
155
  # Return a copy to prevent accidental modification
156
- return {
157
- k: {kk: vv for kk, vv in v.items() if kk != "class"}
158
- for k, v in cls._model_providers.items()
159
- }
 
29
  "base_url_key": "OPENAI_BASE_URL",
30
  },
31
  "gemini": {
32
+ "class": ChatGoogleGenerativeAI,
33
  "env_key": "GOOGLE_API_KEY",
34
  "base_url_key": "GOOGLE_BASE_URL",
35
  },
 
42
  "grok": {
43
  "class": ChatXAI,
44
  "env_key": "XAI_API_KEY",
45
+ },
46
  # Add more providers with default configurations here
47
  }
48
 
49
  @classmethod
50
+ def register_provider(cls, prefix: str, model_class: Type[BaseLanguageModel], env_key: str, **kwargs) -> None:
 
 
51
  """Register a new model provider.
52
 
53
  Args:
 
79
  ValueError: If the required API key is missing
80
  """
81
  # Find the matching provider based on model name prefix
82
+ provider_prefix = next((prefix for prefix in cls._model_providers if model_name.startswith(prefix)), None)
 
 
83
 
84
  if not provider_prefix:
85
  raise ValueError(
 
149
  Dict[str, Dict[str, Any]]: Dictionary of registered providers and their configurations
150
  """
151
  # Return a copy to prevent accidental modification
152
+ return {k: {kk: vv for kk, vv in v.items() if kk != "class"} for k, v in cls._model_providers.items()}
 
 
 
medrax/rag/rag.py CHANGED
@@ -107,9 +107,7 @@ class CohereRAG:
107
  # Initialize Pinecone
108
  self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
109
  if not self.pinecone_api_key:
110
- raise ValueError(
111
- "PINECONE_API_KEY environment variable not set. Please get a key from app.pinecone.io"
112
- )
113
  self.pinecone = Pinecone(api_key=self.pinecone_api_key)
114
  self.index_name = self.config.pinecone_index_name
115
 
@@ -161,9 +159,7 @@ class CohereRAG:
161
  )
162
 
163
  print(f"Connecting to existing Pinecone index: {self.index_name}")
164
- vectorstore = PineconeVectorStore.from_existing_index(
165
- index_name=self.index_name, embedding=self.embeddings
166
- )
167
 
168
  # Check if the index is empty and needs to be populated
169
  try:
@@ -329,9 +325,7 @@ class CohereRAG:
329
  )
330
  documents.append(doc)
331
 
332
- print(
333
- f"Loaded {len(documents)} document chunks from HuggingFace dataset: {dataset_name}"
334
- )
335
  return documents
336
 
337
  except Exception as e:
 
107
  # Initialize Pinecone
108
  self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
109
  if not self.pinecone_api_key:
110
+ raise ValueError("PINECONE_API_KEY environment variable not set. Please get a key from app.pinecone.io")
 
 
111
  self.pinecone = Pinecone(api_key=self.pinecone_api_key)
112
  self.index_name = self.config.pinecone_index_name
113
 
 
159
  )
160
 
161
  print(f"Connecting to existing Pinecone index: {self.index_name}")
162
+ vectorstore = PineconeVectorStore.from_existing_index(index_name=self.index_name, embedding=self.embeddings)
 
 
163
 
164
  # Check if the index is empty and needs to be populated
165
  try:
 
325
  )
326
  documents.append(doc)
327
 
328
+ print(f"Loaded {len(documents)} document chunks from HuggingFace dataset: {dataset_name}")
 
 
329
  return documents
330
 
331
  except Exception as e:
medrax/tools/browsing/__init__.py CHANGED
@@ -6,8 +6,8 @@ from .web_browser import WebBrowserTool, WebBrowserSchema, SearchQuerySchema, Vi
6
  __all__ = [
7
  "DuckDuckGoSearchTool",
8
  "WebSearchInput",
9
- "WebBrowserTool",
10
  "WebBrowserSchema",
11
  "SearchQuerySchema",
12
- "VisitUrlSchema"
13
- ]
 
6
  __all__ = [
7
  "DuckDuckGoSearchTool",
8
  "WebSearchInput",
9
+ "WebBrowserTool",
10
  "WebBrowserSchema",
11
  "SearchQuerySchema",
12
+ "VisitUrlSchema",
13
+ ]
medrax/tools/browsing/duckduckgo.py CHANGED
@@ -95,18 +95,12 @@ class DuckDuckGoSearchTool(BaseTool):
95
  super().__init__(**kwargs)
96
 
97
  if DDGS is None:
98
- logger.error(
99
- "duckduckgo-search package not installed. Install with: pip install duckduckgo-search"
100
- )
101
- raise ImportError(
102
- "duckduckgo-search package is required for web search functionality"
103
- )
104
 
105
  logger.info("DuckDuckGo search tool initialized successfully")
106
 
107
- def _perform_search_sync(
108
- self, query: str, max_results: int = 5, region: str = "us-en"
109
- ) -> Dict[str, Any]:
110
  """
111
  Perform the actual web search using DuckDuckGo synchronously.
112
 
@@ -118,9 +112,7 @@ class DuckDuckGoSearchTool(BaseTool):
118
  Returns:
119
  Dict[str, Any]: Structured search results.
120
  """
121
- logger.info(
122
- f"Performing web search: '{query}' (max_results={max_results}, region={region})"
123
- )
124
 
125
  try:
126
  # Initialize DDGS with error handling
@@ -158,9 +150,7 @@ class DuckDuckGoSearchTool(BaseTool):
158
  summary = f"No results found for '{query}'"
159
 
160
  # Log successful completion
161
- logger.info(
162
- f"Web search completed successfully: {len(formatted_results)} results"
163
- )
164
 
165
  return {
166
  "query": query,
@@ -217,7 +207,7 @@ class DuckDuckGoSearchTool(BaseTool):
217
 
218
  try:
219
  result = self._perform_search_sync(query, max_results, region)
220
-
221
  # Check if search was successful
222
  if "error" in result:
223
  metadata["analysis_status"] = "failed"
@@ -239,7 +229,7 @@ class DuckDuckGoSearchTool(BaseTool):
239
  }
240
  metadata["analysis_status"] = "failed"
241
  metadata["error_details"] = str(e)
242
-
243
  return error_result, metadata
244
 
245
  async def _arun(
@@ -296,9 +286,7 @@ class DuckDuckGoSearchTool(BaseTool):
296
 
297
  # Use asyncio to run sync search in executor
298
  loop = asyncio.get_event_loop()
299
- result, metadata = await loop.run_in_executor(
300
- None, self._run, query, max_results, region
301
- )
302
 
303
  if writer:
304
  # Parse result to get count for progress update
@@ -333,7 +321,7 @@ class DuckDuckGoSearchTool(BaseTool):
333
  "search_engine": "DuckDuckGo",
334
  "timestamp": datetime.now().isoformat(),
335
  }
336
-
337
  metadata = {
338
  "query": query,
339
  "max_results": max_results,
@@ -344,12 +332,10 @@ class DuckDuckGoSearchTool(BaseTool):
344
  "analysis_status": "failed",
345
  "error_details": str(e),
346
  }
347
-
348
  return error_result, metadata
349
 
350
- def get_search_summary(
351
- self, query: str, max_results: int = 3
352
- ) -> dict[str, str | list[str]]:
353
  """
354
  Get a quick summary of search results for a given query.
355
 
@@ -375,14 +361,7 @@ class DuckDuckGoSearchTool(BaseTool):
375
  results = result.get("results", [])
376
  titles = [r["title"] for r in results]
377
  urls = [r["url"] for r in results]
378
- snippets = [
379
- (
380
- r["snippet"][:100] + "..."
381
- if len(r["snippet"]) > 100
382
- else r["snippet"]
383
- )
384
- for r in results
385
- ]
386
 
387
  return {
388
  "query": query,
 
95
  super().__init__(**kwargs)
96
 
97
  if DDGS is None:
98
+ logger.error("duckduckgo-search package not installed. Install with: pip install duckduckgo-search")
99
+ raise ImportError("duckduckgo-search package is required for web search functionality")
 
 
 
 
100
 
101
  logger.info("DuckDuckGo search tool initialized successfully")
102
 
103
+ def _perform_search_sync(self, query: str, max_results: int = 5, region: str = "us-en") -> Dict[str, Any]:
 
 
104
  """
105
  Perform the actual web search using DuckDuckGo synchronously.
106
 
 
112
  Returns:
113
  Dict[str, Any]: Structured search results.
114
  """
115
+ logger.info(f"Performing web search: '{query}' (max_results={max_results}, region={region})")
 
 
116
 
117
  try:
118
  # Initialize DDGS with error handling
 
150
  summary = f"No results found for '{query}'"
151
 
152
  # Log successful completion
153
+ logger.info(f"Web search completed successfully: {len(formatted_results)} results")
 
 
154
 
155
  return {
156
  "query": query,
 
207
 
208
  try:
209
  result = self._perform_search_sync(query, max_results, region)
210
+
211
  # Check if search was successful
212
  if "error" in result:
213
  metadata["analysis_status"] = "failed"
 
229
  }
230
  metadata["analysis_status"] = "failed"
231
  metadata["error_details"] = str(e)
232
+
233
  return error_result, metadata
234
 
235
  async def _arun(
 
286
 
287
  # Use asyncio to run sync search in executor
288
  loop = asyncio.get_event_loop()
289
+ result, metadata = await loop.run_in_executor(None, self._run, query, max_results, region)
 
 
290
 
291
  if writer:
292
  # Parse result to get count for progress update
 
321
  "search_engine": "DuckDuckGo",
322
  "timestamp": datetime.now().isoformat(),
323
  }
324
+
325
  metadata = {
326
  "query": query,
327
  "max_results": max_results,
 
332
  "analysis_status": "failed",
333
  "error_details": str(e),
334
  }
335
+
336
  return error_result, metadata
337
 
338
+ def get_search_summary(self, query: str, max_results: int = 3) -> dict[str, str | list[str]]:
 
 
339
  """
340
  Get a quick summary of search results for a given query.
341
 
 
361
  results = result.get("results", [])
362
  titles = [r["title"] for r in results]
363
  urls = [r["url"] for r in results]
364
+ snippets = [(r["snippet"][:100] + "..." if len(r["snippet"]) > 100 else r["snippet"]) for r in results]
 
 
 
 
 
 
 
365
 
366
  return {
367
  "query": query,
medrax/tools/browsing/web_browser.py CHANGED
@@ -78,9 +78,7 @@ class WebBrowserTool(BaseTool):
78
  max_results: int = 5
79
  args_schema: Type[BaseModel] = WebBrowserSchema
80
 
81
- def __init__(
82
- self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs
83
- ):
84
  """Initialize the web browser tool with optional search API credentials.
85
 
86
  Args:
@@ -145,9 +143,7 @@ class WebBrowserTool(BaseTool):
145
  except Exception as e:
146
  return {"error": f"Search failed: {str(e)}"}
147
 
148
- def visit_url(
149
- self, url: str, max_content_length: int = 5000, max_links: int = 5
150
- ) -> Dict[str, Any]:
151
  """Visit a URL and extract its content with comprehensive parsing.
152
 
153
  Args:
@@ -218,9 +214,7 @@ class WebBrowserTool(BaseTool):
218
  return {
219
  "title": title,
220
  "content": (
221
- text_content[:max_content_length]
222
- if len(text_content) > max_content_length
223
- else text_content
224
  ),
225
  "url": url,
226
  "links": links[:max_links], # Limit to max_links
 
78
  max_results: int = 5
79
  args_schema: Type[BaseModel] = WebBrowserSchema
80
 
81
+ def __init__(self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs):
 
 
82
  """Initialize the web browser tool with optional search API credentials.
83
 
84
  Args:
 
143
  except Exception as e:
144
  return {"error": f"Search failed: {str(e)}"}
145
 
146
+ def visit_url(self, url: str, max_content_length: int = 5000, max_links: int = 5) -> Dict[str, Any]:
 
 
147
  """Visit a URL and extract its content with comprehensive parsing.
148
 
149
  Args:
 
214
  return {
215
  "title": title,
216
  "content": (
217
+ text_content[:max_content_length] if len(text_content) > max_content_length else text_content
 
 
218
  ),
219
  "url": url,
220
  "links": links[:max_links], # Limit to max_links
medrax/tools/classification/__init__.py CHANGED
@@ -3,9 +3,4 @@
3
  from .torchxrayvision import TorchXRayVisionClassifierTool, TorchXRayVisionInput
4
  from .arcplus import ArcPlusClassifierTool, ArcPlusInput
5
 
6
- __all__ = [
7
- "TorchXRayVisionClassifierTool",
8
- "TorchXRayVisionInput",
9
- "ArcPlusClassifierTool",
10
- "ArcPlusInput"
11
- ]
 
3
  from .torchxrayvision import TorchXRayVisionClassifierTool, TorchXRayVisionInput
4
  from .arcplus import ArcPlusClassifierTool, ArcPlusInput
5
 
6
+ __all__ = ["TorchXRayVisionClassifierTool", "TorchXRayVisionInput", "ArcPlusClassifierTool", "ArcPlusInput"]
 
 
 
 
 
medrax/tools/classification/arcplus.py CHANGED
@@ -38,9 +38,7 @@ class OmniSwinTransformer(SwinTransformer):
38
 
39
  self.omni_heads = []
40
  for num_classes in num_classes_list:
41
- self.omni_heads.append(
42
- nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
43
- )
44
  self.omni_heads = nn.ModuleList(self.omni_heads)
45
 
46
  def forward(self, x, head_n=None):
@@ -62,9 +60,7 @@ class OmniSwinTransformer(SwinTransformer):
62
  class ArcPlusInput(BaseModel):
63
  """Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images."""
64
 
65
- image_path: str = Field(
66
- ..., description="Path to the radiology image file, only supports JPG or PNG images"
67
- )
68
 
69
 
70
  class ArcPlusClassifierTool(BaseTool):
@@ -249,11 +245,7 @@ class ArcPlusClassifierTool(BaseTool):
249
 
250
  # Remove "module." prefix if present (improved logic from example)
251
  if any([True if "module." in k else False for k in state_dict.keys()]):
252
- state_dict = {
253
- k.replace("module.", ""): v
254
- for k, v in state_dict.items()
255
- if k.startswith("module.")
256
- }
257
 
258
  # Load the model weights
259
  msg = self.model.load_state_dict(state_dict, strict=False)
@@ -342,14 +334,10 @@ class ArcPlusClassifierTool(BaseTool):
342
 
343
  # Map predictions to disease names
344
  if len(predictions) != len(self.disease_list):
345
- print(
346
- f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}"
347
- )
348
  # Pad or truncate as needed
349
  if len(predictions) < len(self.disease_list):
350
- predictions = np.pad(
351
- predictions, (0, len(self.disease_list) - len(predictions))
352
- )
353
  else:
354
  predictions = predictions[: len(self.disease_list)]
355
 
 
38
 
39
  self.omni_heads = []
40
  for num_classes in num_classes_list:
41
+ self.omni_heads.append(nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
 
 
42
  self.omni_heads = nn.ModuleList(self.omni_heads)
43
 
44
  def forward(self, x, head_n=None):
 
60
  class ArcPlusInput(BaseModel):
61
  """Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images."""
62
 
63
+ image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
 
 
64
 
65
 
66
  class ArcPlusClassifierTool(BaseTool):
 
245
 
246
  # Remove "module." prefix if present (improved logic from example)
247
  if any([True if "module." in k else False for k in state_dict.keys()]):
248
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if k.startswith("module.")}
 
 
 
 
249
 
250
  # Load the model weights
251
  msg = self.model.load_state_dict(state_dict, strict=False)
 
334
 
335
  # Map predictions to disease names
336
  if len(predictions) != len(self.disease_list):
337
+ print(f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}")
 
 
338
  # Pad or truncate as needed
339
  if len(predictions) < len(self.disease_list):
340
+ predictions = np.pad(predictions, (0, len(self.disease_list) - len(predictions)))
 
 
341
  else:
342
  predictions = predictions[: len(self.disease_list)]
343
 
medrax/tools/classification/torchxrayvision.py CHANGED
@@ -19,9 +19,7 @@ from medrax.utils.utils import preprocess_medical_image
19
  class TorchXRayVisionInput(BaseModel):
20
  """Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
21
 
22
- image_path: str = Field(
23
- ..., description="Path to the radiology image file, only supports JPG or PNG images"
24
- )
25
 
26
 
27
  class TorchXRayVisionClassifierTool(BaseTool):
 
19
  class TorchXRayVisionInput(BaseModel):
20
  """Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
21
 
22
+ image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
 
 
23
 
24
 
25
  class TorchXRayVisionClassifierTool(BaseTool):
medrax/tools/dicom.py CHANGED
@@ -14,9 +14,7 @@ class DicomProcessorInput(BaseModel):
14
  """Input schema for the DICOM Processor Tool."""
15
 
16
  dicom_path: str = Field(..., description="Path to the DICOM file")
17
- window_center: Optional[float] = Field(
18
- None, description="Window center for contrast adjustment"
19
- )
20
  window_width: Optional[float] = Field(None, description="Window width for contrast adjustment")
21
 
22
 
 
14
  """Input schema for the DICOM Processor Tool."""
15
 
16
  dicom_path: str = Field(..., description="Path to the DICOM file")
17
+ window_center: Optional[float] = Field(None, description="Window center for contrast adjustment")
 
 
18
  window_width: Optional[float] = Field(None, description="Window width for contrast adjustment")
19
 
20
 
medrax/tools/grounding.py CHANGED
@@ -90,11 +90,8 @@ class XRayPhraseGroundingTool(BaseTool):
90
  trust_remote_code=True,
91
  quantization_config=quantization_config,
92
  )
93
- self.processor = AutoProcessor.from_pretrained(
94
- model_path, cache_dir=cache_dir, trust_remote_code=True
95
- )
96
 
97
-
98
  self.model = self.model.eval()
99
 
100
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
@@ -176,12 +173,8 @@ class XRayPhraseGroundingTool(BaseTool):
176
  )
177
 
178
  prompt_length = inputs["input_ids"].shape[-1]
179
- decoded_text = self.processor.decode(
180
- output[0][prompt_length:], skip_special_tokens=True
181
- )
182
- predictions = self.processor.convert_output_to_plaintext_or_grounded_sequence(
183
- decoded_text
184
- )
185
 
186
  metadata = {
187
  "image_path": image_path,
@@ -208,9 +201,7 @@ class XRayPhraseGroundingTool(BaseTool):
208
  # Convert model bboxes to list format and get original image bboxes
209
  model_bboxes = [list(bbox) for bbox in pred_bboxes]
210
  original_bboxes = [
211
- self.processor.adjust_box_for_original_image_size(
212
- bbox, width=image.size[0], height=image.size[1]
213
- )
214
  for bbox in model_bboxes
215
  ]
216
 
 
90
  trust_remote_code=True,
91
  quantization_config=quantization_config,
92
  )
93
+ self.processor = AutoProcessor.from_pretrained(model_path, cache_dir=cache_dir, trust_remote_code=True)
 
 
94
 
 
95
  self.model = self.model.eval()
96
 
97
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
 
173
  )
174
 
175
  prompt_length = inputs["input_ids"].shape[-1]
176
+ decoded_text = self.processor.decode(output[0][prompt_length:], skip_special_tokens=True)
177
+ predictions = self.processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
 
 
 
 
178
 
179
  metadata = {
180
  "image_path": image_path,
 
201
  # Convert model bboxes to list format and get original image bboxes
202
  model_bboxes = [list(bbox) for bbox in pred_bboxes]
203
  original_bboxes = [
204
+ self.processor.adjust_box_for_original_image_size(bbox, width=image.size[0], height=image.size[1])
 
 
205
  for bbox in model_bboxes
206
  ]
207
 
medrax/tools/rag.py CHANGED
@@ -14,7 +14,7 @@ class RAGTool(BaseTool):
14
 
15
  The knowledge base includes:
16
  - Medical textbooks and reference materials
17
- - Research papers and clinical studies
18
  - Medical manuals and guidelines
19
  - Specialized medical literature
20
 
 
14
 
15
  The knowledge base includes:
16
  - Medical textbooks and reference materials
17
+ - Research papers and clinical studies
18
  - Medical manuals and guidelines
19
  - Specialized medical literature
20
 
medrax/tools/report_generation.py CHANGED
@@ -23,9 +23,7 @@ from transformers import (
23
  class ChestXRayInput(BaseModel):
24
  """Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
25
 
26
- image_path: str = Field(
27
- ..., description="Path to the radiology image file, only supports JPG or PNG images"
28
- )
29
 
30
 
31
  class ChestXRayReportGeneratorTool(BaseTool):
@@ -180,12 +178,8 @@ class ChestXRayReportGeneratorTool(BaseTool):
180
  """
181
  try:
182
  # Process image for both models
183
- findings_pixels = self._process_image(
184
- image_path, self.findings_processor, self.findings_model
185
- )
186
- impression_pixels = self._process_image(
187
- image_path, self.impression_processor, self.impression_model
188
- )
189
 
190
  # Generate both sections
191
  with torch.inference_mode():
@@ -197,11 +191,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
197
  )
198
 
199
  # Combine into formatted report
200
- report = (
201
- "CHEST X-RAY REPORT\n\n"
202
- f"FINDINGS:\n{findings_text}\n\n"
203
- f"IMPRESSION:\n{impression_text}"
204
- )
205
 
206
  output = {
207
  "report": report,
 
23
  class ChestXRayInput(BaseModel):
24
  """Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
25
 
26
+ image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
 
 
27
 
28
 
29
  class ChestXRayReportGeneratorTool(BaseTool):
 
178
  """
179
  try:
180
  # Process image for both models
181
+ findings_pixels = self._process_image(image_path, self.findings_processor, self.findings_model)
182
+ impression_pixels = self._process_image(image_path, self.impression_processor, self.impression_model)
 
 
 
 
183
 
184
  # Generate both sections
185
  with torch.inference_mode():
 
191
  )
192
 
193
  # Combine into formatted report
194
+ report = "CHEST X-RAY REPORT\n\n" f"FINDINGS:\n{findings_text}\n\n" f"IMPRESSION:\n{impression_text}"
 
 
 
 
195
 
196
  output = {
197
  "report": report,
medrax/tools/segmentation/__init__.py CHANGED
@@ -3,10 +3,4 @@
3
  from .segmentation import ChestXRaySegmentationTool, ChestXRaySegmentationInput, OrganMetrics
4
  from .medsam2 import MedSAM2Tool, MedSAM2Input
5
 
6
- __all__ = [
7
- "ChestXRaySegmentationTool",
8
- "ChestXRaySegmentationInput",
9
- "OrganMetrics",
10
- "MedSAM2Tool",
11
- "MedSAM2Input"
12
- ]
 
3
  from .segmentation import ChestXRaySegmentationTool, ChestXRaySegmentationInput, OrganMetrics
4
  from .medsam2 import MedSAM2Tool, MedSAM2Input
5
 
6
+ __all__ = ["ChestXRaySegmentationTool", "ChestXRaySegmentationInput", "OrganMetrics", "MedSAM2Tool", "MedSAM2Input"]
 
 
 
 
 
 
medrax/tools/segmentation/medsam2.py CHANGED
@@ -26,7 +26,6 @@ from hydra import initialize_config_dir
26
  from hydra.core.global_hydra import GlobalHydra
27
 
28
 
29
-
30
  class MedSAM2Input(BaseModel):
31
  """Input schema for the MedSAM2 Tool."""
32
 
@@ -47,7 +46,7 @@ class MedSAM2Input(BaseModel):
47
 
48
  class MedSAM2Tool(BaseTool):
49
  """Advanced medical image segmentation tool using MedSAM2.
50
-
51
  This tool provides state-of-the-art medical image segmentation capabilities using
52
  the MedSAM2 model, which is specifically adapted for medical imaging from Meta's SAM2.
53
  Supports interactive prompting with boxes, points, or automatic segmentation.
@@ -92,22 +91,17 @@ class MedSAM2Tool(BaseTool):
92
  # This works around the issue with initialize_config_module in sam2
93
  if GlobalHydra.instance().is_initialized():
94
  GlobalHydra.instance().clear()
95
-
96
  config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
97
  initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
98
-
99
  hf_hub_download(
100
- repo_id=model_path,
101
- filename=model_file,
102
- local_dir=self.cache_dir,
103
- local_dir_use_symlinks=False
104
  )
105
 
106
- config_path = model_cfg.replace('.yaml', '')
107
  sam2_model = build_sam2(config_path, str(self.cache_dir / model_file), device=device)
108
  self.predictor = SAM2ImagePredictor(sam2_model)
109
-
110
- print(f"MedSAM2 model loaded successfully on {device}")
111
 
112
  except Exception as e:
113
  raise RuntimeError(f"Failed to initialize MedSAM2: {str(e)}")
@@ -116,10 +110,10 @@ class MedSAM2Tool(BaseTool):
116
  """Load and preprocess image for medical analysis."""
117
  try:
118
  # Handle different image formats
119
- if image_path.lower().endswith('.dcm'):
120
  # DICOM files - would need DICOM processor
121
  raise ValueError("DICOM files not directly supported. Please convert to standard image format first.")
122
-
123
  # Load standard image formats
124
  image = Image.open(image_path)
125
 
@@ -131,29 +125,29 @@ class MedSAM2Tool(BaseTool):
131
  image = Image.fromarray(img_normalized, mode='L')
132
 
133
  # For medical images, convert to grayscale first if needed, then to RGB
134
- if image.mode == 'L': # Grayscale
135
  # Convert grayscale to RGB for SAM2
136
- image = image.convert('RGB')
137
- elif image.mode != 'RGB':
138
- if image.mode == 'RGBA':
139
  # Create white background for RGBA
140
- background = Image.new('RGB', image.size, (255, 255, 255))
141
  background.paste(image, mask=image.split()[-1])
142
  image = background
143
  else:
144
- image = image.convert('RGB')
145
-
146
  # Convert to numpy array
147
  image_np = np.array(image)
148
-
149
  # Ensure image is in proper range [0, 255]
150
  if image_np.max() <= 1.0:
151
  image_np = (image_np * 255).astype(np.uint8)
152
  else:
153
  image_np = image_np.astype(np.uint8)
154
-
155
  return image_np
156
-
157
  except Exception as e:
158
  raise ValueError(f"Failed to load image {image_path}: {str(e)}")
159
 
@@ -161,55 +155,53 @@ class MedSAM2Tool(BaseTool):
161
  """Process and validate prompts."""
162
  if prompt_type == "auto":
163
  return None, None, None
164
-
165
  if prompt_coords is None:
166
  if prompt_type != "auto":
167
  raise ValueError(f"Prompt coordinates required for prompt type '{prompt_type}'")
168
  return None, None, None
169
-
170
  if prompt_type == "box":
171
  if len(prompt_coords) != 4:
172
  raise ValueError("Box prompt requires 4 coordinates: [x1,y1,x2,y2]")
173
-
174
  x1, y1, x2, y2 = prompt_coords
175
  # Validate coordinates
176
  if x1 >= x2 or y1 >= y2:
177
  raise ValueError("Invalid box coordinates: x1 < x2 and y1 < y2 required")
178
-
179
  input_box = np.array([[x1, y1, x2, y2]])
180
  return input_box, None, None
181
-
182
  elif prompt_type == "point":
183
  if len(prompt_coords) != 2:
184
  raise ValueError("Point prompt requires 2 coordinates: [x,y]")
185
-
186
  x, y = prompt_coords
187
  input_point = np.array([[x, y]])
188
  input_label = np.array([1]) # Positive point
189
  return None, input_point, input_label
190
-
191
  else:
192
  raise ValueError(f"Unknown prompt type: {prompt_type}")
193
 
194
  def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
195
  """Create visualization of segmentation results."""
196
  plt.figure(figsize=(10, 10))
197
-
198
  # Convert RGB image to grayscale for background display
199
  if len(image.shape) == 3:
200
  # Convert RGB to grayscale using standard luminance formula
201
- gray_image = 0.299 * image[:,:,0] + 0.587 * image[:,:,1] + 0.114 * image[:,:,2]
202
  else:
203
  gray_image = image
204
-
205
  # Display grayscale background
206
- plt.imshow(
207
- gray_image, cmap="gray", extent=[0, image.shape[1], image.shape[0], 0]
208
- )
209
-
210
  # Generate color palette for multiple masks
211
  colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
212
-
213
  # Process and overlay each mask
214
  for idx, (mask, color) in enumerate(zip(masks, colors)):
215
  if mask.sum() > 0:
@@ -217,33 +209,31 @@ class MedSAM2Tool(BaseTool):
217
  mask_bool = mask.astype(bool)
218
  colored_mask = np.zeros((*mask_bool.shape, 4))
219
  colored_mask[mask_bool] = (*color[:3], 0.3) # 30% transparency like segmentation tool
220
- plt.imshow(
221
- colored_mask, extent=[0, image.shape[1], image.shape[0], 0]
222
- )
223
-
224
  # Add legend entry for each mask
225
  mask_label = f"Mask {idx + 1} (score: {prompt_info.get('scores', [0])[idx] if idx < len(prompt_info.get('scores', [])) else 0:.3f})"
226
  plt.plot([], [], color=color, label=mask_label, linewidth=3)
227
-
228
  # Add prompt visualization with consistent styling
229
- if prompt_info.get('box') is not None:
230
- box = prompt_info['box'][0]
231
  x1, y1, x2, y2 = box
232
- plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'g-', linewidth=2, label='Box Prompt')
233
-
234
- if prompt_info.get('point') is not None:
235
- point = prompt_info['point'][0]
236
- plt.plot(point[0], point[1], 'go', markersize=10, label='Point Prompt')
237
-
238
  plt.title("Segmentation Overlay")
239
  plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
240
  plt.axis("off")
241
-
242
  # Save visualization with higher DPI like segmentation tool
243
  viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
244
- plt.savefig(viz_path, bbox_inches='tight', dpi=300)
245
  plt.close()
246
-
247
  return str(viz_path)
248
 
249
  def _run(
@@ -258,28 +248,28 @@ class MedSAM2Tool(BaseTool):
258
  try:
259
  # Load image
260
  image = self._load_image(image_path)
261
-
262
  # Set image for predictor
263
  self.predictor.set_image(image)
264
-
265
  # Process prompts
266
- input_box, input_point, input_label = self._process_prompts(
267
- prompt_type, prompt_coords, image.shape[:2]
268
- )
269
-
270
  # Run inference
271
  if prompt_type == "auto":
272
  # For auto segmentation, try multiple approaches and select best result
273
  h, w = image.shape[:2]
274
-
275
  # Try multiple points in key areas for medical images
276
- sample_points = np.array([
277
- [w//3, h//3], # Upper left lung area
278
- [2*w//3, h//3], # Upper right lung area
279
- [w//2, 2*h//3], # Lower center area
280
- ])
 
 
281
  sample_labels = np.array([1, 1, 1]) # All positive points
282
-
283
  masks, scores, logits = self.predictor.predict(
284
  point_coords=sample_points,
285
  point_labels=sample_labels,
@@ -292,29 +282,29 @@ class MedSAM2Tool(BaseTool):
292
  box=input_box,
293
  multimask_output=True,
294
  )
295
-
296
  # Create visualization
297
  prompt_info = {
298
- 'box': input_box,
299
- 'point': input_point,
300
- 'type': prompt_type,
301
- 'scores': scores # Add scores for legend display
302
  }
303
  viz_path = self._create_visualization(image, masks, prompt_info)
304
-
305
  # Create output dictionary (main results)
306
  output = {
307
  "segmentation_image_path": viz_path,
308
- "confidence_scores": scores.tolist() if hasattr(scores, 'tolist') else list(scores),
309
  "num_masks": len(masks),
310
  "best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
311
  "mask_summary": {
312
  "total_masks": len(masks),
313
  "mask_shapes": [list(mask.shape) for mask in masks],
314
- "segmented_area_pixels": [int(mask.sum()) for mask in masks]
315
  },
316
  }
317
-
318
  # Create metadata dictionary
319
  metadata = {
320
  "image_path": image_path,
@@ -326,9 +316,9 @@ class MedSAM2Tool(BaseTool):
326
  "num_masks_generated": len(masks),
327
  "analysis_status": "completed",
328
  }
329
-
330
  return output, metadata
331
-
332
  except Exception as e:
333
  error_output = {"error": str(e)}
334
  error_metadata = {
@@ -347,4 +337,4 @@ class MedSAM2Tool(BaseTool):
347
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
348
  ) -> Tuple[Dict[str, Any], Dict]:
349
  """Async version of _run."""
350
- return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)
 
26
  from hydra.core.global_hydra import GlobalHydra
27
 
28
 
 
29
  class MedSAM2Input(BaseModel):
30
  """Input schema for the MedSAM2 Tool."""
31
 
 
46
 
47
  class MedSAM2Tool(BaseTool):
48
  """Advanced medical image segmentation tool using MedSAM2.
49
+
50
  This tool provides state-of-the-art medical image segmentation capabilities using
51
  the MedSAM2 model, which is specifically adapted for medical imaging from Meta's SAM2.
52
  Supports interactive prompting with boxes, points, or automatic segmentation.
 
91
  # This works around the issue with initialize_config_module in sam2
92
  if GlobalHydra.instance().is_initialized():
93
  GlobalHydra.instance().clear()
94
+
95
  config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
96
  initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
97
+
98
  hf_hub_download(
99
+ repo_id=model_path, filename=model_file, local_dir=self.cache_dir, local_dir_use_symlinks=False
 
 
 
100
  )
101
 
102
+ config_path = model_cfg.replace(".yaml", "")
103
  sam2_model = build_sam2(config_path, str(self.cache_dir / model_file), device=device)
104
  self.predictor = SAM2ImagePredictor(sam2_model)
 
 
105
 
106
  except Exception as e:
107
  raise RuntimeError(f"Failed to initialize MedSAM2: {str(e)}")
 
110
  """Load and preprocess image for medical analysis."""
111
  try:
112
  # Handle different image formats
113
+ if image_path.lower().endswith(".dcm"):
114
  # DICOM files - would need DICOM processor
115
  raise ValueError("DICOM files not directly supported. Please convert to standard image format first.")
116
+
117
  # Load standard image formats
118
  image = Image.open(image_path)
119
 
 
125
  image = Image.fromarray(img_normalized, mode='L')
126
 
127
  # For medical images, convert to grayscale first if needed, then to RGB
128
+ if image.mode == "L": # Grayscale
129
  # Convert grayscale to RGB for SAM2
130
+ image = image.convert("RGB")
131
+ elif image.mode != "RGB":
132
+ if image.mode == "RGBA":
133
  # Create white background for RGBA
134
+ background = Image.new("RGB", image.size, (255, 255, 255))
135
  background.paste(image, mask=image.split()[-1])
136
  image = background
137
  else:
138
+ image = image.convert("RGB")
139
+
140
  # Convert to numpy array
141
  image_np = np.array(image)
142
+
143
  # Ensure image is in proper range [0, 255]
144
  if image_np.max() <= 1.0:
145
  image_np = (image_np * 255).astype(np.uint8)
146
  else:
147
  image_np = image_np.astype(np.uint8)
148
+
149
  return image_np
150
+
151
  except Exception as e:
152
  raise ValueError(f"Failed to load image {image_path}: {str(e)}")
153
 
 
155
  """Process and validate prompts."""
156
  if prompt_type == "auto":
157
  return None, None, None
158
+
159
  if prompt_coords is None:
160
  if prompt_type != "auto":
161
  raise ValueError(f"Prompt coordinates required for prompt type '{prompt_type}'")
162
  return None, None, None
163
+
164
  if prompt_type == "box":
165
  if len(prompt_coords) != 4:
166
  raise ValueError("Box prompt requires 4 coordinates: [x1,y1,x2,y2]")
167
+
168
  x1, y1, x2, y2 = prompt_coords
169
  # Validate coordinates
170
  if x1 >= x2 or y1 >= y2:
171
  raise ValueError("Invalid box coordinates: x1 < x2 and y1 < y2 required")
172
+
173
  input_box = np.array([[x1, y1, x2, y2]])
174
  return input_box, None, None
175
+
176
  elif prompt_type == "point":
177
  if len(prompt_coords) != 2:
178
  raise ValueError("Point prompt requires 2 coordinates: [x,y]")
179
+
180
  x, y = prompt_coords
181
  input_point = np.array([[x, y]])
182
  input_label = np.array([1]) # Positive point
183
  return None, input_point, input_label
184
+
185
  else:
186
  raise ValueError(f"Unknown prompt type: {prompt_type}")
187
 
188
  def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
189
  """Create visualization of segmentation results."""
190
  plt.figure(figsize=(10, 10))
191
+
192
  # Convert RGB image to grayscale for background display
193
  if len(image.shape) == 3:
194
  # Convert RGB to grayscale using standard luminance formula
195
+ gray_image = 0.299 * image[:, :, 0] + 0.587 * image[:, :, 1] + 0.114 * image[:, :, 2]
196
  else:
197
  gray_image = image
198
+
199
  # Display grayscale background
200
+ plt.imshow(gray_image, cmap="gray", extent=[0, image.shape[1], image.shape[0], 0])
201
+
 
 
202
  # Generate color palette for multiple masks
203
  colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
204
+
205
  # Process and overlay each mask
206
  for idx, (mask, color) in enumerate(zip(masks, colors)):
207
  if mask.sum() > 0:
 
209
  mask_bool = mask.astype(bool)
210
  colored_mask = np.zeros((*mask_bool.shape, 4))
211
  colored_mask[mask_bool] = (*color[:3], 0.3) # 30% transparency like segmentation tool
212
+ plt.imshow(colored_mask, extent=[0, image.shape[1], image.shape[0], 0])
213
+
 
 
214
  # Add legend entry for each mask
215
  mask_label = f"Mask {idx + 1} (score: {prompt_info.get('scores', [0])[idx] if idx < len(prompt_info.get('scores', [])) else 0:.3f})"
216
  plt.plot([], [], color=color, label=mask_label, linewidth=3)
217
+
218
  # Add prompt visualization with consistent styling
219
+ if prompt_info.get("box") is not None:
220
+ box = prompt_info["box"][0]
221
  x1, y1, x2, y2 = box
222
+ plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], "g-", linewidth=2, label="Box Prompt")
223
+
224
+ if prompt_info.get("point") is not None:
225
+ point = prompt_info["point"][0]
226
+ plt.plot(point[0], point[1], "go", markersize=10, label="Point Prompt")
227
+
228
  plt.title("Segmentation Overlay")
229
  plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
230
  plt.axis("off")
231
+
232
  # Save visualization with higher DPI like segmentation tool
233
  viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
234
+ plt.savefig(viz_path, bbox_inches="tight", dpi=300)
235
  plt.close()
236
+
237
  return str(viz_path)
238
 
239
  def _run(
 
248
  try:
249
  # Load image
250
  image = self._load_image(image_path)
251
+
252
  # Set image for predictor
253
  self.predictor.set_image(image)
254
+
255
  # Process prompts
256
+ input_box, input_point, input_label = self._process_prompts(prompt_type, prompt_coords, image.shape[:2])
257
+
 
 
258
  # Run inference
259
  if prompt_type == "auto":
260
  # For auto segmentation, try multiple approaches and select best result
261
  h, w = image.shape[:2]
262
+
263
  # Try multiple points in key areas for medical images
264
+ sample_points = np.array(
265
+ [
266
+ [w // 3, h // 3], # Upper left lung area
267
+ [2 * w // 3, h // 3], # Upper right lung area
268
+ [w // 2, 2 * h // 3], # Lower center area
269
+ ]
270
+ )
271
  sample_labels = np.array([1, 1, 1]) # All positive points
272
+
273
  masks, scores, logits = self.predictor.predict(
274
  point_coords=sample_points,
275
  point_labels=sample_labels,
 
282
  box=input_box,
283
  multimask_output=True,
284
  )
285
+
286
  # Create visualization
287
  prompt_info = {
288
+ "box": input_box,
289
+ "point": input_point,
290
+ "type": prompt_type,
291
+ "scores": scores, # Add scores for legend display
292
  }
293
  viz_path = self._create_visualization(image, masks, prompt_info)
294
+
295
  # Create output dictionary (main results)
296
  output = {
297
  "segmentation_image_path": viz_path,
298
+ "confidence_scores": scores.tolist() if hasattr(scores, "tolist") else list(scores),
299
  "num_masks": len(masks),
300
  "best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
301
  "mask_summary": {
302
  "total_masks": len(masks),
303
  "mask_shapes": [list(mask.shape) for mask in masks],
304
+ "segmented_area_pixels": [int(mask.sum()) for mask in masks],
305
  },
306
  }
307
+
308
  # Create metadata dictionary
309
  metadata = {
310
  "image_path": image_path,
 
316
  "num_masks_generated": len(masks),
317
  "analysis_status": "completed",
318
  }
319
+
320
  return output, metadata
321
+
322
  except Exception as e:
323
  error_output = {"error": str(e)}
324
  error_metadata = {
 
337
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
338
  ) -> Tuple[Dict[str, Any], Dict]:
339
  """Async version of _run."""
340
+ return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)
medrax/tools/segmentation/segmentation.py CHANGED
@@ -43,9 +43,7 @@ class OrganMetrics(BaseModel):
43
  area_pixels: int = Field(..., description="Area in pixels")
44
  area_cm2: float = Field(..., description="Approximate area in cm²")
45
  centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid")
46
- bbox: Tuple[int, int, int, int] = Field(
47
- ..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)"
48
- )
49
 
50
  # Size metrics
51
  width: int = Field(..., description="Width of the organ in pixels")
@@ -53,9 +51,7 @@ class OrganMetrics(BaseModel):
53
  aspect_ratio: float = Field(..., description="Height/width ratio")
54
 
55
  # Position metrics
56
- relative_position: Dict[str, float] = Field(
57
- ..., description="Position relative to image boundaries (0-1 scale)"
58
- )
59
 
60
  # Analysis metrics
61
  mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region")
@@ -92,9 +88,7 @@ class ChestXRaySegmentationTool(BaseTool):
92
  self.model = self.model.to(self.device)
93
  self.model.eval()
94
 
95
- self.transform = torchvision.transforms.Compose(
96
- [xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)]
97
- )
98
 
99
  self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir)
100
  self.temp_dir.mkdir(exist_ok=True)
@@ -117,9 +111,7 @@ class ChestXRaySegmentationTool(BaseTool):
117
  "Spine": 13,
118
  }
119
 
120
- def _align_mask_to_original(
121
- self, mask: np.ndarray, original_shape: Tuple[int, int]
122
- ) -> np.ndarray:
123
  """
124
  Align a mask from the transformed (cropped/resized) space back to the full original image.
125
  Assumes that the transform does a center crop to a square of side = min(original height, width)
@@ -172,23 +164,17 @@ class ChestXRaySegmentationTool(BaseTool):
172
  bbox=tuple(map(int, props.bbox)),
173
  width=int(props.bbox[3] - props.bbox[1]),
174
  height=int(props.bbox[2] - props.bbox[0]),
175
- aspect_ratio=float(
176
- (props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1])
177
- ),
178
  relative_position=relative_pos,
179
  mean_intensity=float(mean_intensity),
180
  std_intensity=float(std_intensity),
181
  confidence_score=float(confidence),
182
  )
183
 
184
- def _save_visualization(
185
- self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]
186
- ) -> str:
187
  """Save visualization of original image with segmentation masks overlaid."""
188
  plt.figure(figsize=(10, 10))
189
- plt.imshow(
190
- original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0]
191
- )
192
 
193
  # Generate color palette for organs
194
  colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices)))
@@ -204,14 +190,10 @@ class ChestXRaySegmentationTool(BaseTool):
204
  # Create a colored overlay with transparency
205
  colored_mask = np.zeros((*original_img.shape, 4))
206
  colored_mask[mask > 0] = (*color[:3], 0.3)
207
- plt.imshow(
208
- colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0]
209
- )
210
 
211
  # Add legend entry for the organ
212
- organ_name = list(self.organ_map.keys())[
213
- list(self.organ_map.values()).index(organ_idx)
214
- ]
215
  plt.plot([], [], color=color, label=organ_name, linewidth=3)
216
 
217
  plt.title("Segmentation Overlay")
@@ -269,9 +251,7 @@ class ChestXRaySegmentationTool(BaseTool):
269
  for idx, organ_name in zip(organ_indices, organs):
270
  mask = pred_masks[0, idx].cpu().numpy()
271
  if mask.sum() > 0:
272
- metrics = self._compute_organ_metrics(
273
- mask, original_img, float(pred_probs[0, idx].mean().cpu())
274
- )
275
  if metrics:
276
  results[organ_name] = metrics
277
 
 
43
  area_pixels: int = Field(..., description="Area in pixels")
44
  area_cm2: float = Field(..., description="Approximate area in cm²")
45
  centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid")
46
+ bbox: Tuple[int, int, int, int] = Field(..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)")
 
 
47
 
48
  # Size metrics
49
  width: int = Field(..., description="Width of the organ in pixels")
 
51
  aspect_ratio: float = Field(..., description="Height/width ratio")
52
 
53
  # Position metrics
54
+ relative_position: Dict[str, float] = Field(..., description="Position relative to image boundaries (0-1 scale)")
 
 
55
 
56
  # Analysis metrics
57
  mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region")
 
88
  self.model = self.model.to(self.device)
89
  self.model.eval()
90
 
91
+ self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)])
 
 
92
 
93
  self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir)
94
  self.temp_dir.mkdir(exist_ok=True)
 
111
  "Spine": 13,
112
  }
113
 
114
+ def _align_mask_to_original(self, mask: np.ndarray, original_shape: Tuple[int, int]) -> np.ndarray:
 
 
115
  """
116
  Align a mask from the transformed (cropped/resized) space back to the full original image.
117
  Assumes that the transform does a center crop to a square of side = min(original height, width)
 
164
  bbox=tuple(map(int, props.bbox)),
165
  width=int(props.bbox[3] - props.bbox[1]),
166
  height=int(props.bbox[2] - props.bbox[0]),
167
+ aspect_ratio=float((props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1])),
 
 
168
  relative_position=relative_pos,
169
  mean_intensity=float(mean_intensity),
170
  std_intensity=float(std_intensity),
171
  confidence_score=float(confidence),
172
  )
173
 
174
+ def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str:
 
 
175
  """Save visualization of original image with segmentation masks overlaid."""
176
  plt.figure(figsize=(10, 10))
177
+ plt.imshow(original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0])
 
 
178
 
179
  # Generate color palette for organs
180
  colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices)))
 
190
  # Create a colored overlay with transparency
191
  colored_mask = np.zeros((*original_img.shape, 4))
192
  colored_mask[mask > 0] = (*color[:3], 0.3)
193
+ plt.imshow(colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0])
 
 
194
 
195
  # Add legend entry for the organ
196
+ organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)]
 
 
197
  plt.plot([], [], color=color, label=organ_name, linewidth=3)
198
 
199
  plt.title("Segmentation Overlay")
 
251
  for idx, organ_name in zip(organ_indices, organs):
252
  mask = pred_masks[0, idx].cpu().numpy()
253
  if mask.sum() > 0:
254
+ metrics = self._compute_organ_metrics(mask, original_img, float(pred_probs[0, idx].mean().cpu()))
 
 
255
  if metrics:
256
  results[organ_name] = metrics
257
 
medrax/tools/utils.py CHANGED
@@ -16,18 +16,10 @@ class ImageVisualizerInput(BaseModel):
16
 
17
  image_path: str = Field(..., description="Path to the image file to display, only supports JPG or PNG images")
18
  title: Optional[str] = Field(None, description="Optional title to display above the image")
19
- description: Optional[str] = Field(
20
- None, description="Optional description to display below the image"
21
- )
22
- width: Optional[int] = Field(
23
- 10, description="Optional figure width in inches"
24
- )
25
- height: Optional[int] = Field(
26
- 10, description="Optional figure height in inches"
27
- )
28
- cmap: Optional[str] = Field(
29
- "rgb", description="Optional colormap to use for displaying the image"
30
- )
31
 
32
 
33
  class ImageVisualizerTool(BaseTool):
@@ -65,9 +57,7 @@ class ImageVisualizerTool(BaseTool):
65
 
66
  # Add description if provided
67
  if description:
68
- plt.figtext(
69
- 0.5, 0.01, description, wrap=True, horizontalalignment="center", fontsize=10
70
- )
71
 
72
  # Adjust margins to minimize whitespace while preventing overlap
73
  plt.subplots_adjust(top=0.95, bottom=0.05, left=0.05, right=0.95)
 
16
 
17
  image_path: str = Field(..., description="Path to the image file to display, only supports JPG or PNG images")
18
  title: Optional[str] = Field(None, description="Optional title to display above the image")
19
+ description: Optional[str] = Field(None, description="Optional description to display below the image")
20
+ width: Optional[int] = Field(10, description="Optional figure width in inches")
21
+ height: Optional[int] = Field(10, description="Optional figure height in inches")
22
+ cmap: Optional[str] = Field("rgb", description="Optional colormap to use for displaying the image")
 
 
 
 
 
 
 
 
23
 
24
 
25
  class ImageVisualizerTool(BaseTool):
 
57
 
58
  # Add description if provided
59
  if description:
60
+ plt.figtext(0.5, 0.01, description, wrap=True, horizontalalignment="center", fontsize=10)
 
 
61
 
62
  # Adjust margins to minimize whitespace while preventing overlap
63
  plt.subplots_adjust(top=0.95, bottom=0.05, left=0.05, right=0.95)
medrax/tools/vqa/__init__.py CHANGED
@@ -1,16 +1,16 @@
1
  """Visual Question Answering tools for medical images."""
2
 
3
  from .llava_med import LlavaMedTool, LlavaMedInput
4
- from .xray_vqa import CheXagentXRayVQATool, XRayVQAToolInput
5
  from .medgemma.medgemma_client import MedGemmaAPIClientTool, MedGemmaVQAInput
6
  from .medgemma.medgemma_setup import setup_medgemma_env
7
 
8
  __all__ = [
9
  "LlavaMedTool",
10
  "LlavaMedInput",
11
- "CheXagentXRayVQATool",
12
  "XRayVQAToolInput",
13
  "MedGemmaAPIClientTool",
14
  "MedGemmaVQAInput",
15
- "setup_medgemma_env"
16
- ]
 
1
  """Visual Question Answering tools for medical images."""
2
 
3
  from .llava_med import LlavaMedTool, LlavaMedInput
4
+ from .xray_vqa import CheXagentXRayVQATool, XRayVQAToolInput
5
  from .medgemma.medgemma_client import MedGemmaAPIClientTool, MedGemmaVQAInput
6
  from .medgemma.medgemma_setup import setup_medgemma_env
7
 
8
  __all__ = [
9
  "LlavaMedTool",
10
  "LlavaMedInput",
11
+ "CheXagentXRayVQATool",
12
  "XRayVQAToolInput",
13
  "MedGemmaAPIClientTool",
14
  "MedGemmaVQAInput",
15
+ "setup_medgemma_env",
16
+ ]
medrax/tools/vqa/llava_med.py CHANGED
@@ -84,13 +84,7 @@ class LlavaMedTool(BaseTool):
84
  self, question: str, image_path: Optional[str] = None
85
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
86
  if self.model.config.mm_use_im_start_end:
87
- question = (
88
- DEFAULT_IM_START_TOKEN
89
- + DEFAULT_IMAGE_TOKEN
90
- + DEFAULT_IM_END_TOKEN
91
- + "\n"
92
- + question
93
- )
94
  else:
95
  question = DEFAULT_IMAGE_TOKEN + "\n" + question
96
 
@@ -100,9 +94,7 @@ class LlavaMedTool(BaseTool):
100
  prompt = conv.get_prompt()
101
 
102
  input_ids = (
103
- tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
104
- .unsqueeze(0)
105
- .cuda()
106
  )
107
 
108
  image_tensor = None
@@ -156,11 +148,11 @@ class LlavaMedTool(BaseTool):
156
  )
157
 
158
  answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
159
-
160
  output = {
161
  "answer": answer,
162
  }
163
-
164
  metadata = {
165
  "question": question,
166
  "image_path": image_path,
 
84
  self, question: str, image_path: Optional[str] = None
85
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
86
  if self.model.config.mm_use_im_start_end:
87
+ question = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + question
 
 
 
 
 
 
88
  else:
89
  question = DEFAULT_IMAGE_TOKEN + "\n" + question
90
 
 
94
  prompt = conv.get_prompt()
95
 
96
  input_ids = (
97
+ tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
 
 
98
  )
99
 
100
  image_tensor = None
 
148
  )
149
 
150
  answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
151
+
152
  output = {
153
  "answer": answer,
154
  }
155
+
156
  metadata = {
157
  "question": question,
158
  "image_path": image_path,
medrax/tools/vqa/medgemma/medgemma.py CHANGED
@@ -1,7 +1,6 @@
1
  import asyncio
2
  import os
3
  from pathlib import Path
4
- import sys
5
  import traceback
6
  from typing import Any, Dict, List, Optional, Tuple
7
  import uuid
@@ -22,6 +21,7 @@ UPLOAD_DIR = "./medgemma_images"
22
  # Create directories if they don't exist
23
  os.makedirs(UPLOAD_DIR, exist_ok=True)
24
 
 
25
  # Pydantic Models for API
26
  class VQAInput(BaseModel):
27
  """Input schema for the MedGemma VQA API endpoint.
@@ -100,7 +100,7 @@ class MedGemmaModel:
100
  device: Optional[str] = "cuda",
101
  dtype: torch.dtype = torch.bfloat16,
102
  cache_dir: Optional[str] = None,
103
- load_in_4bit: bool = True,
104
  **kwargs: Any,
105
  ) -> None:
106
  """Initialize the MedGemmaModel.
@@ -110,7 +110,7 @@ class MedGemmaModel:
110
  device: Device to run model on - "cuda" or "cpu" (default: "cuda")
111
  dtype: Data type for model weights - bfloat16 recommended for efficiency (default: torch.bfloat16)
112
  cache_dir: Directory to cache downloaded models (default: None)
113
- load_in_4bit: Whether to load model in 4-bit quantization for memory efficiency (default: True)
114
  **kwargs: Additional arguments passed to the model pipeline
115
 
116
  Raises:
@@ -140,8 +140,8 @@ class MedGemmaModel:
140
  "use_cache": True,
141
  }
142
 
143
- if load_in_4bit:
144
- model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
145
  model_kwargs["device_map"] = {"": self.device}
146
 
147
  try:
@@ -298,6 +298,12 @@ app = FastAPI(
298
  )
299
 
300
  medgemma_model: Optional[MedGemmaModel] = None
 
 
 
 
 
 
301
 
302
  @app.on_event("startup")
303
  async def startup_event():
@@ -316,7 +322,32 @@ async def startup_event():
316
  """
317
  global medgemma_model
318
  try:
319
- medgemma_model = MedGemmaModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  print("MedGemma model loaded successfully.")
321
  except RuntimeError as e:
322
  print(f"Error loading MedGemma model: {e}")
@@ -389,8 +420,12 @@ async def analyze_images(
389
  raise HTTPException(status_code=500, detail=f"Failed to save uploaded image: {str(e)}")
390
 
391
  try:
392
- # Generate AI analysis
393
- response_text = await medgemma_model.aget_response(image_paths, prompt, system_prompt, max_new_tokens)
 
 
 
 
394
 
395
  # Prepare success response
396
  metadata = {
@@ -428,7 +463,12 @@ async def analyze_images(
428
  if __name__ == "__main__":
429
  """Launch the MedGemma VQA API server.
430
 
431
- Starts the FastAPI application with uvicorn server, binding to all
432
- network interfaces on port 8002.
433
  """
434
- uvicorn.run(app, host="0.0.0.0", port=8002)
 
 
 
 
 
 
1
  import asyncio
2
  import os
3
  from pathlib import Path
 
4
  import traceback
5
  from typing import Any, Dict, List, Optional, Tuple
6
  import uuid
 
21
  # Create directories if they don't exist
22
  os.makedirs(UPLOAD_DIR, exist_ok=True)
23
 
24
+
25
  # Pydantic Models for API
26
  class VQAInput(BaseModel):
27
  """Input schema for the MedGemma VQA API endpoint.
 
100
  device: Optional[str] = "cuda",
101
  dtype: torch.dtype = torch.bfloat16,
102
  cache_dir: Optional[str] = None,
103
+ load_in_8bit: bool = True,
104
  **kwargs: Any,
105
  ) -> None:
106
  """Initialize the MedGemmaModel.
 
110
  device: Device to run model on - "cuda" or "cpu" (default: "cuda")
111
  dtype: Data type for model weights - bfloat16 recommended for efficiency (default: torch.bfloat16)
112
  cache_dir: Directory to cache downloaded models (default: None)
113
+ load_in_8bit: Whether to load model in 4-bit quantization for memory efficiency (default: True)
114
  **kwargs: Additional arguments passed to the model pipeline
115
 
116
  Raises:
 
140
  "use_cache": True,
141
  }
142
 
143
+ if load_in_8bit:
144
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
145
  model_kwargs["device_map"] = {"": self.device}
146
 
147
  try:
 
298
  )
299
 
300
  medgemma_model: Optional[MedGemmaModel] = None
301
+ inference_semaphore: Optional[asyncio.Semaphore] = None
302
+
303
+ @app.get("/health")
304
+ async def health():
305
+ """Health check endpoint."""
306
+ return {"status": "ok"}
307
 
308
  @app.on_event("startup")
309
  async def startup_event():
 
322
  """
323
  global medgemma_model
324
  try:
325
+ # Allow overriding Hugging Face cache directory and device via env vars
326
+ cache_dir_env = os.getenv("MEDGEMMA_CACHE_DIR")
327
+ device_env = os.getenv("MEDGEMMA_DEVICE")
328
+ max_concurrency_env = os.getenv("MEDGEMMA_MAX_CONCURRENCY", "1")
329
+
330
+ # Ensure the cache directory is writable; if not, fall back to a user cache
331
+ if cache_dir_env:
332
+ try:
333
+ os.makedirs(cache_dir_env, exist_ok=True)
334
+ if not os.access(cache_dir_env, os.W_OK):
335
+ raise PermissionError("Cache dir not writable")
336
+ except Exception:
337
+ fallback = os.path.join(Path.home(), ".cache", "medrax", "medgemma")
338
+ os.makedirs(fallback, exist_ok=True)
339
+ print(f"Warning: MEDGEMMA_CACHE_DIR '{cache_dir_env}' not writable. Falling back to '{fallback}'.")
340
+ cache_dir_env = fallback
341
+
342
+ medgemma_model = MedGemmaModel(cache_dir=cache_dir_env, device=device_env)
343
+ # Initialize concurrency gate
344
+ try:
345
+ max_concurrency = int(max_concurrency_env)
346
+ except ValueError:
347
+ max_concurrency = 1
348
+ max_concurrency = max(1, max_concurrency)
349
+ global inference_semaphore
350
+ inference_semaphore = asyncio.Semaphore(max_concurrency)
351
  print("MedGemma model loaded successfully.")
352
  except RuntimeError as e:
353
  print(f"Error loading MedGemma model: {e}")
 
420
  raise HTTPException(status_code=500, detail=f"Failed to save uploaded image: {str(e)}")
421
 
422
  try:
423
+ # Generate AI analysis with concurrency gating to avoid GPU contention timeouts
424
+ global inference_semaphore
425
+ if inference_semaphore is None:
426
+ inference_semaphore = asyncio.Semaphore(1)
427
+ async with inference_semaphore:
428
+ response_text = await medgemma_model.aget_response(image_paths, prompt, system_prompt, max_new_tokens)
429
 
430
  # Prepare success response
431
  metadata = {
 
463
  if __name__ == "__main__":
464
  """Launch the MedGemma VQA API server.
465
 
466
+ Reads MEDGEMMA_HOST and MEDGEMMA_PORT if provided; otherwise defaults
467
+ to 0.0.0.0:8002.
468
  """
469
+ host = os.getenv("MEDGEMMA_HOST", "0.0.0.0")
470
+ try:
471
+ port = int(os.getenv("MEDGEMMA_PORT", "8002"))
472
+ except ValueError:
473
+ port = 8002
474
+ uvicorn.run(app, host=host, port=port)
medrax/tools/vqa/medgemma/medgemma_client.py CHANGED
@@ -59,15 +59,21 @@ class MedGemmaAPIClientTool(BaseTool):
59
 
60
  # API configuration
61
  api_url: str # The URL of the running FastAPI service
 
 
62
 
63
- def __init__(self, api_url: str, **kwargs: Any):
64
  """Initialize the MedGemmaAPIClientTool.
65
 
66
  Args:
67
  api_url: The URL of the running MedGemma FastAPI service
 
 
 
68
  **kwargs: Additional arguments passed to BaseTool
69
  """
70
- super().__init__(api_url=api_url, **kwargs)
 
71
 
72
  def _prepare_request_data(
73
  self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int
@@ -154,7 +160,8 @@ class MedGemmaAPIClientTool(BaseTool):
154
  Tuple of output dictionary and metadata
155
  """
156
  # httpx is a modern HTTP client that supports sync and async
157
- timeout_config = httpx.Timeout(300.0, connect=10.0)
 
158
  client = httpx.Client(timeout=timeout_config)
159
 
160
  try:
@@ -238,11 +245,12 @@ class MedGemmaAPIClientTool(BaseTool):
238
  image_paths, prompt, system_prompt, max_new_tokens
239
  )
240
 
 
241
  response = await client.post(
242
  f"{self.api_url}/analyze-images/",
243
  data=data,
244
  files=files_to_send,
245
- timeout=120.0
246
  )
247
  response.raise_for_status()
248
 
 
59
 
60
  # API configuration
61
  api_url: str # The URL of the running FastAPI service
62
+ cache_dir: Optional[str] = None # Not used by the client directly, but accepted to keep a uniform constructor
63
+ device: Optional[str] = None
64
 
65
+ def __init__(self, api_url: str, cache_dir: Optional[str] = None, device: Optional[str] = None, timeout_seconds: Optional[float] = None, **kwargs: Any):
66
  """Initialize the MedGemmaAPIClientTool.
67
 
68
  Args:
69
  api_url: The URL of the running MedGemma FastAPI service
70
+ cache_dir: Optional local cache directory for model weights (accepted for interface consistency)
71
+ device: Optional device spec (accepted for interface consistency)
72
+ timeout_seconds: Optional request timeout override (seconds)
73
  **kwargs: Additional arguments passed to BaseTool
74
  """
75
+ super().__init__(api_url=api_url, cache_dir=cache_dir, device=device, **kwargs)
76
+ self._timeout_seconds = timeout_seconds
77
 
78
  def _prepare_request_data(
79
  self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int
 
160
  Tuple of output dictionary and metadata
161
  """
162
  # httpx is a modern HTTP client that supports sync and async
163
+ timeout_value = self._timeout_seconds if self._timeout_seconds is not None else 600.0
164
+ timeout_config = httpx.Timeout(timeout_value, connect=10.0)
165
  client = httpx.Client(timeout=timeout_config)
166
 
167
  try:
 
245
  image_paths, prompt, system_prompt, max_new_tokens
246
  )
247
 
248
+ timeout_value = self._timeout_seconds if self._timeout_seconds is not None else 600.0
249
  response = await client.post(
250
  f"{self.api_url}/analyze-images/",
251
  data=data,
252
  files=files_to_send,
253
+ timeout=timeout_value
254
  )
255
  response.raise_for_status()
256
 
medrax/tools/vqa/medgemma/medgemma_requirements_standard.txt CHANGED
@@ -52,4 +52,4 @@ typing_inspection==0.4.1
52
  urllib3==2.5.0
53
  uvicorn==0.35.0
54
  wcwidth==0.2.13
55
- zstandard==0.23.0
 
52
  urllib3==2.5.0
53
  uvicorn==0.35.0
54
  wcwidth==0.2.13
55
+ zstandard==0.23.0
medrax/tools/vqa/medgemma/medgemma_setup.py CHANGED
@@ -1,9 +1,61 @@
1
  import os
2
  from pathlib import Path
3
  import subprocess
 
 
4
  import venv
5
 
6
- def setup_medgemma_env():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """Set up MedGemma virtual environment and launch the FastAPI service.
8
 
9
  This function performs the following steps:
@@ -53,12 +105,47 @@ def setup_medgemma_env():
53
  if not env_dir.exists():
54
  raise RuntimeError("Failed to create MedGemma virtual environment")
55
 
56
- # Launch MedGemma FastAPI service
57
- print("Launching MedGemma FastAPI service...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  subprocess.Popen([
59
  str(python_executable),
60
  str(medgemma_path)
61
- ])
 
 
 
 
62
  # Note: stdout and stderr redirection commented out for debugging
63
  # stdout=subprocess.DEVNULL,
64
  # stderr=subprocess.DEVNULL,
 
1
  import os
2
  from pathlib import Path
3
  import subprocess
4
+ import socket
5
+ from contextlib import closing
6
  import venv
7
 
8
+ def _resolve_writable_cache_dir(preferred: str | None) -> str:
9
+ """Return a writable cache directory, falling back to user cache if needed."""
10
+ # Preferred path first
11
+ if preferred:
12
+ try:
13
+ os.makedirs(preferred, exist_ok=True)
14
+ if os.access(preferred, os.W_OK):
15
+ return preferred
16
+ except Exception:
17
+ pass
18
+ # Fallback path under user's home
19
+ fallback = os.path.join(Path.home(), ".cache", "medrax", "medgemma")
20
+ os.makedirs(fallback, exist_ok=True)
21
+ return fallback
22
+
23
+
24
+ def _is_port_free(host: str, port: int) -> bool:
25
+ """Return True if (host, port) is free to bind on this machine."""
26
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
27
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
28
+ try:
29
+ sock.bind((host, port))
30
+ return True
31
+ except OSError:
32
+ return False
33
+
34
+
35
+ def _find_free_loopback_and_port(start_octet: int = 2, end_octet: int = 254, base_port: int = 8002, max_port_tries: int = 50) -> tuple[str, int]:
36
+ """Find a free 127.0.0.X address and port combination.
37
+
38
+ Tries 127.0.0.2..127.0.0.254 each with ports base_port..base_port+max_port_tries
39
+ until a free pair is found. Falls back to 127.0.0.1 if none found for other octets.
40
+ """
41
+ # Try alternate loopback IPs first
42
+ for last_octet in range(start_octet, end_octet + 1):
43
+ host = f"127.0.0.{last_octet}"
44
+ for port in range(base_port, base_port + max_port_tries):
45
+ if _is_port_free(host, port):
46
+ return host, port
47
+ # Fallback: use 127.0.0.1 with port scan
48
+ host = "127.0.0.1"
49
+ for port in range(base_port, base_port + max_port_tries):
50
+ if _is_port_free(host, port):
51
+ return host, port
52
+ # Last resort: system-chosen ephemeral on 127.0.0.1
53
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
54
+ sock.bind((host, 0))
55
+ return host, sock.getsockname()[1]
56
+
57
+
58
+ def setup_medgemma_env(cache_dir: str | None = None, device: str | None = None) -> str:
59
  """Set up MedGemma virtual environment and launch the FastAPI service.
60
 
61
  This function performs the following steps:
 
105
  if not env_dir.exists():
106
  raise RuntimeError("Failed to create MedGemma virtual environment")
107
 
108
+ # Decide host/port to avoid collisions when multiple instances run
109
+ medgemma_host = os.getenv("MEDGEMMA_HOST")
110
+ medgemma_port_env = os.getenv("MEDGEMMA_PORT")
111
+ chosen_host: str
112
+ chosen_port: int
113
+ if medgemma_host and medgemma_port_env:
114
+ try:
115
+ port_val = int(medgemma_port_env)
116
+ except ValueError:
117
+ port_val = 8002
118
+ # If explicit host/port are provided, prefer them; if taken, try incrementing the port on the same host
119
+ chosen_host = medgemma_host
120
+ chosen_port = None
121
+ for p in range(port_val, port_val + 50):
122
+ if _is_port_free(medgemma_host, p):
123
+ chosen_port = p
124
+ break
125
+ if chosen_port is None:
126
+ print(f"No free ports in range {port_val}-{port_val+49} on {medgemma_host}; selecting a free loopback IP/port...")
127
+ chosen_host, chosen_port = _find_free_loopback_and_port()
128
+ else:
129
+ # Auto-pick a free loopback IP and port
130
+ chosen_host, chosen_port = _find_free_loopback_and_port()
131
+
132
+ print(f"Launching MedGemma FastAPI service on {chosen_host}:{chosen_port} ...")
133
+ env = os.environ.copy()
134
+ resolved_cache = _resolve_writable_cache_dir(cache_dir)
135
+ env["MEDGEMMA_CACHE_DIR"] = resolved_cache
136
+ if device:
137
+ env["MEDGEMMA_DEVICE"] = device
138
+ # Pass the chosen binding to the server via env
139
+ env["MEDGEMMA_HOST"] = chosen_host
140
+ env["MEDGEMMA_PORT"] = str(chosen_port)
141
  subprocess.Popen([
142
  str(python_executable),
143
  str(medgemma_path)
144
+ ], env=env)
145
+
146
+ # Return the base URL so callers can use it. If bound to 0.0.0.0, use 127.0.0.1 for local client access.
147
+ chosen_client_host = "127.0.0.1" if chosen_host in ("0.0.0.0", "::") else chosen_host
148
+ return f"http://{chosen_client_host}:{chosen_port}"
149
  # Note: stdout and stderr redirection commented out for debugging
150
  # stdout=subprocess.DEVNULL,
151
  # stderr=subprocess.DEVNULL,
medrax/tools/vqa/xray_vqa.py CHANGED
@@ -15,13 +15,9 @@ from langchain_core.tools import BaseTool
15
  class XRayVQAToolInput(BaseModel):
16
  """Input schema for the CheXagent Tool."""
17
 
18
- image_paths: List[str] = Field(
19
- ..., description="List of paths to chest X-ray images to analyze"
20
- )
21
  prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
22
- max_new_tokens: int = Field(
23
- 512, description="Maximum number of tokens to generate in the response"
24
- )
25
 
26
 
27
  class CheXagentXRayVQATool(BaseTool):
@@ -99,16 +95,14 @@ class CheXagentXRayVQATool(BaseTool):
99
  Returns:
100
  str: Model's response
101
  """
102
- query = self.tokenizer.from_list_format(
103
- [*[{"image": path} for path in image_paths], {"text": prompt}]
104
- )
105
  conv = [
106
  {"from": "system", "value": "You are a helpful assistant."},
107
  {"from": "human", "value": query},
108
  ]
109
- input_ids = self.tokenizer.apply_chat_template(
110
- conv, add_generation_prompt=True, return_tensors="pt"
111
- ).to(device=self.device)
112
 
113
  # Run inference
114
  with torch.inference_mode():
 
15
  class XRayVQAToolInput(BaseModel):
16
  """Input schema for the CheXagent Tool."""
17
 
18
+ image_paths: List[str] = Field(..., description="List of paths to chest X-ray images to analyze")
 
 
19
  prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
20
+ max_new_tokens: int = Field(512, description="Maximum number of tokens to generate in the response")
 
 
21
 
22
 
23
  class CheXagentXRayVQATool(BaseTool):
 
95
  Returns:
96
  str: Model's response
97
  """
98
+ query = self.tokenizer.from_list_format([*[{"image": path} for path in image_paths], {"text": prompt}])
 
 
99
  conv = [
100
  {"from": "system", "value": "You are a helpful assistant."},
101
  {"from": "human", "value": query},
102
  ]
103
+ input_ids = self.tokenizer.apply_chat_template(conv, add_generation_prompt=True, return_tensors="pt").to(
104
+ device=self.device
105
+ )
106
 
107
  # Run inference
108
  with torch.inference_mode():
medrax/tools/xray_generation.py CHANGED
@@ -11,26 +11,15 @@ from langchain_core.tools import BaseTool
11
 
12
  class ChestXRayGeneratorInput(BaseModel):
13
  """Input schema for the Chest X-Ray Generator Tool."""
14
-
15
  prompt: str = Field(
16
- ...,
17
- description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')"
18
- )
19
- height: int = Field(
20
- 512,
21
- description="Height of generated image in pixels"
22
- )
23
- width: int = Field(
24
- 512,
25
- description="Width of generated image in pixels"
26
- )
27
- num_inference_steps: int = Field(
28
- 75,
29
- description="Number of denoising steps (higher = better quality but slower)"
30
  )
 
 
 
31
  guidance_scale: float = Field(
32
- 4.0,
33
- description="How closely to follow the prompt (higher = more faithful but less diverse)"
34
  )
35
 
36
 
@@ -60,11 +49,11 @@ class ChestXRayGeneratorTool(BaseTool):
60
  ):
61
  """Initialize the chest X-ray generator tool."""
62
  super().__init__()
63
-
64
  self.device = torch.device(device) if device else "cuda"
65
  self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir)
66
  self.model = self.model.to(torch.float32).to(self.device)
67
-
68
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
69
  self.temp_dir.mkdir(exist_ok=True)
70
 
@@ -97,7 +86,7 @@ class ChestXRayGeneratorTool(BaseTool):
97
  num_inference_steps=num_inference_steps,
98
  height=height,
99
  width=width,
100
- guidance_scale=guidance_scale
101
  )
102
 
103
  # Save generated image
@@ -107,7 +96,7 @@ class ChestXRayGeneratorTool(BaseTool):
107
  output = {
108
  "image_path": str(image_path),
109
  }
110
-
111
  metadata = {
112
  "prompt": prompt,
113
  "num_inference_steps": num_inference_steps,
@@ -126,7 +115,7 @@ class ChestXRayGeneratorTool(BaseTool):
126
  "prompt": prompt,
127
  "analysis_status": "failed",
128
  "error_details": str(e),
129
- }
130
  )
131
 
132
  async def _arun(
@@ -139,4 +128,4 @@ class ChestXRayGeneratorTool(BaseTool):
139
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
140
  ) -> Tuple[Dict[str, str], Dict]:
141
  """Async version of _run."""
142
- return self._run(prompt, num_inference_steps, guidance_scale, height, width)
 
11
 
12
  class ChestXRayGeneratorInput(BaseModel):
13
  """Input schema for the Chest X-Ray Generator Tool."""
14
+
15
  prompt: str = Field(
16
+ ..., description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')"
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
+ height: int = Field(512, description="Height of generated image in pixels")
19
+ width: int = Field(512, description="Width of generated image in pixels")
20
+ num_inference_steps: int = Field(75, description="Number of denoising steps (higher = better quality but slower)")
21
  guidance_scale: float = Field(
22
+ 4.0, description="How closely to follow the prompt (higher = more faithful but less diverse)"
 
23
  )
24
 
25
 
 
49
  ):
50
  """Initialize the chest X-ray generator tool."""
51
  super().__init__()
52
+
53
  self.device = torch.device(device) if device else "cuda"
54
  self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir)
55
  self.model = self.model.to(torch.float32).to(self.device)
56
+
57
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
58
  self.temp_dir.mkdir(exist_ok=True)
59
 
 
86
  num_inference_steps=num_inference_steps,
87
  height=height,
88
  width=width,
89
+ guidance_scale=guidance_scale,
90
  )
91
 
92
  # Save generated image
 
96
  output = {
97
  "image_path": str(image_path),
98
  }
99
+
100
  metadata = {
101
  "prompt": prompt,
102
  "num_inference_steps": num_inference_steps,
 
115
  "prompt": prompt,
116
  "analysis_status": "failed",
117
  "error_details": str(e),
118
+ },
119
  )
120
 
121
  async def _arun(
 
128
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
129
  ) -> Tuple[Dict[str, str], Dict]:
130
  """Async version of _run."""
131
+ return self._run(prompt, num_inference_steps, guidance_scale, height, width)
pyproject.toml CHANGED
@@ -14,15 +14,15 @@ requires-python = ">=3.12"
14
  dependencies = [
15
  "requests>=2.25.0",
16
  "numpy>=1.19.0",
17
- "langchain>=0.1.0",
18
- "langchain-core>=0.1.0",
19
  "langchain-community>=0.0.20",
20
- "langchain-openai>=0.0.2",
21
- "langchain-cohere>=0.3.0,<0.4.0",
22
- "langchain-anthropic>=0.0.2",
23
- "langchain-xai>=0.0.1",
24
- "langchain-chroma>=0.0.10",
25
- "langgraph>=0.0.10",
26
  "hydra-core>=1.1.0",
27
  "python-dotenv>=0.19.0",
28
  "pandas>=1.5.0",
@@ -46,8 +46,9 @@ dependencies = [
46
  "gradio>=3.0.0",
47
  "gradio_client>=0.2.0",
48
  "httpx>=0.23.0",
49
- "uvicorn>=0.15.0",
50
  "fastapi>=0.68.0",
 
51
  "einops>=0.3.0",
52
  "einops-exts>=0.0.4",
53
  "timm==0.5.4",
@@ -73,6 +74,7 @@ dependencies = [
73
  "huggingface_hub>=0.17.0",
74
  "iopath>=0.1.10",
75
  "duckduckgo-search>=4.0.0",
 
76
  ]
77
 
78
  [project.optional-dependencies]
 
14
  dependencies = [
15
  "requests>=2.25.0",
16
  "numpy>=1.19.0",
17
+ "langchain>=0.3.26",
18
+ "langchain-core>=0.3.68",
19
  "langchain-community>=0.0.20",
20
+ "langchain-openai>=0.3.27",
21
+ "langchain-cohere>=0.3.5",
22
+ "langchain-anthropic>=0.3.17",
23
+ "langchain-xai>=0.2.4",
24
+ "langchain-chroma>=0.2.4",
25
+ "langgraph>=0.5.1",
26
  "hydra-core>=1.1.0",
27
  "python-dotenv>=0.19.0",
28
  "pandas>=1.5.0",
 
46
  "gradio>=3.0.0",
47
  "gradio_client>=0.2.0",
48
  "httpx>=0.23.0",
49
+ "uvicorn[standard]>=0.15.0",
50
  "fastapi>=0.68.0",
51
+ "python-multipart>=0.0.6",
52
  "einops>=0.3.0",
53
  "einops-exts>=0.0.4",
54
  "timm==0.5.4",
 
74
  "huggingface_hub>=0.17.0",
75
  "iopath>=0.1.10",
76
  "duckduckgo-search>=4.0.0",
77
+ "pyngrok>=7.0.0",
78
  ]
79
 
80
  [project.optional-dependencies]