Adibvafa commited on
Commit
9a2c640
·
1 Parent(s): f1b429f

Improve style

Browse files
Files changed (3) hide show
  1. api.py +73 -78
  2. interface.py +10 -23
  3. main.py +50 -112
api.py CHANGED
@@ -32,12 +32,13 @@ from medrax.agent import Agent
32
  class QueryRequest(BaseModel):
33
  """
34
  Request model for text-only queries.
35
-
36
  Attributes:
37
  question (str): The question or query to ask the agent
38
  system_prompt (Optional[str]): Custom system prompt to override default
39
  thread_id (Optional[str]): Optional thread ID for conversation continuity
40
  """
 
41
  question: str = Field(..., description="The question or query to ask the agent")
42
  system_prompt: Optional[str] = Field(None, description="Custom system prompt to override default")
43
  thread_id: Optional[str] = Field(None, description="Optional thread ID for conversation continuity")
@@ -46,13 +47,14 @@ class QueryRequest(BaseModel):
46
  class QueryResponse(BaseModel):
47
  """
48
  Response model for API queries.
49
-
50
  Attributes:
51
  response (str): The agent's text response
52
  thread_id (str): The thread ID used for this conversation
53
  tools_used (List[str]): List of tools that were executed
54
  processing_time (float): Time taken to process the request in seconds
55
  """
 
56
  response: str = Field(..., description="The agent's text response")
57
  thread_id: str = Field(..., description="The thread ID used for this conversation")
58
  tools_used: List[str] = Field(..., description="List of tools that were executed")
@@ -62,15 +64,15 @@ class QueryResponse(BaseModel):
62
  class MedRAXAPI:
63
  """
64
  FastAPI application wrapper for the MedRAX agent.
65
-
66
  This class provides a clean interface for creating and managing the API endpoints
67
  while maintaining separation of concerns from the core agent functionality.
68
  """
69
-
70
  def __init__(self, agent: Agent, tools_dict: Dict[str, Any], temp_dir: str = "temp_api"):
71
  """
72
  Initialize the MedRAX API.
73
-
74
  Args:
75
  agent (Agent): The initialized MedRAX agent
76
  tools_dict (Dict[str, Any]): Dictionary of available tools
@@ -80,16 +82,16 @@ class MedRAXAPI:
80
  self.tools_dict = tools_dict
81
  self.temp_dir = Path(temp_dir)
82
  self.temp_dir.mkdir(exist_ok=True)
83
-
84
  # Create FastAPI app
85
  self.app = FastAPI(
86
  title="MedRAX API",
87
  description="Medical Reasoning Agent for Chest X-ray Analysis",
88
  version="2.0.0",
89
  docs_url="/docs",
90
- redoc_url="/redoc"
91
  )
92
-
93
  # Add CORS middleware
94
  self.app.add_middleware(
95
  CORSMiddleware,
@@ -98,161 +100,154 @@ class MedRAXAPI:
98
  allow_methods=["*"],
99
  allow_headers=["*"],
100
  )
101
-
102
  # Register routes
103
  self._register_routes()
104
-
105
  def _register_routes(self):
106
  """Register all API routes."""
107
-
108
  @self.app.get("/health")
109
  async def health_check():
110
  """Health check endpoint."""
111
  return {"status": "healthy", "service": "MedRAX API"}
112
-
113
  @self.app.get("/tools")
114
  async def list_tools():
115
  """List available tools."""
116
- return {
117
- "available_tools": list(self.tools_dict.keys()),
118
- "total_count": len(self.tools_dict)
119
- }
120
-
121
  @self.app.post("/query", response_model=QueryResponse)
122
  async def query_text_only(request: QueryRequest):
123
  """
124
  Process a text-only query without images.
125
-
126
  Args:
127
  request (QueryRequest): The query request
128
-
129
  Returns:
130
  QueryResponse: The agent's response
131
  """
132
  return await self._process_query(
133
- question=request.question,
134
- system_prompt=request.system_prompt,
135
- thread_id=request.thread_id,
136
- images=None
137
  )
138
-
139
  @self.app.post("/query-with-images", response_model=QueryResponse)
140
  async def query_with_images(
141
  question: str = Form(..., description="The question or query to ask the agent"),
142
  system_prompt: Optional[str] = Form(None, description="Custom system prompt to override default"),
143
  thread_id: Optional[str] = Form(None, description="Optional thread ID for conversation continuity"),
144
- images: List[UploadFile] = File(..., description="One or more medical images to analyze")
145
  ):
146
  """
147
  Process a query with one or more images.
148
-
149
  Args:
150
  question (str): The question or query to ask the agent
151
  system_prompt (Optional[str]): Custom system prompt to override default
152
  thread_id (Optional[str]): Optional thread ID for conversation continuity
153
  images (List[UploadFile]): List of uploaded image files
154
-
155
  Returns:
156
  QueryResponse: The agent's response
157
  """
158
  # Validate image files
159
  if not images or len(images) == 0:
160
  raise HTTPException(status_code=400, detail="At least one image is required")
161
-
162
  # Validate file types
163
- allowed_types = {'image/jpeg', 'image/jpg', 'image/png', 'image/bmp', 'image/tiff', 'application/dicom'}
164
  for image in images:
165
  if image.content_type not in allowed_types:
166
  raise HTTPException(
167
- status_code=400,
168
- detail=f"Unsupported file type: {image.content_type}. Allowed types: {allowed_types}"
169
  )
170
-
171
  return await self._process_query(
172
- question=question,
173
- system_prompt=system_prompt,
174
- thread_id=thread_id,
175
- images=images
176
  )
177
-
178
  async def _process_query(
179
  self,
180
  question: str,
181
  system_prompt: Optional[str] = None,
182
  thread_id: Optional[str] = None,
183
- images: Optional[List[UploadFile]] = None
184
  ) -> QueryResponse:
185
  """
186
  Internal method to process queries through the agent.
187
-
188
  Args:
189
  question (str): The question to ask
190
  system_prompt (Optional[str]): Custom system prompt
191
  thread_id (Optional[str]): Thread ID for conversation
192
  images (Optional[List[UploadFile]]): List of images
193
-
194
  Returns:
195
  QueryResponse: The processed response
196
  """
197
  start_time = time.time()
198
-
199
  # Generate thread ID if not provided
200
  if not thread_id:
201
  thread_id = str(uuid.uuid4())
202
-
203
  try:
204
  # Prepare messages
205
  messages = []
206
  image_paths = []
207
-
208
  # Handle image uploads
209
  if images:
210
  for i, image in enumerate(images):
211
  # Save uploaded file temporarily
212
  temp_path = self.temp_dir / f"{thread_id}_{i}_{image.filename}"
213
-
214
  with open(temp_path, "wb") as buffer:
215
  content = await image.read()
216
  buffer.write(content)
217
-
218
  image_paths.append(str(temp_path))
219
-
220
  # Add image path for tools
221
  messages.append({"role": "user", "content": f"image_path: {temp_path}"})
222
-
223
  # Add base64 encoded image for multimodal processing
224
  image_base64 = base64.b64encode(content).decode("utf-8")
225
-
226
  # Determine MIME type
227
  mime_type = "image/jpeg" # Default
228
  if image.content_type:
229
  mime_type = image.content_type
230
- elif temp_path.suffix.lower() in ['.png']:
231
  mime_type = "image/png"
232
-
233
- messages.append({
234
- "role": "user",
235
- "content": [
236
- {
237
- "type": "image_url",
238
- "image_url": {"url": f"data:{mime_type};base64,{image_base64}"},
239
- }
240
- ],
241
- })
242
-
 
 
243
  # Add text question
244
  messages.append({"role": "user", "content": [{"type": "text", "text": question}]})
245
-
246
  # Process through agent workflow
247
  response_text = ""
248
  tools_used = []
249
-
250
  # Temporarily update system prompt if provided
251
  original_prompt = None
252
  if system_prompt:
253
  original_prompt = self.agent.system_prompt
254
  self.agent.system_prompt = system_prompt
255
-
256
  try:
257
  async for chunk in self._stream_agent_response(messages, thread_id):
258
  if chunk.get("type") == "text":
@@ -263,23 +258,23 @@ class MedRAXAPI:
263
  # Restore original system prompt
264
  if original_prompt is not None:
265
  self.agent.system_prompt = original_prompt
266
-
267
  # Clean up temporary files
268
  for image_path in image_paths:
269
  try:
270
  Path(image_path).unlink(missing_ok=True)
271
  except Exception:
272
  pass # Ignore cleanup errors
273
-
274
  processing_time = time.time() - start_time
275
-
276
  return QueryResponse(
277
  response=response_text.strip(),
278
  thread_id=thread_id,
279
  tools_used=list(set(tools_used)), # Remove duplicates
280
- processing_time=processing_time
281
  )
282
-
283
  except Exception as e:
284
  # Clean up on error
285
  for image_path in image_paths:
@@ -287,17 +282,17 @@ class MedRAXAPI:
287
  Path(image_path).unlink(missing_ok=True)
288
  except Exception:
289
  pass
290
-
291
  raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
292
-
293
  async def _stream_agent_response(self, messages: List[Dict], thread_id: str):
294
  """
295
  Stream responses from the agent workflow.
296
-
297
  Args:
298
  messages (List[Dict]): Messages to process
299
  thread_id (str): Thread ID for the conversation
300
-
301
  Yields:
302
  Dict: Response chunks with type and content
303
  """
@@ -309,24 +304,24 @@ class MedRAXAPI:
309
  ):
310
  if not isinstance(chunk, dict):
311
  continue
312
-
313
  for node_name, node_output in chunk.items():
314
  if "messages" not in node_output:
315
  continue
316
-
317
  for msg in node_output["messages"]:
318
  if isinstance(msg, AIMessage) and msg.content:
319
  # Clean up temp paths from response
320
  clean_content = re.sub(r"temp[^\s]*", "", msg.content).strip()
321
  if clean_content:
322
  yield {"type": "text", "content": clean_content}
323
-
324
  elif isinstance(msg, ToolMessage):
325
  # Extract tool name from the message
326
  tool_call_id = msg.tool_call_id
327
  # We'll track tool usage but not include detailed output in API response
328
  yield {"type": "tool", "tool_name": "tool_executed"}
329
-
330
  except Exception as e:
331
  yield {"type": "error", "content": str(e)}
332
 
@@ -334,12 +329,12 @@ class MedRAXAPI:
334
  def create_api(agent: Agent, tools_dict: Dict[str, Any], temp_dir: str = "temp_api") -> FastAPI:
335
  """
336
  Create and configure the MedRAX FastAPI application.
337
-
338
  Args:
339
  agent (Agent): The initialized MedRAX agent
340
  tools_dict (Dict[str, Any]): Dictionary of available tools
341
  temp_dir (str): Directory for temporary file storage
342
-
343
  Returns:
344
  FastAPI: Configured FastAPI application
345
  """
 
32
  class QueryRequest(BaseModel):
33
  """
34
  Request model for text-only queries.
35
+
36
  Attributes:
37
  question (str): The question or query to ask the agent
38
  system_prompt (Optional[str]): Custom system prompt to override default
39
  thread_id (Optional[str]): Optional thread ID for conversation continuity
40
  """
41
+
42
  question: str = Field(..., description="The question or query to ask the agent")
43
  system_prompt: Optional[str] = Field(None, description="Custom system prompt to override default")
44
  thread_id: Optional[str] = Field(None, description="Optional thread ID for conversation continuity")
 
47
  class QueryResponse(BaseModel):
48
  """
49
  Response model for API queries.
50
+
51
  Attributes:
52
  response (str): The agent's text response
53
  thread_id (str): The thread ID used for this conversation
54
  tools_used (List[str]): List of tools that were executed
55
  processing_time (float): Time taken to process the request in seconds
56
  """
57
+
58
  response: str = Field(..., description="The agent's text response")
59
  thread_id: str = Field(..., description="The thread ID used for this conversation")
60
  tools_used: List[str] = Field(..., description="List of tools that were executed")
 
64
  class MedRAXAPI:
65
  """
66
  FastAPI application wrapper for the MedRAX agent.
67
+
68
  This class provides a clean interface for creating and managing the API endpoints
69
  while maintaining separation of concerns from the core agent functionality.
70
  """
71
+
72
  def __init__(self, agent: Agent, tools_dict: Dict[str, Any], temp_dir: str = "temp_api"):
73
  """
74
  Initialize the MedRAX API.
75
+
76
  Args:
77
  agent (Agent): The initialized MedRAX agent
78
  tools_dict (Dict[str, Any]): Dictionary of available tools
 
82
  self.tools_dict = tools_dict
83
  self.temp_dir = Path(temp_dir)
84
  self.temp_dir.mkdir(exist_ok=True)
85
+
86
  # Create FastAPI app
87
  self.app = FastAPI(
88
  title="MedRAX API",
89
  description="Medical Reasoning Agent for Chest X-ray Analysis",
90
  version="2.0.0",
91
  docs_url="/docs",
92
+ redoc_url="/redoc",
93
  )
94
+
95
  # Add CORS middleware
96
  self.app.add_middleware(
97
  CORSMiddleware,
 
100
  allow_methods=["*"],
101
  allow_headers=["*"],
102
  )
103
+
104
  # Register routes
105
  self._register_routes()
106
+
107
  def _register_routes(self):
108
  """Register all API routes."""
109
+
110
  @self.app.get("/health")
111
  async def health_check():
112
  """Health check endpoint."""
113
  return {"status": "healthy", "service": "MedRAX API"}
114
+
115
  @self.app.get("/tools")
116
  async def list_tools():
117
  """List available tools."""
118
+ return {"available_tools": list(self.tools_dict.keys()), "total_count": len(self.tools_dict)}
119
+
 
 
 
120
  @self.app.post("/query", response_model=QueryResponse)
121
  async def query_text_only(request: QueryRequest):
122
  """
123
  Process a text-only query without images.
124
+
125
  Args:
126
  request (QueryRequest): The query request
127
+
128
  Returns:
129
  QueryResponse: The agent's response
130
  """
131
  return await self._process_query(
132
+ question=request.question, system_prompt=request.system_prompt, thread_id=request.thread_id, images=None
 
 
 
133
  )
134
+
135
  @self.app.post("/query-with-images", response_model=QueryResponse)
136
  async def query_with_images(
137
  question: str = Form(..., description="The question or query to ask the agent"),
138
  system_prompt: Optional[str] = Form(None, description="Custom system prompt to override default"),
139
  thread_id: Optional[str] = Form(None, description="Optional thread ID for conversation continuity"),
140
+ images: List[UploadFile] = File(..., description="One or more medical images to analyze"),
141
  ):
142
  """
143
  Process a query with one or more images.
144
+
145
  Args:
146
  question (str): The question or query to ask the agent
147
  system_prompt (Optional[str]): Custom system prompt to override default
148
  thread_id (Optional[str]): Optional thread ID for conversation continuity
149
  images (List[UploadFile]): List of uploaded image files
150
+
151
  Returns:
152
  QueryResponse: The agent's response
153
  """
154
  # Validate image files
155
  if not images or len(images) == 0:
156
  raise HTTPException(status_code=400, detail="At least one image is required")
157
+
158
  # Validate file types
159
+ allowed_types = {"image/jpeg", "image/jpg", "image/png", "image/bmp", "image/tiff", "application/dicom"}
160
  for image in images:
161
  if image.content_type not in allowed_types:
162
  raise HTTPException(
163
+ status_code=400,
164
+ detail=f"Unsupported file type: {image.content_type}. Allowed types: {allowed_types}",
165
  )
166
+
167
  return await self._process_query(
168
+ question=question, system_prompt=system_prompt, thread_id=thread_id, images=images
 
 
 
169
  )
170
+
171
  async def _process_query(
172
  self,
173
  question: str,
174
  system_prompt: Optional[str] = None,
175
  thread_id: Optional[str] = None,
176
+ images: Optional[List[UploadFile]] = None,
177
  ) -> QueryResponse:
178
  """
179
  Internal method to process queries through the agent.
180
+
181
  Args:
182
  question (str): The question to ask
183
  system_prompt (Optional[str]): Custom system prompt
184
  thread_id (Optional[str]): Thread ID for conversation
185
  images (Optional[List[UploadFile]]): List of images
186
+
187
  Returns:
188
  QueryResponse: The processed response
189
  """
190
  start_time = time.time()
191
+
192
  # Generate thread ID if not provided
193
  if not thread_id:
194
  thread_id = str(uuid.uuid4())
195
+
196
  try:
197
  # Prepare messages
198
  messages = []
199
  image_paths = []
200
+
201
  # Handle image uploads
202
  if images:
203
  for i, image in enumerate(images):
204
  # Save uploaded file temporarily
205
  temp_path = self.temp_dir / f"{thread_id}_{i}_{image.filename}"
206
+
207
  with open(temp_path, "wb") as buffer:
208
  content = await image.read()
209
  buffer.write(content)
210
+
211
  image_paths.append(str(temp_path))
212
+
213
  # Add image path for tools
214
  messages.append({"role": "user", "content": f"image_path: {temp_path}"})
215
+
216
  # Add base64 encoded image for multimodal processing
217
  image_base64 = base64.b64encode(content).decode("utf-8")
218
+
219
  # Determine MIME type
220
  mime_type = "image/jpeg" # Default
221
  if image.content_type:
222
  mime_type = image.content_type
223
+ elif temp_path.suffix.lower() in [".png"]:
224
  mime_type = "image/png"
225
+
226
+ messages.append(
227
+ {
228
+ "role": "user",
229
+ "content": [
230
+ {
231
+ "type": "image_url",
232
+ "image_url": {"url": f"data:{mime_type};base64,{image_base64}"},
233
+ }
234
+ ],
235
+ }
236
+ )
237
+
238
  # Add text question
239
  messages.append({"role": "user", "content": [{"type": "text", "text": question}]})
240
+
241
  # Process through agent workflow
242
  response_text = ""
243
  tools_used = []
244
+
245
  # Temporarily update system prompt if provided
246
  original_prompt = None
247
  if system_prompt:
248
  original_prompt = self.agent.system_prompt
249
  self.agent.system_prompt = system_prompt
250
+
251
  try:
252
  async for chunk in self._stream_agent_response(messages, thread_id):
253
  if chunk.get("type") == "text":
 
258
  # Restore original system prompt
259
  if original_prompt is not None:
260
  self.agent.system_prompt = original_prompt
261
+
262
  # Clean up temporary files
263
  for image_path in image_paths:
264
  try:
265
  Path(image_path).unlink(missing_ok=True)
266
  except Exception:
267
  pass # Ignore cleanup errors
268
+
269
  processing_time = time.time() - start_time
270
+
271
  return QueryResponse(
272
  response=response_text.strip(),
273
  thread_id=thread_id,
274
  tools_used=list(set(tools_used)), # Remove duplicates
275
+ processing_time=processing_time,
276
  )
277
+
278
  except Exception as e:
279
  # Clean up on error
280
  for image_path in image_paths:
 
282
  Path(image_path).unlink(missing_ok=True)
283
  except Exception:
284
  pass
285
+
286
  raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
287
+
288
  async def _stream_agent_response(self, messages: List[Dict], thread_id: str):
289
  """
290
  Stream responses from the agent workflow.
291
+
292
  Args:
293
  messages (List[Dict]): Messages to process
294
  thread_id (str): Thread ID for the conversation
295
+
296
  Yields:
297
  Dict: Response chunks with type and content
298
  """
 
304
  ):
305
  if not isinstance(chunk, dict):
306
  continue
307
+
308
  for node_name, node_output in chunk.items():
309
  if "messages" not in node_output:
310
  continue
311
+
312
  for msg in node_output["messages"]:
313
  if isinstance(msg, AIMessage) and msg.content:
314
  # Clean up temp paths from response
315
  clean_content = re.sub(r"temp[^\s]*", "", msg.content).strip()
316
  if clean_content:
317
  yield {"type": "text", "content": clean_content}
318
+
319
  elif isinstance(msg, ToolMessage):
320
  # Extract tool name from the message
321
  tool_call_id = msg.tool_call_id
322
  # We'll track tool usage but not include detailed output in API response
323
  yield {"type": "tool", "tool_name": "tool_executed"}
324
+
325
  except Exception as e:
326
  yield {"type": "error", "content": str(e)}
327
 
 
329
  def create_api(agent: Agent, tools_dict: Dict[str, Any], temp_dir: str = "temp_api") -> FastAPI:
330
  """
331
  Create and configure the MedRAX FastAPI application.
332
+
333
  Args:
334
  agent (Agent): The initialized MedRAX agent
335
  tools_dict (Dict[str, Any]): Dictionary of available tools
336
  temp_dir (str): Directory for temporary file storage
337
+
338
  Returns:
339
  FastAPI: Configured FastAPI application
340
  """
interface.py CHANGED
@@ -68,9 +68,7 @@ class ChatInterface:
68
 
69
  return self.display_file_path
70
 
71
- def add_message(
72
- self, message: str, display_image: str, history: List[dict]
73
- ) -> Tuple[List[dict], gr.Textbox]:
74
  """
75
  Add a new message to the chat history.
76
 
@@ -155,9 +153,7 @@ class ChatInterface:
155
  if isinstance(msg, AIMessageChunk) and msg.content:
156
  accumulated_content += msg.content
157
  if final_message is None:
158
- final_message = ChatMessage(
159
- role="assistant", content=accumulated_content
160
- )
161
  chat_history.append(final_message)
162
  else:
163
  final_message.content = accumulated_content
@@ -169,9 +165,7 @@ class ChatInterface:
169
  if final_message:
170
  final_message.content = final_content
171
  else:
172
- chat_history.append(
173
- ChatMessage(role="assistant", content=final_content)
174
- )
175
  yield chat_history, self.display_file_path, ""
176
 
177
  if msg.tool_calls:
@@ -204,7 +198,7 @@ class ChatInterface:
204
  except json.JSONDecodeError:
205
  result = msg.content
206
  tool_output_str = str(msg.content)
207
-
208
  # Display tool usage card
209
  tool_args_str = json.dumps(tool_args, indent=2)
210
  description = f"**Input:**\n```json\n{tool_args_str}\n```\n\n**Output:**\n```json\n{tool_output_str}\n```"
@@ -231,7 +225,7 @@ class ChatInterface:
231
  image_path = result[0]["image_path"]
232
  except (TypeError, KeyError, IndexError):
233
  pass
234
-
235
  if image_path:
236
  self.display_file_path = image_path
237
  chat_history.append(
@@ -240,16 +234,13 @@ class ChatInterface:
240
  content={"path": self.display_file_path},
241
  )
242
  )
243
-
244
  # Yield a single update for this tool event
245
  yield chat_history, self.display_file_path, ""
246
 
247
-
248
  except Exception as e:
249
  chat_history.append(
250
- ChatMessage(
251
- role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
252
- )
253
  )
254
  yield chat_history, self.display_file_path, ""
255
 
@@ -300,9 +291,7 @@ def create_demo(agent, tools_dict):
300
  )
301
 
302
  with gr.Column(scale=3):
303
- image_display = gr.Image(
304
- label="Image", type="filepath", height=600, container=True
305
- )
306
  with gr.Row():
307
  upload_button = gr.UploadButton(
308
  "📎 Upload X-Ray",
@@ -325,9 +314,7 @@ def create_demo(agent, tools_dict):
325
  def handle_file_upload(file):
326
  return interface.handle_upload(file.name)
327
 
328
- chat_msg = txt.submit(
329
- interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
330
- )
331
  bot_msg = chat_msg.then(
332
  interface.process_message,
333
  inputs=[txt, image_display, chatbot],
@@ -341,4 +328,4 @@ def create_demo(agent, tools_dict):
341
 
342
  new_chat_btn.click(new_chat, outputs=[chatbot, image_display])
343
 
344
- return demo
 
68
 
69
  return self.display_file_path
70
 
71
+ def add_message(self, message: str, display_image: str, history: List[dict]) -> Tuple[List[dict], gr.Textbox]:
 
 
72
  """
73
  Add a new message to the chat history.
74
 
 
153
  if isinstance(msg, AIMessageChunk) and msg.content:
154
  accumulated_content += msg.content
155
  if final_message is None:
156
+ final_message = ChatMessage(role="assistant", content=accumulated_content)
 
 
157
  chat_history.append(final_message)
158
  else:
159
  final_message.content = accumulated_content
 
165
  if final_message:
166
  final_message.content = final_content
167
  else:
168
+ chat_history.append(ChatMessage(role="assistant", content=final_content))
 
 
169
  yield chat_history, self.display_file_path, ""
170
 
171
  if msg.tool_calls:
 
198
  except json.JSONDecodeError:
199
  result = msg.content
200
  tool_output_str = str(msg.content)
201
+
202
  # Display tool usage card
203
  tool_args_str = json.dumps(tool_args, indent=2)
204
  description = f"**Input:**\n```json\n{tool_args_str}\n```\n\n**Output:**\n```json\n{tool_output_str}\n```"
 
225
  image_path = result[0]["image_path"]
226
  except (TypeError, KeyError, IndexError):
227
  pass
228
+
229
  if image_path:
230
  self.display_file_path = image_path
231
  chat_history.append(
 
234
  content={"path": self.display_file_path},
235
  )
236
  )
237
+
238
  # Yield a single update for this tool event
239
  yield chat_history, self.display_file_path, ""
240
 
 
241
  except Exception as e:
242
  chat_history.append(
243
+ ChatMessage(role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"})
 
 
244
  )
245
  yield chat_history, self.display_file_path, ""
246
 
 
291
  )
292
 
293
  with gr.Column(scale=3):
294
+ image_display = gr.Image(label="Image", type="filepath", height=600, container=True)
 
 
295
  with gr.Row():
296
  upload_button = gr.UploadButton(
297
  "📎 Upload X-Ray",
 
314
  def handle_file_upload(file):
315
  return interface.handle_upload(file.name)
316
 
317
+ chat_msg = txt.submit(interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt])
 
 
318
  bot_msg = chat_msg.then(
319
  interface.process_message,
320
  inputs=[txt, image_display, chatbot],
 
328
 
329
  new_chat_btn.click(new_chat, outputs=[chatbot, image_display])
330
 
331
+ return demo
main.py CHANGED
@@ -76,9 +76,7 @@ def initialize_agent(
76
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
77
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
78
  "CheXagentXRayVQATool": lambda: CheXagentXRayVQATool(cache_dir=model_dir, device=device),
79
- "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
80
- cache_dir=model_dir, device=device
81
- ),
82
  "XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
83
  cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
84
  ),
@@ -90,15 +88,13 @@ def initialize_agent(
90
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
91
  "WebBrowserTool": lambda: WebBrowserTool(),
92
  "DuckDuckGoSearchTool": lambda: DuckDuckGoSearchTool(),
93
- "MedSAM2Tool": lambda: MedSAM2Tool(
94
- device=device, cache_dir=model_dir, temp_dir=temp_dir
95
- ),
96
  "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(
97
  cache_dir=model_dir,
98
  device=device,
99
  load_in_8bit=True,
100
- api_url=os.getenv("MEDGEMMA_API_URL", "http://0.0.0.0:8002")
101
- )
102
  }
103
 
104
  # Initialize only selected tools or all if none specified
@@ -106,7 +102,7 @@ def initialize_agent(
106
 
107
  if tools_to_use is None:
108
  tools_to_use = []
109
-
110
  for tool_name in tools_to_use:
111
  if tool_name == "PythonSandboxTool":
112
  try:
@@ -116,16 +112,13 @@ def initialize_agent(
116
  print("Skipping PythonSandboxTool")
117
  if tool_name in all_tools:
118
  tools_dict[tool_name] = all_tools[tool_name]()
119
-
120
 
121
  # Set up checkpointing for conversation state
122
  checkpointer = MemorySaver()
123
 
124
  # Create the language model using the factory
125
  try:
126
- llm = ModelFactory.create_model(
127
- model_name=model, temperature=temperature, **model_kwargs
128
- )
129
  except ValueError as e:
130
  print(f"Error creating language model: {e}")
131
  print(f"Available model providers: {list(ModelFactory._model_providers.keys())}")
@@ -145,7 +138,7 @@ def initialize_agent(
145
  def run_gradio_interface(agent, tools_dict, host="0.0.0.0", port=8686):
146
  """
147
  Run the Gradio web interface.
148
-
149
  Args:
150
  agent: The initialized MedRAX agent
151
  tools_dict: Dictionary of available tools
@@ -160,7 +153,7 @@ def run_gradio_interface(agent, tools_dict, host="0.0.0.0", port=8686):
160
  def run_api_server(agent, tools_dict, host="0.0.0.0", port=8585, public=False):
161
  """
162
  Run the FastAPI server.
163
-
164
  Args:
165
  agent: The initialized MedRAX agent
166
  tools_dict: Dictionary of available tools
@@ -169,21 +162,23 @@ def run_api_server(agent, tools_dict, host="0.0.0.0", port=8585, public=False):
169
  public (bool): Whether to expose via ngrok tunnel
170
  """
171
  print(f"Starting API server on {host}:{port}")
172
-
173
  if public:
174
  try:
175
  public_tunnel = ngrok.connect(port)
176
  public_url = public_tunnel.public_url
177
- print(f"🌍 Public URL: {public_url}\n🌍 API Documentation: {public_url}/docs\n🌍 Share this URL with your friend!\n{'=' * 60}")
 
 
178
  except ImportError:
179
  print("⚠️ pyngrok not installed. Install with: pip install pyngrok\nRunning locally only...")
180
  public = False
181
  except Exception as e:
182
  print(f"⚠️ Failed to create public tunnel: {e}\nRunning locally only...")
183
  public = False
184
-
185
  app = create_api(agent, tools_dict)
186
-
187
  try:
188
  uvicorn.run(app, host=host, port=port)
189
  finally:
@@ -198,121 +193,74 @@ def run_api_server(agent, tools_dict, host="0.0.0.0", port=8585, public=False):
198
  def parse_arguments():
199
  """Parse command line arguments."""
200
  parser = argparse.ArgumentParser(description="MedRAX - Medical Reasoning Agent for Chest X-ray")
201
-
202
  # Server configuration
203
  parser.add_argument(
204
- "--mode",
205
- choices=["gradio", "api", "both"],
206
  default="gradio",
207
- help="Run mode: 'gradio' for web interface, 'api' for REST API, 'both' for both services"
208
  )
209
  parser.add_argument("--gradio-host", default="0.0.0.0", help="Gradio host address")
210
  parser.add_argument("--gradio-port", type=int, default=8686, help="Gradio port")
211
  parser.add_argument("--api-host", default="0.0.0.0", help="API host address")
212
  parser.add_argument("--api-port", type=int, default=8000, help="API port")
213
  parser.add_argument("--public", action="store_true", help="Make API publicly accessible via ngrok tunnel")
214
-
215
  # Model and system configuration
216
  parser.add_argument(
217
- "--model-dir",
218
  default="/model-weights",
219
- help="Directory containing model weights (default: uses MODEL_WEIGHTS_DIR env var or '/model-weights')"
220
  )
221
  parser.add_argument(
222
- "--device",
223
- default="cuda",
224
- help="Device to run models on (default: uses MEDRAX_DEVICE env var or 'cuda:1')"
225
  )
226
  parser.add_argument(
227
- "--model",
228
  default="gpt-4.1",
229
- help="Model to use (default: gpt-4.1). Examples: gpt-4.1-2025-04-14, gemini-2.5-pro, gpt-5"
230
- )
231
- parser.add_argument(
232
- "--temperature",
233
- type=float,
234
- default=1.0,
235
- help="Temperature for the model (default: 1.0)"
236
- )
237
- parser.add_argument(
238
- "--temp-dir",
239
- default="temp2",
240
- help="Directory for temporary files (default: temp2)"
241
  )
 
 
242
  parser.add_argument(
243
- "--prompt-file",
244
  default="medrax/docs/system_prompts.txt",
245
- help="Path to file containing system prompts (default: medrax/docs/system_prompts.txt)"
246
  )
247
  parser.add_argument(
248
- "--system-prompt",
249
- default="MEDICAL_ASSISTANT",
250
- help="System prompt to use (default: MEDICAL_ASSISTANT)"
251
  )
252
-
253
  # RAG configuration
254
  parser.add_argument(
255
- "--rag-model",
256
- default="command-a-03-2025",
257
- help="Chat model for RAG responses (default: command-a-03-2025)"
258
- )
259
- parser.add_argument(
260
- "--rag-embedding-model",
261
- default="embed-v4.0",
262
- help="Embedding model for RAG system (default: embed-v4.0)"
263
- )
264
- parser.add_argument(
265
- "--rag-rerank-model",
266
- default="rerank-v3.5",
267
- help="Reranking model for RAG system (default: rerank-v3.5)"
268
- )
269
- parser.add_argument(
270
- "--rag-temperature",
271
- type=float,
272
- default=0.3,
273
- help="Temperature for RAG model (default: 0.3)"
274
  )
275
  parser.add_argument(
276
- "--pinecone-index",
277
- default="medrax2",
278
- help="Pinecone index name (default: medrax2)"
279
  )
280
  parser.add_argument(
281
- "--chunk-size",
282
- type=int,
283
- default=1500,
284
- help="RAG chunk size (default: 1500)"
285
  )
286
- parser.add_argument(
287
- "--chunk-overlap",
288
- type=int,
289
- default=300,
290
- help="RAG chunk overlap (default: 300)"
291
- )
292
- parser.add_argument(
293
- "--retriever-k",
294
- type=int,
295
- default=3,
296
- help="Number of documents to retrieve (default: 3)"
297
- )
298
- parser.add_argument(
299
- "--rag-docs-dir",
300
- default="rag_docs",
301
- help="Directory for RAG documents (default: rag_docs)"
302
- )
303
-
304
  # Tools configuration
305
  parser.add_argument(
306
- "--tools",
307
  nargs="*",
308
- help="Specific tools to enable (if not provided, uses default set). Available tools: " +
309
- "ImageVisualizerTool, DicomProcessorTool, MedSAM2Tool, ChestXRaySegmentationTool, " +
310
- "ChestXRayGeneratorTool, TorchXRayVisionClassifierTool, ArcPlusClassifierTool, " +
311
- "ChestXRayReportGeneratorTool, XRayPhraseGroundingTool, MedGemmaVQATool, " +
312
- "XRayVQATool, LlavaMedTool, MedicalRAGTool, WebBrowserTool, DuckDuckGoSearchTool, " +
313
- "PythonSandboxTool"
314
  )
315
-
316
  return parser.parse_args()
317
 
318
 
@@ -334,36 +282,27 @@ if __name__ == "__main__":
334
  # Image Processing Tools
335
  "ImageVisualizerTool", # For displaying images in the UI
336
  # "DicomProcessorTool", # For processing DICOM medical image files
337
-
338
  # Segmentation Tools
339
  "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
340
  "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
341
-
342
  # Generation Tools
343
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
344
-
345
  # Classification Tools
346
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
347
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
348
-
349
  # Report Generation Tools
350
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
351
-
352
  # Grounding Tools
353
  "XRayPhraseGroundingTool", # For locating described features in X-rays
354
-
355
  # VQA Tools
356
  # "MedGemmaVQATool", # Google MedGemma VQA tool
357
  "XRayVQATool", # For visual question answering on X-rays
358
  # "LlavaMedTool", # For multimodal medical image understanding
359
-
360
  # RAG Tools
361
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
362
-
363
  # Search Tools
364
  # "WebBrowserTool", # For web browsing and search capabilities
365
  "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
366
-
367
  # Development Tools
368
  # "PythonSandboxTool", # Add the Python sandbox tool
369
  ]
@@ -424,11 +363,10 @@ if __name__ == "__main__":
424
  elif args.mode == "both":
425
  # Run both services in separate threads
426
  api_thread = threading.Thread(
427
- target=run_api_server,
428
- args=(agent, tools_dict, args.api_host, args.api_port, args.public)
429
  )
430
  api_thread.daemon = True
431
  api_thread.start()
432
-
433
  # Run Gradio in main thread
434
  run_gradio_interface(agent, tools_dict, args.gradio_host, args.gradio_port)
 
76
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
77
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
78
  "CheXagentXRayVQATool": lambda: CheXagentXRayVQATool(cache_dir=model_dir, device=device),
79
+ "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(cache_dir=model_dir, device=device),
 
 
80
  "XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
81
  cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
82
  ),
 
88
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
89
  "WebBrowserTool": lambda: WebBrowserTool(),
90
  "DuckDuckGoSearchTool": lambda: DuckDuckGoSearchTool(),
91
+ "MedSAM2Tool": lambda: MedSAM2Tool(device=device, cache_dir=model_dir, temp_dir=temp_dir),
 
 
92
  "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(
93
  cache_dir=model_dir,
94
  device=device,
95
  load_in_8bit=True,
96
+ api_url=os.getenv("MEDGEMMA_API_URL", "http://0.0.0.0:8002"),
97
+ ),
98
  }
99
 
100
  # Initialize only selected tools or all if none specified
 
102
 
103
  if tools_to_use is None:
104
  tools_to_use = []
105
+
106
  for tool_name in tools_to_use:
107
  if tool_name == "PythonSandboxTool":
108
  try:
 
112
  print("Skipping PythonSandboxTool")
113
  if tool_name in all_tools:
114
  tools_dict[tool_name] = all_tools[tool_name]()
 
115
 
116
  # Set up checkpointing for conversation state
117
  checkpointer = MemorySaver()
118
 
119
  # Create the language model using the factory
120
  try:
121
+ llm = ModelFactory.create_model(model_name=model, temperature=temperature, **model_kwargs)
 
 
122
  except ValueError as e:
123
  print(f"Error creating language model: {e}")
124
  print(f"Available model providers: {list(ModelFactory._model_providers.keys())}")
 
138
  def run_gradio_interface(agent, tools_dict, host="0.0.0.0", port=8686):
139
  """
140
  Run the Gradio web interface.
141
+
142
  Args:
143
  agent: The initialized MedRAX agent
144
  tools_dict: Dictionary of available tools
 
153
  def run_api_server(agent, tools_dict, host="0.0.0.0", port=8585, public=False):
154
  """
155
  Run the FastAPI server.
156
+
157
  Args:
158
  agent: The initialized MedRAX agent
159
  tools_dict: Dictionary of available tools
 
162
  public (bool): Whether to expose via ngrok tunnel
163
  """
164
  print(f"Starting API server on {host}:{port}")
165
+
166
  if public:
167
  try:
168
  public_tunnel = ngrok.connect(port)
169
  public_url = public_tunnel.public_url
170
+ print(
171
+ f"🌍 Public URL: {public_url}\n🌍 API Documentation: {public_url}/docs\n🌍 Share this URL with your friend!\n{'=' * 60}"
172
+ )
173
  except ImportError:
174
  print("⚠️ pyngrok not installed. Install with: pip install pyngrok\nRunning locally only...")
175
  public = False
176
  except Exception as e:
177
  print(f"⚠️ Failed to create public tunnel: {e}\nRunning locally only...")
178
  public = False
179
+
180
  app = create_api(agent, tools_dict)
181
+
182
  try:
183
  uvicorn.run(app, host=host, port=port)
184
  finally:
 
193
  def parse_arguments():
194
  """Parse command line arguments."""
195
  parser = argparse.ArgumentParser(description="MedRAX - Medical Reasoning Agent for Chest X-ray")
196
+
197
  # Server configuration
198
  parser.add_argument(
199
+ "--mode",
200
+ choices=["gradio", "api", "both"],
201
  default="gradio",
202
+ help="Run mode: 'gradio' for web interface, 'api' for REST API, 'both' for both services",
203
  )
204
  parser.add_argument("--gradio-host", default="0.0.0.0", help="Gradio host address")
205
  parser.add_argument("--gradio-port", type=int, default=8686, help="Gradio port")
206
  parser.add_argument("--api-host", default="0.0.0.0", help="API host address")
207
  parser.add_argument("--api-port", type=int, default=8000, help="API port")
208
  parser.add_argument("--public", action="store_true", help="Make API publicly accessible via ngrok tunnel")
209
+
210
  # Model and system configuration
211
  parser.add_argument(
212
+ "--model-dir",
213
  default="/model-weights",
214
+ help="Directory containing model weights (default: uses MODEL_WEIGHTS_DIR env var or '/model-weights')",
215
  )
216
  parser.add_argument(
217
+ "--device", default="cuda", help="Device to run models on (default: uses MEDRAX_DEVICE env var or 'cuda:1')"
 
 
218
  )
219
  parser.add_argument(
220
+ "--model",
221
  default="gpt-4.1",
222
+ help="Model to use (default: gpt-4.1). Examples: gpt-4.1-2025-04-14, gemini-2.5-pro, gpt-5",
 
 
 
 
 
 
 
 
 
 
 
223
  )
224
+ parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for the model (default: 1.0)")
225
+ parser.add_argument("--temp-dir", default="temp2", help="Directory for temporary files (default: temp2)")
226
  parser.add_argument(
227
+ "--prompt-file",
228
  default="medrax/docs/system_prompts.txt",
229
+ help="Path to file containing system prompts (default: medrax/docs/system_prompts.txt)",
230
  )
231
  parser.add_argument(
232
+ "--system-prompt", default="MEDICAL_ASSISTANT", help="System prompt to use (default: MEDICAL_ASSISTANT)"
 
 
233
  )
234
+
235
  # RAG configuration
236
  parser.add_argument(
237
+ "--rag-model", default="command-a-03-2025", help="Chat model for RAG responses (default: command-a-03-2025)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  )
239
  parser.add_argument(
240
+ "--rag-embedding-model", default="embed-v4.0", help="Embedding model for RAG system (default: embed-v4.0)"
 
 
241
  )
242
  parser.add_argument(
243
+ "--rag-rerank-model", default="rerank-v3.5", help="Reranking model for RAG system (default: rerank-v3.5)"
 
 
 
244
  )
245
+ parser.add_argument("--rag-temperature", type=float, default=0.3, help="Temperature for RAG model (default: 0.3)")
246
+ parser.add_argument("--pinecone-index", default="medrax2", help="Pinecone index name (default: medrax2)")
247
+ parser.add_argument("--chunk-size", type=int, default=1500, help="RAG chunk size (default: 1500)")
248
+ parser.add_argument("--chunk-overlap", type=int, default=300, help="RAG chunk overlap (default: 300)")
249
+ parser.add_argument("--retriever-k", type=int, default=3, help="Number of documents to retrieve (default: 3)")
250
+ parser.add_argument("--rag-docs-dir", default="rag_docs", help="Directory for RAG documents (default: rag_docs)")
251
+
 
 
 
 
 
 
 
 
 
 
 
252
  # Tools configuration
253
  parser.add_argument(
254
+ "--tools",
255
  nargs="*",
256
+ help="Specific tools to enable (if not provided, uses default set). Available tools: "
257
+ + "ImageVisualizerTool, DicomProcessorTool, MedSAM2Tool, ChestXRaySegmentationTool, "
258
+ + "ChestXRayGeneratorTool, TorchXRayVisionClassifierTool, ArcPlusClassifierTool, "
259
+ + "ChestXRayReportGeneratorTool, XRayPhraseGroundingTool, MedGemmaVQATool, "
260
+ + "XRayVQATool, LlavaMedTool, MedicalRAGTool, WebBrowserTool, DuckDuckGoSearchTool, "
261
+ + "PythonSandboxTool",
262
  )
263
+
264
  return parser.parse_args()
265
 
266
 
 
282
  # Image Processing Tools
283
  "ImageVisualizerTool", # For displaying images in the UI
284
  # "DicomProcessorTool", # For processing DICOM medical image files
 
285
  # Segmentation Tools
286
  "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
287
  "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
 
288
  # Generation Tools
289
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
 
290
  # Classification Tools
291
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
292
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
 
293
  # Report Generation Tools
294
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
 
295
  # Grounding Tools
296
  "XRayPhraseGroundingTool", # For locating described features in X-rays
 
297
  # VQA Tools
298
  # "MedGemmaVQATool", # Google MedGemma VQA tool
299
  "XRayVQATool", # For visual question answering on X-rays
300
  # "LlavaMedTool", # For multimodal medical image understanding
 
301
  # RAG Tools
302
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
 
303
  # Search Tools
304
  # "WebBrowserTool", # For web browsing and search capabilities
305
  "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
 
306
  # Development Tools
307
  # "PythonSandboxTool", # Add the Python sandbox tool
308
  ]
 
363
  elif args.mode == "both":
364
  # Run both services in separate threads
365
  api_thread = threading.Thread(
366
+ target=run_api_server, args=(agent, tools_dict, args.api_host, args.api_port, args.public)
 
367
  )
368
  api_thread.daemon = True
369
  api_thread.start()
370
+
371
  # Run Gradio in main thread
372
  run_gradio_interface(agent, tools_dict, args.gradio_host, args.gradio_port)