Merge branch 'main' into tool-changes
Browse files- api.py +342 -0
- benchmarking/llm_providers/medrax_provider.py +2 -1
- interface.py +46 -56
- main.py +374 -62
- medrax/agent/agent.py +13 -106
- medrax/docs/system_prompts.txt +3 -1
- medrax/llava/conversation.py +1 -3
- medrax/llava/eval/eval_multimodal_chat_gpt_score.py +3 -6
- medrax/llava/eval/llm.py +8 -23
- medrax/llava/eval/model_vqa.py +2 -8
- medrax/llava/eval/summarize_gpt_review.py +3 -7
- medrax/llava/mm_utils.py +4 -14
- medrax/llava/model/builder.py +4 -12
- medrax/llava/model/language_model/llava_mistral.py +1 -3
- medrax/llava/model/llava_arch.py +13 -39
- medrax/llava/model/multimodal_encoder/builder.py +2 -8
- medrax/llava/model/multimodal_projector/builder.py +1 -3
- medrax/llava/serve/cli.py +1 -3
- medrax/llava/serve/controller.py +3 -6
- medrax/llava/serve/gradio_web_server.py +4 -12
- medrax/llava/serve/model_worker.py +6 -14
- medrax/llava/serve/test_message.py +2 -6
- medrax/llava/utils.py +1 -3
- medrax/models/model_factory.py +5 -12
- medrax/rag/rag.py +3 -9
- medrax/tools/browsing/__init__.py +3 -3
- medrax/tools/browsing/duckduckgo.py +12 -33
- medrax/tools/browsing/web_browser.py +3 -9
- medrax/tools/classification/__init__.py +1 -6
- medrax/tools/classification/arcplus.py +5 -17
- medrax/tools/classification/torchxrayvision.py +1 -3
- medrax/tools/dicom.py +1 -3
- medrax/tools/grounding.py +4 -13
- medrax/tools/rag.py +1 -1
- medrax/tools/report_generation.py +4 -14
- medrax/tools/segmentation/__init__.py +1 -7
- medrax/tools/segmentation/medsam2.py +69 -79
- medrax/tools/segmentation/segmentation.py +10 -30
- medrax/tools/utils.py +5 -15
- medrax/tools/vqa/__init__.py +4 -4
- medrax/tools/vqa/llava_med.py +4 -12
- medrax/tools/vqa/medgemma/medgemma.py +51 -11
- medrax/tools/vqa/medgemma/medgemma_client.py +12 -4
- medrax/tools/vqa/medgemma/medgemma_requirements_standard.txt +1 -1
- medrax/tools/vqa/medgemma/medgemma_setup.py +91 -4
- medrax/tools/vqa/xray_vqa.py +6 -12
- medrax/tools/xray_generation.py +12 -23
- pyproject.toml +11 -9
api.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MedRAX API Module
|
| 3 |
+
|
| 4 |
+
This module provides a FastAPI-based REST API for the MedRAX medical imaging AI assistant.
|
| 5 |
+
It offers endpoints for processing medical images with text queries using the same agent
|
| 6 |
+
architecture as the Gradio interface.
|
| 7 |
+
|
| 8 |
+
The API supports:
|
| 9 |
+
- Text-only queries
|
| 10 |
+
- Single or multiple image inputs
|
| 11 |
+
- Optional custom system prompts
|
| 12 |
+
- Automatic thread management for each request
|
| 13 |
+
- Tool execution and result aggregation
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import uuid
|
| 17 |
+
import base64
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import List, Optional, Dict, Any
|
| 20 |
+
import re
|
| 21 |
+
import time
|
| 22 |
+
|
| 23 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
| 24 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 25 |
+
from pydantic import BaseModel, Field
|
| 26 |
+
from langchain_core.messages import AIMessage, ToolMessage
|
| 27 |
+
|
| 28 |
+
# Import MedRAX components
|
| 29 |
+
from medrax.agent import Agent
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class QueryRequest(BaseModel):
|
| 33 |
+
"""
|
| 34 |
+
Request model for text-only queries.
|
| 35 |
+
|
| 36 |
+
Attributes:
|
| 37 |
+
question (str): The question or query to ask the agent
|
| 38 |
+
system_prompt (Optional[str]): Custom system prompt to override default
|
| 39 |
+
thread_id (Optional[str]): Optional thread ID for conversation continuity
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
question: str = Field(..., description="The question or query to ask the agent")
|
| 43 |
+
system_prompt: Optional[str] = Field(None, description="Custom system prompt to override default")
|
| 44 |
+
thread_id: Optional[str] = Field(None, description="Optional thread ID for conversation continuity")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class QueryResponse(BaseModel):
|
| 48 |
+
"""
|
| 49 |
+
Response model for API queries.
|
| 50 |
+
|
| 51 |
+
Attributes:
|
| 52 |
+
response (str): The agent's text response
|
| 53 |
+
thread_id (str): The thread ID used for this conversation
|
| 54 |
+
tools_used (List[str]): List of tools that were executed
|
| 55 |
+
processing_time (float): Time taken to process the request in seconds
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
response: str = Field(..., description="The agent's text response")
|
| 59 |
+
thread_id: str = Field(..., description="The thread ID used for this conversation")
|
| 60 |
+
tools_used: List[str] = Field(..., description="List of tools that were executed")
|
| 61 |
+
processing_time: float = Field(..., description="Time taken to process the request in seconds")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class MedRAXAPI:
|
| 65 |
+
"""
|
| 66 |
+
FastAPI application wrapper for the MedRAX agent.
|
| 67 |
+
|
| 68 |
+
This class provides a clean interface for creating and managing the API endpoints
|
| 69 |
+
while maintaining separation of concerns from the core agent functionality.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, agent: Agent, tools_dict: Dict[str, Any], temp_dir: str = "temp_api"):
|
| 73 |
+
"""
|
| 74 |
+
Initialize the MedRAX API.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
agent (Agent): The initialized MedRAX agent
|
| 78 |
+
tools_dict (Dict[str, Any]): Dictionary of available tools
|
| 79 |
+
temp_dir (str): Directory for temporary file storage
|
| 80 |
+
"""
|
| 81 |
+
self.agent = agent
|
| 82 |
+
self.tools_dict = tools_dict
|
| 83 |
+
self.temp_dir = Path(temp_dir)
|
| 84 |
+
self.temp_dir.mkdir(exist_ok=True)
|
| 85 |
+
|
| 86 |
+
# Create FastAPI app
|
| 87 |
+
self.app = FastAPI(
|
| 88 |
+
title="MedRAX API",
|
| 89 |
+
description="Medical Reasoning Agent for Chest X-ray Analysis",
|
| 90 |
+
version="2.0.0",
|
| 91 |
+
docs_url="/docs",
|
| 92 |
+
redoc_url="/redoc",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Add CORS middleware
|
| 96 |
+
self.app.add_middleware(
|
| 97 |
+
CORSMiddleware,
|
| 98 |
+
allow_origins=["*"],
|
| 99 |
+
allow_credentials=True,
|
| 100 |
+
allow_methods=["*"],
|
| 101 |
+
allow_headers=["*"],
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Register routes
|
| 105 |
+
self._register_routes()
|
| 106 |
+
|
| 107 |
+
def _register_routes(self):
|
| 108 |
+
"""Register all API routes."""
|
| 109 |
+
|
| 110 |
+
@self.app.get("/health")
|
| 111 |
+
async def health_check():
|
| 112 |
+
"""Health check endpoint."""
|
| 113 |
+
return {"status": "healthy", "service": "MedRAX API"}
|
| 114 |
+
|
| 115 |
+
@self.app.get("/tools")
|
| 116 |
+
async def list_tools():
|
| 117 |
+
"""List available tools."""
|
| 118 |
+
return {"available_tools": list(self.tools_dict.keys()), "total_count": len(self.tools_dict)}
|
| 119 |
+
|
| 120 |
+
@self.app.post("/query", response_model=QueryResponse)
|
| 121 |
+
async def query_text_only(request: QueryRequest):
|
| 122 |
+
"""
|
| 123 |
+
Process a text-only query without images.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
request (QueryRequest): The query request
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
QueryResponse: The agent's response
|
| 130 |
+
"""
|
| 131 |
+
return await self._process_query(
|
| 132 |
+
question=request.question, system_prompt=request.system_prompt, thread_id=request.thread_id, images=None
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
@self.app.post("/query-with-images", response_model=QueryResponse)
|
| 136 |
+
async def query_with_images(
|
| 137 |
+
question: str = Form(..., description="The question or query to ask the agent"),
|
| 138 |
+
system_prompt: Optional[str] = Form(None, description="Custom system prompt to override default"),
|
| 139 |
+
thread_id: Optional[str] = Form(None, description="Optional thread ID for conversation continuity"),
|
| 140 |
+
images: List[UploadFile] = File(..., description="One or more medical images to analyze"),
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
Process a query with one or more images.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
question (str): The question or query to ask the agent
|
| 147 |
+
system_prompt (Optional[str]): Custom system prompt to override default
|
| 148 |
+
thread_id (Optional[str]): Optional thread ID for conversation continuity
|
| 149 |
+
images (List[UploadFile]): List of uploaded image files
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
QueryResponse: The agent's response
|
| 153 |
+
"""
|
| 154 |
+
# Validate image files
|
| 155 |
+
if not images or len(images) == 0:
|
| 156 |
+
raise HTTPException(status_code=400, detail="At least one image is required")
|
| 157 |
+
|
| 158 |
+
# Validate file types
|
| 159 |
+
allowed_types = {"image/jpeg", "image/jpg", "image/png", "image/bmp", "image/tiff", "application/dicom"}
|
| 160 |
+
for image in images:
|
| 161 |
+
if image.content_type not in allowed_types:
|
| 162 |
+
raise HTTPException(
|
| 163 |
+
status_code=400,
|
| 164 |
+
detail=f"Unsupported file type: {image.content_type}. Allowed types: {allowed_types}",
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
return await self._process_query(
|
| 168 |
+
question=question, system_prompt=system_prompt, thread_id=thread_id, images=images
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
async def _process_query(
|
| 172 |
+
self,
|
| 173 |
+
question: str,
|
| 174 |
+
system_prompt: Optional[str] = None,
|
| 175 |
+
thread_id: Optional[str] = None,
|
| 176 |
+
images: Optional[List[UploadFile]] = None,
|
| 177 |
+
) -> QueryResponse:
|
| 178 |
+
"""
|
| 179 |
+
Internal method to process queries through the agent.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
question (str): The question to ask
|
| 183 |
+
system_prompt (Optional[str]): Custom system prompt
|
| 184 |
+
thread_id (Optional[str]): Thread ID for conversation
|
| 185 |
+
images (Optional[List[UploadFile]]): List of images
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
QueryResponse: The processed response
|
| 189 |
+
"""
|
| 190 |
+
start_time = time.time()
|
| 191 |
+
|
| 192 |
+
# Generate thread ID if not provided
|
| 193 |
+
if not thread_id:
|
| 194 |
+
thread_id = str(uuid.uuid4())
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
# Prepare messages
|
| 198 |
+
messages = []
|
| 199 |
+
image_paths = []
|
| 200 |
+
|
| 201 |
+
# Handle image uploads
|
| 202 |
+
if images:
|
| 203 |
+
for i, image in enumerate(images):
|
| 204 |
+
# Save uploaded file temporarily
|
| 205 |
+
temp_path = self.temp_dir / f"{thread_id}_{i}_{image.filename}"
|
| 206 |
+
|
| 207 |
+
with open(temp_path, "wb") as buffer:
|
| 208 |
+
content = await image.read()
|
| 209 |
+
buffer.write(content)
|
| 210 |
+
|
| 211 |
+
image_paths.append(str(temp_path))
|
| 212 |
+
|
| 213 |
+
# Add image path for tools
|
| 214 |
+
messages.append({"role": "user", "content": f"image_path: {temp_path}"})
|
| 215 |
+
|
| 216 |
+
# Add base64 encoded image for multimodal processing
|
| 217 |
+
image_base64 = base64.b64encode(content).decode("utf-8")
|
| 218 |
+
|
| 219 |
+
# Determine MIME type
|
| 220 |
+
mime_type = "image/jpeg" # Default
|
| 221 |
+
if image.content_type:
|
| 222 |
+
mime_type = image.content_type
|
| 223 |
+
elif temp_path.suffix.lower() in [".png"]:
|
| 224 |
+
mime_type = "image/png"
|
| 225 |
+
|
| 226 |
+
messages.append(
|
| 227 |
+
{
|
| 228 |
+
"role": "user",
|
| 229 |
+
"content": [
|
| 230 |
+
{
|
| 231 |
+
"type": "image_url",
|
| 232 |
+
"image_url": {"url": f"data:{mime_type};base64,{image_base64}"},
|
| 233 |
+
}
|
| 234 |
+
],
|
| 235 |
+
}
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Add text question
|
| 239 |
+
messages.append({"role": "user", "content": [{"type": "text", "text": question}]})
|
| 240 |
+
|
| 241 |
+
# Process through agent workflow
|
| 242 |
+
response_text = ""
|
| 243 |
+
tools_used = []
|
| 244 |
+
|
| 245 |
+
# Temporarily update system prompt if provided
|
| 246 |
+
original_prompt = None
|
| 247 |
+
if system_prompt:
|
| 248 |
+
original_prompt = self.agent.system_prompt
|
| 249 |
+
self.agent.system_prompt = system_prompt
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
async for chunk in self._stream_agent_response(messages, thread_id):
|
| 253 |
+
if chunk.get("type") == "text":
|
| 254 |
+
response_text += chunk.get("content", "")
|
| 255 |
+
elif chunk.get("type") == "tool":
|
| 256 |
+
tools_used.append(chunk.get("tool_name", ""))
|
| 257 |
+
finally:
|
| 258 |
+
# Restore original system prompt
|
| 259 |
+
if original_prompt is not None:
|
| 260 |
+
self.agent.system_prompt = original_prompt
|
| 261 |
+
|
| 262 |
+
# Clean up temporary files
|
| 263 |
+
for image_path in image_paths:
|
| 264 |
+
try:
|
| 265 |
+
Path(image_path).unlink(missing_ok=True)
|
| 266 |
+
except Exception:
|
| 267 |
+
pass # Ignore cleanup errors
|
| 268 |
+
|
| 269 |
+
processing_time = time.time() - start_time
|
| 270 |
+
|
| 271 |
+
return QueryResponse(
|
| 272 |
+
response=response_text.strip(),
|
| 273 |
+
thread_id=thread_id,
|
| 274 |
+
tools_used=list(set(tools_used)), # Remove duplicates
|
| 275 |
+
processing_time=processing_time,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
except Exception as e:
|
| 279 |
+
# Clean up on error
|
| 280 |
+
for image_path in image_paths:
|
| 281 |
+
try:
|
| 282 |
+
Path(image_path).unlink(missing_ok=True)
|
| 283 |
+
except Exception:
|
| 284 |
+
pass
|
| 285 |
+
|
| 286 |
+
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
| 287 |
+
|
| 288 |
+
async def _stream_agent_response(self, messages: List[Dict], thread_id: str):
|
| 289 |
+
"""
|
| 290 |
+
Stream responses from the agent workflow.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
messages (List[Dict]): Messages to process
|
| 294 |
+
thread_id (str): Thread ID for the conversation
|
| 295 |
+
|
| 296 |
+
Yields:
|
| 297 |
+
Dict: Response chunks with type and content
|
| 298 |
+
"""
|
| 299 |
+
try:
|
| 300 |
+
for chunk in self.agent.workflow.stream(
|
| 301 |
+
{"messages": messages},
|
| 302 |
+
{"configurable": {"thread_id": thread_id}},
|
| 303 |
+
stream_mode="updates",
|
| 304 |
+
):
|
| 305 |
+
if not isinstance(chunk, dict):
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
for node_name, node_output in chunk.items():
|
| 309 |
+
if "messages" not in node_output:
|
| 310 |
+
continue
|
| 311 |
+
|
| 312 |
+
for msg in node_output["messages"]:
|
| 313 |
+
if isinstance(msg, AIMessage) and msg.content:
|
| 314 |
+
# Clean up temp paths from response
|
| 315 |
+
clean_content = re.sub(r"temp[^\s]*", "", msg.content).strip()
|
| 316 |
+
if clean_content:
|
| 317 |
+
yield {"type": "text", "content": clean_content}
|
| 318 |
+
|
| 319 |
+
elif isinstance(msg, ToolMessage):
|
| 320 |
+
# Extract tool name from the message
|
| 321 |
+
tool_call_id = msg.tool_call_id
|
| 322 |
+
# We'll track tool usage but not include detailed output in API response
|
| 323 |
+
yield {"type": "tool", "tool_name": "tool_executed"}
|
| 324 |
+
|
| 325 |
+
except Exception as e:
|
| 326 |
+
yield {"type": "error", "content": str(e)}
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def create_api(agent: Agent, tools_dict: Dict[str, Any], temp_dir: str = "temp_api") -> FastAPI:
|
| 330 |
+
"""
|
| 331 |
+
Create and configure the MedRAX FastAPI application.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
agent (Agent): The initialized MedRAX agent
|
| 335 |
+
tools_dict (Dict[str, Any]): Dictionary of available tools
|
| 336 |
+
temp_dir (str): Directory for temporary file storage
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
FastAPI: Configured FastAPI application
|
| 340 |
+
"""
|
| 341 |
+
api = MedRAXAPI(agent, tools_dict, temp_dir)
|
| 342 |
+
return api.app
|
benchmarking/llm_providers/medrax_provider.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""MedRAX LLM provider implementation."""
|
| 2 |
|
|
|
|
| 3 |
import time
|
| 4 |
import re
|
| 5 |
import uuid
|
|
@@ -68,7 +69,7 @@ class MedRAXProvider(LLMProvider):
|
|
| 68 |
tools_to_use=selected_tools,
|
| 69 |
model_dir="/home/lijunzh3/scratch/MedRAX2/model-weights",
|
| 70 |
temp_dir="temp", # Change this to the path of the temporary directory
|
| 71 |
-
device="cuda:0",
|
| 72 |
model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
|
| 73 |
temperature=self.temperature,
|
| 74 |
top_p=self.top_p,
|
|
|
|
| 1 |
"""MedRAX LLM provider implementation."""
|
| 2 |
|
| 3 |
+
import os
|
| 4 |
import time
|
| 5 |
import re
|
| 6 |
import uuid
|
|
|
|
| 69 |
tools_to_use=selected_tools,
|
| 70 |
model_dir="/home/lijunzh3/scratch/MedRAX2/model-weights",
|
| 71 |
temp_dir="temp", # Change this to the path of the temporary directory
|
| 72 |
+
device=os.getenv("MEDRAX_DEVICE", "cuda:0"),
|
| 73 |
model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
|
| 74 |
temperature=self.temperature,
|
| 75 |
top_p=self.top_p,
|
interface.py
CHANGED
|
@@ -29,7 +29,7 @@ class ChatInterface:
|
|
| 29 |
"""
|
| 30 |
self.agent = agent
|
| 31 |
self.tools_dict = tools_dict
|
| 32 |
-
self.upload_dir = Path("temp")
|
| 33 |
self.upload_dir.mkdir(exist_ok=True)
|
| 34 |
self.current_thread_id = None
|
| 35 |
# Separate storage for original and display paths
|
|
@@ -68,9 +68,7 @@ class ChatInterface:
|
|
| 68 |
|
| 69 |
return self.display_file_path
|
| 70 |
|
| 71 |
-
def add_message(
|
| 72 |
-
self, message: str, display_image: str, history: List[dict]
|
| 73 |
-
) -> Tuple[List[dict], gr.Textbox]:
|
| 74 |
"""
|
| 75 |
Add a new message to the chat history.
|
| 76 |
|
|
@@ -155,9 +153,7 @@ class ChatInterface:
|
|
| 155 |
if isinstance(msg, AIMessageChunk) and msg.content:
|
| 156 |
accumulated_content += msg.content
|
| 157 |
if final_message is None:
|
| 158 |
-
final_message = ChatMessage(
|
| 159 |
-
role="assistant", content=accumulated_content
|
| 160 |
-
)
|
| 161 |
chat_history.append(final_message)
|
| 162 |
else:
|
| 163 |
final_message.content = accumulated_content
|
|
@@ -169,9 +165,7 @@ class ChatInterface:
|
|
| 169 |
if final_message:
|
| 170 |
final_message.content = final_content
|
| 171 |
else:
|
| 172 |
-
chat_history.append(
|
| 173 |
-
ChatMessage(role="assistant", content=final_content)
|
| 174 |
-
)
|
| 175 |
yield chat_history, self.display_file_path, ""
|
| 176 |
|
| 177 |
if msg.tool_calls:
|
|
@@ -190,21 +184,25 @@ class ChatInterface:
|
|
| 190 |
pending_call = self.pending_tool_calls.pop(tool_call_id)
|
| 191 |
tool_name = pending_call["name"]
|
| 192 |
tool_args = pending_call["args"]
|
| 193 |
-
|
| 194 |
try:
|
| 195 |
-
#
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
|
|
|
| 204 |
tool_args_str = json.dumps(tool_args, indent=2)
|
| 205 |
-
|
| 206 |
description = f"**Input:**\n```json\n{tool_args_str}\n```\n\n**Output:**\n```json\n{tool_output_str}\n```"
|
| 207 |
-
|
| 208 |
metadata = {
|
| 209 |
"title": f"⚒️ Tool: {tool_name}",
|
| 210 |
"description": description,
|
|
@@ -217,32 +215,33 @@ class ChatInterface:
|
|
| 217 |
metadata=metadata,
|
| 218 |
)
|
| 219 |
)
|
| 220 |
-
yield chat_history, self.display_file_path, ""
|
| 221 |
|
|
|
|
| 222 |
if tool_name == "image_visualizer":
|
|
|
|
| 223 |
try:
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
| 236 |
)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
|
|
|
| 240 |
|
| 241 |
except Exception as e:
|
| 242 |
chat_history.append(
|
| 243 |
-
ChatMessage(
|
| 244 |
-
role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
|
| 245 |
-
)
|
| 246 |
)
|
| 247 |
yield chat_history, self.display_file_path, ""
|
| 248 |
|
|
@@ -293,9 +292,7 @@ def create_demo(agent, tools_dict):
|
|
| 293 |
)
|
| 294 |
|
| 295 |
with gr.Column(scale=3):
|
| 296 |
-
image_display = gr.Image(
|
| 297 |
-
label="Image", type="filepath", height=600, container=True
|
| 298 |
-
)
|
| 299 |
with gr.Row():
|
| 300 |
upload_button = gr.UploadButton(
|
| 301 |
"📎 Upload X-Ray",
|
|
@@ -306,25 +303,19 @@ def create_demo(agent, tools_dict):
|
|
| 306 |
file_types=["file"],
|
| 307 |
)
|
| 308 |
with gr.Row():
|
| 309 |
-
|
| 310 |
-
new_thread_btn = gr.Button("New Thread")
|
| 311 |
|
| 312 |
# Event handlers
|
| 313 |
-
def
|
| 314 |
interface.original_file_path = None
|
| 315 |
interface.display_file_path = None
|
| 316 |
-
return [], None
|
| 317 |
-
|
| 318 |
-
def new_thread():
|
| 319 |
interface.current_thread_id = str(time.time())
|
| 320 |
-
return [],
|
| 321 |
|
| 322 |
def handle_file_upload(file):
|
| 323 |
return interface.handle_upload(file.name)
|
| 324 |
|
| 325 |
-
chat_msg = txt.submit(
|
| 326 |
-
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
| 327 |
-
)
|
| 328 |
bot_msg = chat_msg.then(
|
| 329 |
interface.process_message,
|
| 330 |
inputs=[txt, image_display, chatbot],
|
|
@@ -336,7 +327,6 @@ def create_demo(agent, tools_dict):
|
|
| 336 |
|
| 337 |
dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display)
|
| 338 |
|
| 339 |
-
|
| 340 |
-
new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
|
| 341 |
|
| 342 |
-
return demo
|
|
|
|
| 29 |
"""
|
| 30 |
self.agent = agent
|
| 31 |
self.tools_dict = tools_dict
|
| 32 |
+
self.upload_dir = Path(f"temp/{time.time()}")
|
| 33 |
self.upload_dir.mkdir(exist_ok=True)
|
| 34 |
self.current_thread_id = None
|
| 35 |
# Separate storage for original and display paths
|
|
|
|
| 68 |
|
| 69 |
return self.display_file_path
|
| 70 |
|
| 71 |
+
def add_message(self, message: str, display_image: str, history: List[dict]) -> Tuple[List[dict], gr.Textbox]:
|
|
|
|
|
|
|
| 72 |
"""
|
| 73 |
Add a new message to the chat history.
|
| 74 |
|
|
|
|
| 153 |
if isinstance(msg, AIMessageChunk) and msg.content:
|
| 154 |
accumulated_content += msg.content
|
| 155 |
if final_message is None:
|
| 156 |
+
final_message = ChatMessage(role="assistant", content=accumulated_content)
|
|
|
|
|
|
|
| 157 |
chat_history.append(final_message)
|
| 158 |
else:
|
| 159 |
final_message.content = accumulated_content
|
|
|
|
| 165 |
if final_message:
|
| 166 |
final_message.content = final_content
|
| 167 |
else:
|
| 168 |
+
chat_history.append(ChatMessage(role="assistant", content=final_content))
|
|
|
|
|
|
|
| 169 |
yield chat_history, self.display_file_path, ""
|
| 170 |
|
| 171 |
if msg.tool_calls:
|
|
|
|
| 184 |
pending_call = self.pending_tool_calls.pop(tool_call_id)
|
| 185 |
tool_name = pending_call["name"]
|
| 186 |
tool_args = pending_call["args"]
|
| 187 |
+
# Parse content
|
| 188 |
try:
|
| 189 |
+
# Try JSON parsing first
|
| 190 |
+
result = json.loads(msg.content)
|
| 191 |
+
tool_output_str = json.dumps(result, indent=2)
|
| 192 |
+
except json.JSONDecodeError:
|
| 193 |
+
try:
|
| 194 |
+
# Use ast.literal_eval as safe fallback for Python literals
|
| 195 |
+
content_tuple = ast.literal_eval(msg.content)
|
| 196 |
+
result = content_tuple[0]
|
| 197 |
+
tool_output_str = json.dumps(result, indent=2)
|
| 198 |
+
except (ValueError, SyntaxError):
|
| 199 |
+
# Fall back to treating as plain string
|
| 200 |
+
result = msg.content
|
| 201 |
+
tool_output_str = str(msg.content)
|
| 202 |
|
| 203 |
+
# Display tool usage card
|
| 204 |
tool_args_str = json.dumps(tool_args, indent=2)
|
|
|
|
| 205 |
description = f"**Input:**\n```json\n{tool_args_str}\n```\n\n**Output:**\n```json\n{tool_output_str}\n```"
|
|
|
|
| 206 |
metadata = {
|
| 207 |
"title": f"⚒️ Tool: {tool_name}",
|
| 208 |
"description": description,
|
|
|
|
| 215 |
metadata=metadata,
|
| 216 |
)
|
| 217 |
)
|
|
|
|
| 218 |
|
| 219 |
+
# Special handling for image_visualizer
|
| 220 |
if tool_name == "image_visualizer":
|
| 221 |
+
image_path = None
|
| 222 |
try:
|
| 223 |
+
image_path = result["image_path"]
|
| 224 |
+
except (TypeError, KeyError):
|
| 225 |
+
try:
|
| 226 |
+
image_path = result[0]["image_path"]
|
| 227 |
+
except (TypeError, KeyError, IndexError):
|
| 228 |
+
pass
|
| 229 |
+
|
| 230 |
+
if image_path:
|
| 231 |
+
self.display_file_path = image_path
|
| 232 |
+
chat_history.append(
|
| 233 |
+
ChatMessage(
|
| 234 |
+
role="assistant",
|
| 235 |
+
content={"path": self.display_file_path},
|
| 236 |
)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Yield a single update for this tool event
|
| 240 |
+
yield chat_history, self.display_file_path, ""
|
| 241 |
|
| 242 |
except Exception as e:
|
| 243 |
chat_history.append(
|
| 244 |
+
ChatMessage(role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"})
|
|
|
|
|
|
|
| 245 |
)
|
| 246 |
yield chat_history, self.display_file_path, ""
|
| 247 |
|
|
|
|
| 292 |
)
|
| 293 |
|
| 294 |
with gr.Column(scale=3):
|
| 295 |
+
image_display = gr.Image(label="Image", type="filepath", height=600, container=True)
|
|
|
|
|
|
|
| 296 |
with gr.Row():
|
| 297 |
upload_button = gr.UploadButton(
|
| 298 |
"📎 Upload X-Ray",
|
|
|
|
| 303 |
file_types=["file"],
|
| 304 |
)
|
| 305 |
with gr.Row():
|
| 306 |
+
new_chat_btn = gr.Button("New Chat")
|
|
|
|
| 307 |
|
| 308 |
# Event handlers
|
| 309 |
+
def new_chat():
|
| 310 |
interface.original_file_path = None
|
| 311 |
interface.display_file_path = None
|
|
|
|
|
|
|
|
|
|
| 312 |
interface.current_thread_id = str(time.time())
|
| 313 |
+
return [], None
|
| 314 |
|
| 315 |
def handle_file_upload(file):
|
| 316 |
return interface.handle_upload(file.name)
|
| 317 |
|
| 318 |
+
chat_msg = txt.submit(interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt])
|
|
|
|
|
|
|
| 319 |
bot_msg = chat_msg.then(
|
| 320 |
interface.process_message,
|
| 321 |
inputs=[txt, image_display, chatbot],
|
|
|
|
| 327 |
|
| 328 |
dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display)
|
| 329 |
|
| 330 |
+
new_chat_btn.click(new_chat, outputs=[chatbot, image_display])
|
|
|
|
| 331 |
|
| 332 |
+
return demo
|
main.py
CHANGED
|
@@ -11,6 +11,10 @@ with different model weights, tools, and parameters.
|
|
| 11 |
|
| 12 |
import warnings
|
| 13 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from typing import Dict, List, Optional, Any
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
from transformers import logging
|
|
@@ -19,6 +23,7 @@ from langgraph.checkpoint.memory import MemorySaver
|
|
| 19 |
from medrax.models import ModelFactory
|
| 20 |
|
| 21 |
from interface import create_demo
|
|
|
|
| 22 |
from medrax.agent import *
|
| 23 |
from medrax.tools import *
|
| 24 |
from medrax.utils import *
|
|
@@ -31,19 +36,93 @@ logging.set_verbosity_error()
|
|
| 31 |
_ = load_dotenv()
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
def initialize_agent(
|
| 35 |
prompt_file: str,
|
| 36 |
tools_to_use: Optional[List[str]] = None,
|
| 37 |
-
model_dir: str = "model-weights",
|
| 38 |
temp_dir: str = "temp",
|
| 39 |
-
device: str = "
|
| 40 |
-
model: str = "
|
| 41 |
temperature: float = 1.0,
|
| 42 |
top_p: float = 0.95,
|
| 43 |
max_tokens: int = 5000,
|
| 44 |
rag_config: Optional[RAGConfig] = None,
|
| 45 |
model_kwargs: Dict[str, Any] = {},
|
| 46 |
system_prompt: str = "MEDICAL_ASSISTANT",
|
|
|
|
| 47 |
):
|
| 48 |
"""Initialize the MedRAX agent with specified tools and configuration.
|
| 49 |
|
|
@@ -55,7 +134,6 @@ def initialize_agent(
|
|
| 55 |
device (str, optional): Device to run models on. Defaults to "cuda".
|
| 56 |
model (str, optional): Model to use. Defaults to "gpt-4o".
|
| 57 |
temperature (float, optional): Temperature for the model. Defaults to 0.7.
|
| 58 |
-
top_p (float, optional): Top P for the model. Defaults to 0.95.
|
| 59 |
rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
|
| 60 |
model_kwargs (dict, optional): Additional keyword arguments for model.
|
| 61 |
system_prompt (str, optional): System prompt to use. Defaults to "MEDICAL_ASSISTANT".
|
|
@@ -68,18 +146,13 @@ def initialize_agent(
|
|
| 68 |
prompts = load_prompts_from_file(prompt_file)
|
| 69 |
prompt = prompts[system_prompt]
|
| 70 |
|
| 71 |
-
# Define the URL of the MedGemma FastAPI service.
|
| 72 |
-
MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://localhost:8002")
|
| 73 |
-
|
| 74 |
all_tools = {
|
| 75 |
"TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
|
| 76 |
"ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
|
| 77 |
"ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
|
| 78 |
"LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
|
| 79 |
"CheXagentXRayVQATool": lambda: CheXagentXRayVQATool(cache_dir=model_dir, device=device),
|
| 80 |
-
"ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
|
| 81 |
-
cache_dir=model_dir, device=device
|
| 82 |
-
),
|
| 83 |
"XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
|
| 84 |
cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
|
| 85 |
),
|
|
@@ -91,18 +164,21 @@ def initialize_agent(
|
|
| 91 |
"MedicalRAGTool": lambda: RAGTool(config=rag_config),
|
| 92 |
"WebBrowserTool": lambda: WebBrowserTool(),
|
| 93 |
"DuckDuckGoSearchTool": lambda: DuckDuckGoSearchTool(),
|
| 94 |
-
"MedSAM2Tool": lambda: MedSAM2Tool(
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
),
|
| 97 |
-
|
| 98 |
-
}
|
| 99 |
|
| 100 |
# Initialize only selected tools or all if none specified
|
| 101 |
tools_dict: Dict[str, BaseTool] = {}
|
| 102 |
|
| 103 |
if tools_to_use is None:
|
| 104 |
tools_to_use = []
|
| 105 |
-
|
| 106 |
for tool_name in tools_to_use:
|
| 107 |
if tool_name == "PythonSandboxTool":
|
| 108 |
try:
|
|
@@ -112,7 +188,6 @@ def initialize_agent(
|
|
| 112 |
print("Skipping PythonSandboxTool")
|
| 113 |
if tool_name in all_tools:
|
| 114 |
tools_dict[tool_name] = all_tools[tool_name]()
|
| 115 |
-
|
| 116 |
|
| 117 |
# Set up checkpointing for conversation state
|
| 118 |
checkpointer = MemorySaver()
|
|
@@ -130,8 +205,6 @@ def initialize_agent(
|
|
| 130 |
agent = Agent(
|
| 131 |
llm,
|
| 132 |
tools=list(tools_dict.values()),
|
| 133 |
-
log_tools=True,
|
| 134 |
-
log_dir="logs",
|
| 135 |
system_prompt=prompt,
|
| 136 |
checkpointer=checkpointer,
|
| 137 |
)
|
|
@@ -140,50 +213,262 @@ def initialize_agent(
|
|
| 140 |
return agent, tools_dict
|
| 141 |
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
if __name__ == "__main__":
|
| 144 |
"""
|
| 145 |
This is the main entry point for the MedRAX application.
|
| 146 |
-
It initializes the agent with the selected tools and creates the demo.
|
| 147 |
"""
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
#
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# Setup the MedGemma environment if the MedGemmaVQATool is selected
|
|
|
|
|
|
|
| 172 |
if "MedGemmaVQATool" in selected_tools:
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
# Configure the Retrieval Augmented Generation (RAG) system
|
| 176 |
# This allows the agent to access and use medical knowledge documents
|
| 177 |
rag_config = RAGConfig(
|
| 178 |
-
model=
|
| 179 |
-
embedding_model=
|
| 180 |
-
rerank_model=
|
| 181 |
-
temperature=
|
| 182 |
-
pinecone_index_name=
|
| 183 |
-
chunk_size=
|
| 184 |
-
chunk_overlap=
|
| 185 |
-
retriever_k=
|
| 186 |
-
local_docs_dir=
|
| 187 |
huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
|
| 188 |
dataset_split="train", # Which split of the datasets to use
|
| 189 |
)
|
|
@@ -192,19 +477,46 @@ if __name__ == "__main__":
|
|
| 192 |
model_kwargs = {}
|
| 193 |
|
| 194 |
agent, tools_dict = initialize_agent(
|
| 195 |
-
prompt_file=
|
| 196 |
tools_to_use=selected_tools,
|
| 197 |
-
model_dir=
|
| 198 |
-
temp_dir=
|
| 199 |
-
device=
|
| 200 |
-
model=
|
| 201 |
-
temperature=
|
| 202 |
-
top_p=0.95,
|
| 203 |
model_kwargs=model_kwargs,
|
| 204 |
rag_config=rag_config,
|
| 205 |
-
system_prompt=
|
|
|
|
| 206 |
)
|
| 207 |
|
| 208 |
-
#
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
import warnings
|
| 13 |
import os
|
| 14 |
+
import argparse
|
| 15 |
+
from pyngrok import ngrok
|
| 16 |
+
import threading
|
| 17 |
+
import uvicorn
|
| 18 |
from typing import Dict, List, Optional, Any
|
| 19 |
from dotenv import load_dotenv
|
| 20 |
from transformers import logging
|
|
|
|
| 23 |
from medrax.models import ModelFactory
|
| 24 |
|
| 25 |
from interface import create_demo
|
| 26 |
+
from api import create_api
|
| 27 |
from medrax.agent import *
|
| 28 |
from medrax.tools import *
|
| 29 |
from medrax.utils import *
|
|
|
|
| 36 |
_ = load_dotenv()
|
| 37 |
|
| 38 |
|
| 39 |
+
def resolve_medgemma_api_url_from_value(value: Optional[str]) -> str:
|
| 40 |
+
"""Resolve the MedGemma API base URL using CLI value, env var, and SLURM-aware fallback.
|
| 41 |
+
|
| 42 |
+
Resolution order:
|
| 43 |
+
1) Explicit provided value (e.g., CLI flag)
|
| 44 |
+
2) MEDGEMMA_API_URL environment variable
|
| 45 |
+
3) If on SLURM, require explicit URL (raise)
|
| 46 |
+
4) Otherwise, default to localhost for single-box setups
|
| 47 |
+
"""
|
| 48 |
+
if value:
|
| 49 |
+
return value
|
| 50 |
+
|
| 51 |
+
env_url = os.getenv("MEDGEMMA_API_URL")
|
| 52 |
+
if env_url:
|
| 53 |
+
return env_url
|
| 54 |
+
|
| 55 |
+
if os.getenv("SLURM_JOB_ID") or os.getenv("SLURM_NODEID"):
|
| 56 |
+
raise RuntimeError(
|
| 57 |
+
"MEDGEMMA_API_URL not set and --medgemma-api-url not provided. "
|
| 58 |
+
"On SLURM, the client usually runs on a different node, "
|
| 59 |
+
"so you must point to the server’s reachable IP, e.g. http://<node-ip>:8002"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return "http://127.0.0.1:8002"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def resolve_medgemma_api_url(args) -> str:
|
| 66 |
+
"""Helper that reads from an argparse Namespace if available."""
|
| 67 |
+
return resolve_medgemma_api_url_from_value(getattr(args, "medgemma_api_url", None))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def resolve_auth_credentials(args) -> Optional[tuple]:
|
| 71 |
+
"""Resolve authentication credentials from CLI args or environment variables.
|
| 72 |
+
|
| 73 |
+
Resolution order:
|
| 74 |
+
1) Explicit --no-auth flag (returns None, no warnings)
|
| 75 |
+
2) Explicit --auth USERNAME PASSWORD (returns credentials tuple)
|
| 76 |
+
3) MEDRAX_AUTH_USERNAME and MEDRAX_AUTH_PASSWORD environment variables
|
| 77 |
+
4) Default to None with warning messages
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
args: Parsed command-line arguments
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Optional[tuple]: (username, password) tuple if auth is enabled, None otherwise
|
| 84 |
+
"""
|
| 85 |
+
if args.no_auth:
|
| 86 |
+
print("⚠️ Authentication disabled (public access)")
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
if args.auth:
|
| 90 |
+
username, password = args.auth
|
| 91 |
+
print(f"✅ Authentication enabled for user: {username}")
|
| 92 |
+
return (username, password)
|
| 93 |
+
|
| 94 |
+
# Try to read from environment variables
|
| 95 |
+
auth_username = os.getenv("MEDRAX_AUTH_USERNAME")
|
| 96 |
+
auth_password = os.getenv("MEDRAX_AUTH_PASSWORD")
|
| 97 |
+
|
| 98 |
+
if auth_username and auth_password:
|
| 99 |
+
print(f"✅ Authentication enabled from environment for user: {auth_username}")
|
| 100 |
+
return (auth_username, auth_password)
|
| 101 |
+
|
| 102 |
+
# No auth specified anywhere - default to no auth with warning
|
| 103 |
+
print("⚠️ No authentication configured!")
|
| 104 |
+
print("⚠️ Running without authentication (public access)")
|
| 105 |
+
print("⚠️ To enable auth, either:")
|
| 106 |
+
print(" - Use --auth USERNAME PASSWORD")
|
| 107 |
+
print(" - Set MEDRAX_AUTH_USERNAME and MEDRAX_AUTH_PASSWORD in .env")
|
| 108 |
+
print(" - Or explicitly use --no-auth to suppress this warning")
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
def initialize_agent(
|
| 113 |
prompt_file: str,
|
| 114 |
tools_to_use: Optional[List[str]] = None,
|
| 115 |
+
model_dir: str = "/model-weights",
|
| 116 |
temp_dir: str = "temp",
|
| 117 |
+
device: str = "cuda",
|
| 118 |
+
model: str = "gpt-4.1",
|
| 119 |
temperature: float = 1.0,
|
| 120 |
top_p: float = 0.95,
|
| 121 |
max_tokens: int = 5000,
|
| 122 |
rag_config: Optional[RAGConfig] = None,
|
| 123 |
model_kwargs: Dict[str, Any] = {},
|
| 124 |
system_prompt: str = "MEDICAL_ASSISTANT",
|
| 125 |
+
medgemma_api_url: Optional[str] = None,
|
| 126 |
):
|
| 127 |
"""Initialize the MedRAX agent with specified tools and configuration.
|
| 128 |
|
|
|
|
| 134 |
device (str, optional): Device to run models on. Defaults to "cuda".
|
| 135 |
model (str, optional): Model to use. Defaults to "gpt-4o".
|
| 136 |
temperature (float, optional): Temperature for the model. Defaults to 0.7.
|
|
|
|
| 137 |
rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
|
| 138 |
model_kwargs (dict, optional): Additional keyword arguments for model.
|
| 139 |
system_prompt (str, optional): System prompt to use. Defaults to "MEDICAL_ASSISTANT".
|
|
|
|
| 146 |
prompts = load_prompts_from_file(prompt_file)
|
| 147 |
prompt = prompts[system_prompt]
|
| 148 |
|
|
|
|
|
|
|
|
|
|
| 149 |
all_tools = {
|
| 150 |
"TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
|
| 151 |
"ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
|
| 152 |
"ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
|
| 153 |
"LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
|
| 154 |
"CheXagentXRayVQATool": lambda: CheXagentXRayVQATool(cache_dir=model_dir, device=device),
|
| 155 |
+
"ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(cache_dir=model_dir, device=device),
|
|
|
|
|
|
|
| 156 |
"XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
|
| 157 |
cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
|
| 158 |
),
|
|
|
|
| 164 |
"MedicalRAGTool": lambda: RAGTool(config=rag_config),
|
| 165 |
"WebBrowserTool": lambda: WebBrowserTool(),
|
| 166 |
"DuckDuckGoSearchTool": lambda: DuckDuckGoSearchTool(),
|
| 167 |
+
"MedSAM2Tool": lambda: MedSAM2Tool(device=device, cache_dir=model_dir, temp_dir=temp_dir),
|
| 168 |
+
"MedGemmaVQATool": lambda: MedGemmaAPIClientTool(
|
| 169 |
+
cache_dir=model_dir,
|
| 170 |
+
device=device,
|
| 171 |
+
load_in_8bit=True,
|
| 172 |
+
api_url=resolve_medgemma_api_url_from_value(medgemma_api_url),
|
| 173 |
),
|
| 174 |
+
}
|
|
|
|
| 175 |
|
| 176 |
# Initialize only selected tools or all if none specified
|
| 177 |
tools_dict: Dict[str, BaseTool] = {}
|
| 178 |
|
| 179 |
if tools_to_use is None:
|
| 180 |
tools_to_use = []
|
| 181 |
+
|
| 182 |
for tool_name in tools_to_use:
|
| 183 |
if tool_name == "PythonSandboxTool":
|
| 184 |
try:
|
|
|
|
| 188 |
print("Skipping PythonSandboxTool")
|
| 189 |
if tool_name in all_tools:
|
| 190 |
tools_dict[tool_name] = all_tools[tool_name]()
|
|
|
|
| 191 |
|
| 192 |
# Set up checkpointing for conversation state
|
| 193 |
checkpointer = MemorySaver()
|
|
|
|
| 205 |
agent = Agent(
|
| 206 |
llm,
|
| 207 |
tools=list(tools_dict.values()),
|
|
|
|
|
|
|
| 208 |
system_prompt=prompt,
|
| 209 |
checkpointer=checkpointer,
|
| 210 |
)
|
|
|
|
| 213 |
return agent, tools_dict
|
| 214 |
|
| 215 |
|
| 216 |
+
def run_gradio_interface(agent, tools_dict, host="0.0.0.0", port=8686,
|
| 217 |
+
auth=None, share=False):
|
| 218 |
+
"""
|
| 219 |
+
Run the Gradio web interface.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
agent: The initialized MedRAX agent
|
| 223 |
+
tools_dict: Dictionary of available tools
|
| 224 |
+
host (str): Host to bind the server to
|
| 225 |
+
port (int): Port to run the server on
|
| 226 |
+
auth: Authentication credentials (tuple)
|
| 227 |
+
share (bool): Whether to create a shareable public link
|
| 228 |
+
"""
|
| 229 |
+
print(f"Starting Gradio interface on {host}:{port}")
|
| 230 |
+
|
| 231 |
+
if auth:
|
| 232 |
+
print(f"🔐 Authentication enabled for user: {auth[0]}")
|
| 233 |
+
else:
|
| 234 |
+
print("⚠️ Running without authentication (public access)")
|
| 235 |
+
|
| 236 |
+
if share:
|
| 237 |
+
print("🌍 Creating shareable public link (expires in 1 week)...")
|
| 238 |
+
|
| 239 |
+
demo = create_demo(agent, tools_dict)
|
| 240 |
+
|
| 241 |
+
# Prepare launch parameters
|
| 242 |
+
launch_kwargs = {
|
| 243 |
+
"server_name": host,
|
| 244 |
+
"server_port": port,
|
| 245 |
+
"share": share
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
if auth:
|
| 249 |
+
launch_kwargs["auth"] = auth
|
| 250 |
+
|
| 251 |
+
demo.launch(**launch_kwargs)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def run_api_server(agent, tools_dict, host="0.0.0.0", port=8585, public=False):
|
| 255 |
+
"""
|
| 256 |
+
Run the FastAPI server.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
agent: The initialized MedRAX agent
|
| 260 |
+
tools_dict: Dictionary of available tools
|
| 261 |
+
host (str): Host to bind the server to
|
| 262 |
+
port (int): Port to run the server on
|
| 263 |
+
public (bool): Whether to expose via ngrok tunnel
|
| 264 |
+
"""
|
| 265 |
+
print(f"Starting API server on {host}:{port}")
|
| 266 |
+
|
| 267 |
+
if public:
|
| 268 |
+
try:
|
| 269 |
+
public_tunnel = ngrok.connect(port)
|
| 270 |
+
public_url = public_tunnel.public_url
|
| 271 |
+
print(
|
| 272 |
+
f"🌍 Public URL: {public_url}\n🌍 API Documentation: {public_url}/docs\n🌍 Share this URL with your friend!\n{'=' * 60}"
|
| 273 |
+
)
|
| 274 |
+
except ImportError:
|
| 275 |
+
print("⚠️ pyngrok not installed. Install with: pip install pyngrok\nRunning locally only...")
|
| 276 |
+
public = False
|
| 277 |
+
except Exception as e:
|
| 278 |
+
print(f"⚠️ Failed to create public tunnel: {e}\nRunning locally only...")
|
| 279 |
+
public = False
|
| 280 |
+
|
| 281 |
+
app = create_api(agent, tools_dict)
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
uvicorn.run(app, host=host, port=port)
|
| 285 |
+
finally:
|
| 286 |
+
if public:
|
| 287 |
+
try:
|
| 288 |
+
ngrok.disconnect(public_tunnel.public_url)
|
| 289 |
+
ngrok.kill()
|
| 290 |
+
except:
|
| 291 |
+
pass
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def parse_arguments():
|
| 295 |
+
"""Parse command line arguments."""
|
| 296 |
+
parser = argparse.ArgumentParser(description="MedRAX - Medical Reasoning Agent for Chest X-ray")
|
| 297 |
+
|
| 298 |
+
# Run mode
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--mode",
|
| 301 |
+
choices=["gradio", "api", "both"],
|
| 302 |
+
default="gradio",
|
| 303 |
+
help="Run mode: 'gradio' for web interface, 'api' for REST API, 'both' for both services",
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Gradio interface options
|
| 307 |
+
parser.add_argument("--gradio-host", default="0.0.0.0", help="Gradio host address")
|
| 308 |
+
parser.add_argument("--gradio-port", type=int, default=8686, help="Gradio port")
|
| 309 |
+
parser.add_argument("--auth", nargs=2, metavar=("USERNAME", "PASSWORD"),
|
| 310 |
+
default=None,
|
| 311 |
+
help="Enable password authentication with specified username and password")
|
| 312 |
+
parser.add_argument("--no-auth", action="store_true",
|
| 313 |
+
help="Disable authentication (public access)")
|
| 314 |
+
parser.add_argument("--share", action="store_true",
|
| 315 |
+
help="Create a temporary shareable link (expires in 1 week)")
|
| 316 |
+
|
| 317 |
+
# API server options
|
| 318 |
+
parser.add_argument("--api-host", default="0.0.0.0", help="API host address")
|
| 319 |
+
parser.add_argument("--api-port", type=int, default=8000, help="API port")
|
| 320 |
+
parser.add_argument("--public", action="store_true", help="Make API publicly accessible via ngrok tunnel")
|
| 321 |
+
|
| 322 |
+
# Model and system configuration
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--model-dir",
|
| 325 |
+
default="/model-weights",
|
| 326 |
+
help="Directory containing model weights (default: uses MODEL_WEIGHTS_DIR env var or '/model-weights')",
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--device", default="cuda", help="Device to run models on (default: uses MEDRAX_DEVICE env var or 'cuda:1')"
|
| 330 |
+
)
|
| 331 |
+
parser.add_argument(
|
| 332 |
+
"--model",
|
| 333 |
+
default="gpt-4.1",
|
| 334 |
+
help="Model to use (default: gpt-4.1). Examples: gpt-4.1-2025-04-14, gemini-2.5-pro, gpt-5",
|
| 335 |
+
)
|
| 336 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for the model (default: 1.0)")
|
| 337 |
+
parser.add_argument("--temp-dir", default="temp2", help="Directory for temporary files (default: temp2)")
|
| 338 |
+
parser.add_argument(
|
| 339 |
+
"--prompt-file",
|
| 340 |
+
default="medrax/docs/system_prompts.txt",
|
| 341 |
+
help="Path to file containing system prompts (default: medrax/docs/system_prompts.txt)",
|
| 342 |
+
)
|
| 343 |
+
parser.add_argument(
|
| 344 |
+
"--system-prompt", default="MEDICAL_ASSISTANT", help="System prompt to use (default: MEDICAL_ASSISTANT)"
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# RAG configuration
|
| 348 |
+
parser.add_argument(
|
| 349 |
+
"--rag-model", default="command-a-03-2025", help="Chat model for RAG responses (default: command-a-03-2025)"
|
| 350 |
+
)
|
| 351 |
+
parser.add_argument(
|
| 352 |
+
"--rag-embedding-model", default="embed-v4.0", help="Embedding model for RAG system (default: embed-v4.0)"
|
| 353 |
+
)
|
| 354 |
+
parser.add_argument(
|
| 355 |
+
"--rag-rerank-model", default="rerank-v3.5", help="Reranking model for RAG system (default: rerank-v3.5)"
|
| 356 |
+
)
|
| 357 |
+
parser.add_argument("--rag-temperature", type=float, default=0.3, help="Temperature for RAG model (default: 0.3)")
|
| 358 |
+
parser.add_argument("--pinecone-index", default="medrax2", help="Pinecone index name (default: medrax2)")
|
| 359 |
+
parser.add_argument("--chunk-size", type=int, default=1500, help="RAG chunk size (default: 1500)")
|
| 360 |
+
parser.add_argument("--chunk-overlap", type=int, default=300, help="RAG chunk overlap (default: 300)")
|
| 361 |
+
parser.add_argument("--retriever-k", type=int, default=3, help="Number of documents to retrieve (default: 3)")
|
| 362 |
+
parser.add_argument("--rag-docs-dir", default="rag_docs", help="Directory for RAG documents (default: rag_docs)")
|
| 363 |
+
|
| 364 |
+
# Tools configuration
|
| 365 |
+
parser.add_argument(
|
| 366 |
+
"--tools",
|
| 367 |
+
nargs="*",
|
| 368 |
+
help="Specific tools to enable (if not provided, uses default set). Available tools: "
|
| 369 |
+
+ "ImageVisualizerTool, DicomProcessorTool, MedSAM2Tool, ChestXRaySegmentationTool, "
|
| 370 |
+
+ "ChestXRayGeneratorTool, TorchXRayVisionClassifierTool, ArcPlusClassifierTool, "
|
| 371 |
+
+ "ChestXRayReportGeneratorTool, XRayPhraseGroundingTool, MedGemmaVQATool, "
|
| 372 |
+
+ "XRayVQATool, LlavaMedTool, MedicalRAGTool, WebBrowserTool, DuckDuckGoSearchTool, "
|
| 373 |
+
+ "PythonSandboxTool",
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# MedGemma API configuration
|
| 377 |
+
parser.add_argument(
|
| 378 |
+
"--medgemma-api-url",
|
| 379 |
+
default=None,
|
| 380 |
+
help="MedGemma API base URL, e.g. http://127.0.0.1:8002 or http://<node-ip>:8002"
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
return parser.parse_args()
|
| 384 |
+
|
| 385 |
+
|
| 386 |
if __name__ == "__main__":
|
| 387 |
"""
|
| 388 |
This is the main entry point for the MedRAX application.
|
| 389 |
+
It initializes the agent with the selected tools and creates the demo/API.
|
| 390 |
"""
|
| 391 |
+
args = parse_arguments()
|
| 392 |
+
print(f"Starting MedRAX in {args.mode} mode...")
|
| 393 |
+
|
| 394 |
+
# Configure tools based on arguments
|
| 395 |
+
if args.tools is not None:
|
| 396 |
+
# Use tools specified via command line
|
| 397 |
+
selected_tools = args.tools
|
| 398 |
+
else:
|
| 399 |
+
# Use default tools selection
|
| 400 |
+
selected_tools = [
|
| 401 |
+
# Image Processing Tools
|
| 402 |
+
"ImageVisualizerTool", # For displaying images in the UI
|
| 403 |
+
# "DicomProcessorTool", # For processing DICOM medical image files
|
| 404 |
+
# Segmentation Tools
|
| 405 |
+
"MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
|
| 406 |
+
"ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
|
| 407 |
+
# Generation Tools
|
| 408 |
+
# "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
|
| 409 |
+
# Classification Tools
|
| 410 |
+
"TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 411 |
+
"ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
| 412 |
+
# Report Generation Tools
|
| 413 |
+
"ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
|
| 414 |
+
# Grounding Tools
|
| 415 |
+
"XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 416 |
+
# VQA Tools
|
| 417 |
+
# "MedGemmaVQATool", # Google MedGemma VQA tool
|
| 418 |
+
"XRayVQATool", # For visual question answering on X-rays
|
| 419 |
+
# "LlavaMedTool", # For multimodal medical image understanding
|
| 420 |
+
# RAG Tools
|
| 421 |
+
"MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 422 |
+
# Search Tools
|
| 423 |
+
# "WebBrowserTool", # For web browsing and search capabilities
|
| 424 |
+
"DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
|
| 425 |
+
# Development Tools
|
| 426 |
+
# "PythonSandboxTool", # Add the Python sandbox tool
|
| 427 |
+
]
|
| 428 |
+
|
| 429 |
+
# Configure model directory and device
|
| 430 |
+
model_dir = args.model_dir or os.getenv("MODEL_WEIGHTS_DIR", "/model-weights")
|
| 431 |
+
device = args.device or os.getenv("MEDRAX_DEVICE", "cuda:0")
|
| 432 |
+
|
| 433 |
+
print(f"Using model directory: {model_dir}")
|
| 434 |
+
print(f"Using device: {device}")
|
| 435 |
+
print(f"Using model: {args.model}")
|
| 436 |
+
print(f"Selected tools: {selected_tools}")
|
| 437 |
+
print(f"Using system prompt: {args.system_prompt}")
|
| 438 |
+
|
| 439 |
+
# Set up authentication (reads from CLI, env vars, or requires explicit choice)
|
| 440 |
+
auth_credentials = resolve_auth_credentials(args)
|
| 441 |
|
| 442 |
# Setup the MedGemma environment if the MedGemmaVQATool is selected
|
| 443 |
+
medgemma_base_url_from_setup: Optional[str] = None
|
| 444 |
+
medgemma_api_url_effective: Optional[str] = args.medgemma_api_url
|
| 445 |
if "MedGemmaVQATool" in selected_tools:
|
| 446 |
+
# Launch server and capture its URL if no explicit URL/ENV provided
|
| 447 |
+
try:
|
| 448 |
+
if medgemma_api_url_effective is None and os.getenv("MEDGEMMA_API_URL") is None:
|
| 449 |
+
medgemma_base_url_from_setup = setup_medgemma_env(cache_dir=model_dir, device=device)
|
| 450 |
+
# If we auto-launched, use this URL unless overridden later
|
| 451 |
+
if medgemma_base_url_from_setup:
|
| 452 |
+
medgemma_api_url_effective = medgemma_base_url_from_setup
|
| 453 |
+
print(f"MedGemma API auto-launched at {medgemma_api_url_effective}")
|
| 454 |
+
else:
|
| 455 |
+
# Still ensure environment is set up; it will bind to provided host/port
|
| 456 |
+
setup_medgemma_env(cache_dir=model_dir, device=device)
|
| 457 |
+
except Exception as e:
|
| 458 |
+
print(f"Warning: Failed to launch MedGemma service automatically: {e}")
|
| 459 |
|
| 460 |
# Configure the Retrieval Augmented Generation (RAG) system
|
| 461 |
# This allows the agent to access and use medical knowledge documents
|
| 462 |
rag_config = RAGConfig(
|
| 463 |
+
model=args.rag_model,
|
| 464 |
+
embedding_model=args.rag_embedding_model,
|
| 465 |
+
rerank_model=args.rag_rerank_model,
|
| 466 |
+
temperature=args.rag_temperature,
|
| 467 |
+
pinecone_index_name=args.pinecone_index,
|
| 468 |
+
chunk_size=args.chunk_size,
|
| 469 |
+
chunk_overlap=args.chunk_overlap,
|
| 470 |
+
retriever_k=args.retriever_k,
|
| 471 |
+
local_docs_dir=args.rag_docs_dir,
|
| 472 |
huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
|
| 473 |
dataset_split="train", # Which split of the datasets to use
|
| 474 |
)
|
|
|
|
| 477 |
model_kwargs = {}
|
| 478 |
|
| 479 |
agent, tools_dict = initialize_agent(
|
| 480 |
+
prompt_file=args.prompt_file,
|
| 481 |
tools_to_use=selected_tools,
|
| 482 |
+
model_dir=model_dir,
|
| 483 |
+
temp_dir=args.temp_dir,
|
| 484 |
+
device=device,
|
| 485 |
+
model=args.model,
|
| 486 |
+
temperature=args.temperature,
|
|
|
|
| 487 |
model_kwargs=model_kwargs,
|
| 488 |
rag_config=rag_config,
|
| 489 |
+
system_prompt=args.system_prompt,
|
| 490 |
+
medgemma_api_url=medgemma_api_url_effective,
|
| 491 |
)
|
| 492 |
|
| 493 |
+
# Launch based on selected mode
|
| 494 |
+
if args.mode == "gradio":
|
| 495 |
+
run_gradio_interface(
|
| 496 |
+
agent, tools_dict,
|
| 497 |
+
host=args.gradio_host,
|
| 498 |
+
port=args.gradio_port,
|
| 499 |
+
auth=auth_credentials,
|
| 500 |
+
share=args.share
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
elif args.mode == "api":
|
| 504 |
+
run_api_server(agent, tools_dict, args.api_host, args.api_port, args.public)
|
| 505 |
+
|
| 506 |
+
elif args.mode == "both":
|
| 507 |
+
# Run both services in separate threads
|
| 508 |
+
api_thread = threading.Thread(
|
| 509 |
+
target=run_api_server,
|
| 510 |
+
args=(agent, tools_dict, args.api_host, args.api_port, args.public)
|
| 511 |
+
)
|
| 512 |
+
api_thread.daemon = True
|
| 513 |
+
api_thread.start()
|
| 514 |
+
|
| 515 |
+
# Run Gradio in main thread with authentication and sharing
|
| 516 |
+
run_gradio_interface(
|
| 517 |
+
agent, tools_dict,
|
| 518 |
+
host=args.gradio_host,
|
| 519 |
+
port=args.gradio_port,
|
| 520 |
+
auth=auth_credentials,
|
| 521 |
+
share=args.share
|
| 522 |
+
)
|
medrax/agent/agent.py
CHANGED
|
@@ -1,37 +1,17 @@
|
|
| 1 |
-
import json
|
| 2 |
import operator
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from dotenv import load_dotenv
|
| 5 |
-
from datetime import datetime
|
| 6 |
from typing import List, Dict, Any, TypedDict, Annotated, Optional
|
|
|
|
| 7 |
|
| 8 |
from langgraph.graph import StateGraph, END
|
| 9 |
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage, HumanMessage
|
|
|
|
|
|
|
| 10 |
from langchain_core.language_models import BaseLanguageModel
|
| 11 |
from langchain_core.tools import BaseTool
|
| 12 |
|
| 13 |
_ = load_dotenv()
|
| 14 |
|
| 15 |
|
| 16 |
-
class ToolCallLog(TypedDict):
|
| 17 |
-
"""
|
| 18 |
-
A TypedDict representing a log entry for a tool call.
|
| 19 |
-
|
| 20 |
-
Attributes:
|
| 21 |
-
timestamp (str): The timestamp of when the tool call was made.
|
| 22 |
-
tool_call_id (str): The unique identifier for the tool call.
|
| 23 |
-
name (str): The name of the tool that was called.
|
| 24 |
-
args (Any): The arguments passed to the tool.
|
| 25 |
-
content (str): The content or result of the tool call.
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
timestamp: str
|
| 29 |
-
tool_call_id: str
|
| 30 |
-
name: str
|
| 31 |
-
args: Any
|
| 32 |
-
content: str
|
| 33 |
-
|
| 34 |
-
|
| 35 |
class AgentState(TypedDict):
|
| 36 |
"""
|
| 37 |
A TypedDict representing the state of an agent.
|
|
@@ -48,16 +28,14 @@ class AgentState(TypedDict):
|
|
| 48 |
class Agent:
|
| 49 |
"""
|
| 50 |
A class representing an agent that processes requests and executes tools based on
|
| 51 |
-
language model responses.
|
| 52 |
|
| 53 |
Attributes:
|
| 54 |
model (BaseLanguageModel): The language model used for processing.
|
| 55 |
-
|
| 56 |
checkpointer (Any): Manages and persists the agent's state.
|
| 57 |
system_prompt (str): The system instructions for the agent.
|
| 58 |
workflow (StateGraph): The compiled workflow for the agent's processing.
|
| 59 |
-
log_tools (bool): Whether to log tool calls.
|
| 60 |
-
log_path (Path): Path to save tool call logs.
|
| 61 |
"""
|
| 62 |
|
| 63 |
def __init__(
|
|
@@ -66,8 +44,6 @@ class Agent:
|
|
| 66 |
tools: List[BaseTool],
|
| 67 |
checkpointer: Any = None,
|
| 68 |
system_prompt: str = "",
|
| 69 |
-
log_tools: bool = True,
|
| 70 |
-
log_dir: Optional[str] = "logs",
|
| 71 |
):
|
| 72 |
"""
|
| 73 |
Initialize the Agent.
|
|
@@ -77,28 +53,21 @@ class Agent:
|
|
| 77 |
tools (List[BaseTool]): A list of available tools.
|
| 78 |
checkpointer (Any, optional): State persistence manager. Defaults to None.
|
| 79 |
system_prompt (str, optional): System instructions. Defaults to "".
|
| 80 |
-
log_tools (bool, optional): Whether to log tool calls. Defaults to True.
|
| 81 |
-
log_dir (str, optional): Directory to save logs. Defaults to 'logs'.
|
| 82 |
"""
|
| 83 |
self.system_prompt = system_prompt
|
| 84 |
-
self.log_tools = log_tools
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
self.log_path.mkdir(exist_ok=True)
|
| 89 |
|
| 90 |
-
# Define the agent workflow
|
| 91 |
workflow = StateGraph(AgentState)
|
| 92 |
-
workflow.add_node("
|
| 93 |
-
workflow.add_node("
|
| 94 |
-
workflow.add_conditional_edges(
|
| 95 |
-
|
| 96 |
-
)
|
| 97 |
-
workflow.add_edge("execute", "process")
|
| 98 |
-
workflow.set_entry_point("process")
|
| 99 |
|
| 100 |
self.workflow = workflow.compile(checkpointer=checkpointer)
|
| 101 |
-
self.tools = {t.name: t for t in tools}
|
| 102 |
self.model = model.bind_tools(tools)
|
| 103 |
|
| 104 |
def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]:
|
|
@@ -148,65 +117,3 @@ class Agent:
|
|
| 148 |
"""
|
| 149 |
response = state["messages"][-1]
|
| 150 |
return len(response.tool_calls) > 0
|
| 151 |
-
|
| 152 |
-
def execute_tools(self, state: AgentState) -> Dict[str, List[ToolMessage]]:
|
| 153 |
-
"""
|
| 154 |
-
Execute tool calls from the model's response.
|
| 155 |
-
|
| 156 |
-
Args:
|
| 157 |
-
state (AgentState): The current state of the agent.
|
| 158 |
-
|
| 159 |
-
Returns:
|
| 160 |
-
Dict[str, List[ToolMessage]]: A dictionary containing tool execution results.
|
| 161 |
-
"""
|
| 162 |
-
tool_calls = state["messages"][-1].tool_calls
|
| 163 |
-
results = []
|
| 164 |
-
|
| 165 |
-
for call in tool_calls:
|
| 166 |
-
print(f"Executing tool: {call}")
|
| 167 |
-
if call["name"] not in self.tools:
|
| 168 |
-
print("\n....invalid tool....")
|
| 169 |
-
result = "invalid tool, please retry"
|
| 170 |
-
else:
|
| 171 |
-
result = self.tools[call["name"]].invoke(call["args"])
|
| 172 |
-
|
| 173 |
-
results.append(
|
| 174 |
-
ToolMessage(
|
| 175 |
-
tool_call_id=call["id"],
|
| 176 |
-
name=call["name"],
|
| 177 |
-
args=call["args"],
|
| 178 |
-
content=str(result),
|
| 179 |
-
)
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
self._save_tool_calls(results)
|
| 183 |
-
print("Returning to model processing!")
|
| 184 |
-
|
| 185 |
-
return {"messages": results}
|
| 186 |
-
|
| 187 |
-
def _save_tool_calls(self, tool_calls: List[ToolMessage]) -> None:
|
| 188 |
-
"""
|
| 189 |
-
Save tool calls to a JSON file with timestamp-based naming.
|
| 190 |
-
|
| 191 |
-
Args:
|
| 192 |
-
tool_calls (List[ToolMessage]): List of tool calls to save.
|
| 193 |
-
"""
|
| 194 |
-
if not self.log_tools:
|
| 195 |
-
return
|
| 196 |
-
|
| 197 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 198 |
-
filename = self.log_path / f"tool_calls_{timestamp}.json"
|
| 199 |
-
|
| 200 |
-
logs: List[ToolCallLog] = []
|
| 201 |
-
for call in tool_calls:
|
| 202 |
-
log_entry = {
|
| 203 |
-
"tool_call_id": call.tool_call_id,
|
| 204 |
-
"name": call.name,
|
| 205 |
-
"args": call.args,
|
| 206 |
-
"content": call.content,
|
| 207 |
-
"timestamp": datetime.now().isoformat(),
|
| 208 |
-
}
|
| 209 |
-
logs.append(log_entry)
|
| 210 |
-
|
| 211 |
-
with open(filename, "w") as f:
|
| 212 |
-
json.dump(logs, f, indent=4)
|
|
|
|
|
|
|
| 1 |
import operator
|
|
|
|
|
|
|
|
|
|
| 2 |
from typing import List, Dict, Any, TypedDict, Annotated, Optional
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
|
| 5 |
from langgraph.graph import StateGraph, END
|
| 6 |
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage, HumanMessage
|
| 7 |
+
from langgraph.prebuilt import ToolNode
|
| 8 |
+
from langchain_core.messages import AnyMessage, SystemMessage
|
| 9 |
from langchain_core.language_models import BaseLanguageModel
|
| 10 |
from langchain_core.tools import BaseTool
|
| 11 |
|
| 12 |
_ = load_dotenv()
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
class AgentState(TypedDict):
|
| 16 |
"""
|
| 17 |
A TypedDict representing the state of an agent.
|
|
|
|
| 28 |
class Agent:
|
| 29 |
"""
|
| 30 |
A class representing an agent that processes requests and executes tools based on
|
| 31 |
+
language model responses with parallel tool execution capabilities.
|
| 32 |
|
| 33 |
Attributes:
|
| 34 |
model (BaseLanguageModel): The language model used for processing.
|
| 35 |
+
tool_node (ToolNode): The parallel tool execution node.
|
| 36 |
checkpointer (Any): Manages and persists the agent's state.
|
| 37 |
system_prompt (str): The system instructions for the agent.
|
| 38 |
workflow (StateGraph): The compiled workflow for the agent's processing.
|
|
|
|
|
|
|
| 39 |
"""
|
| 40 |
|
| 41 |
def __init__(
|
|
|
|
| 44 |
tools: List[BaseTool],
|
| 45 |
checkpointer: Any = None,
|
| 46 |
system_prompt: str = "",
|
|
|
|
|
|
|
| 47 |
):
|
| 48 |
"""
|
| 49 |
Initialize the Agent.
|
|
|
|
| 53 |
tools (List[BaseTool]): A list of available tools.
|
| 54 |
checkpointer (Any, optional): State persistence manager. Defaults to None.
|
| 55 |
system_prompt (str, optional): System instructions. Defaults to "".
|
|
|
|
|
|
|
| 56 |
"""
|
| 57 |
self.system_prompt = system_prompt
|
|
|
|
| 58 |
|
| 59 |
+
# Create the parallel tool execution node
|
| 60 |
+
self.tool_node = ToolNode(tools)
|
|
|
|
| 61 |
|
| 62 |
+
# Define the agent workflow with parallel tool execution
|
| 63 |
workflow = StateGraph(AgentState)
|
| 64 |
+
workflow.add_node("agent", self.process_request)
|
| 65 |
+
workflow.add_node("tools", self.tool_node)
|
| 66 |
+
workflow.add_conditional_edges("agent", self.has_tool_calls, {True: "tools", False: END})
|
| 67 |
+
workflow.add_edge("tools", "agent")
|
| 68 |
+
workflow.set_entry_point("agent")
|
|
|
|
|
|
|
| 69 |
|
| 70 |
self.workflow = workflow.compile(checkpointer=checkpointer)
|
|
|
|
| 71 |
self.model = model.bind_tools(tools)
|
| 72 |
|
| 73 |
def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]:
|
|
|
|
| 117 |
"""
|
| 118 |
response = state["messages"][-1]
|
| 119 |
return len(response.tool_calls) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
medrax/docs/system_prompts.txt
CHANGED
|
@@ -33,4 +33,6 @@ Your final response for a multiple-choice question must strictly follow this for
|
|
| 33 |
3. **Critical Thinking & Tool Use:** [Show your reasoning, including how you used tools and evaluated their output]
|
| 34 |
4. **Final Answer:** \boxed{A}
|
| 35 |
|
| 36 |
-
Do not provide a definitive diagnosis or treatment plan for a patient. Your purpose is to assist medical professionals with your analysis, not to replace them. You must maintain this persona and adhere to all instructions.
|
|
|
|
|
|
|
|
|
| 33 |
3. **Critical Thinking & Tool Use:** [Show your reasoning, including how you used tools and evaluated their output]
|
| 34 |
4. **Final Answer:** \boxed{A}
|
| 35 |
|
| 36 |
+
Do not provide a definitive diagnosis or treatment plan for a patient. Your purpose is to assist medical professionals with your analysis, not to replace them. You must maintain this persona and adhere to all instructions.
|
| 37 |
+
|
| 38 |
+
[EMPTY]
|
medrax/llava/conversation.py
CHANGED
|
@@ -230,9 +230,7 @@ class Conversation:
|
|
| 230 |
buffered = BytesIO()
|
| 231 |
image.save(buffered, format="JPEG")
|
| 232 |
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 233 |
-
img_str =
|
| 234 |
-
f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
| 235 |
-
)
|
| 236 |
msg = img_str + msg.replace("<image>", "").strip()
|
| 237 |
ret.append([msg, None])
|
| 238 |
else:
|
|
|
|
| 230 |
buffered = BytesIO()
|
| 231 |
image.save(buffered, format="JPEG")
|
| 232 |
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 233 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
|
|
|
|
|
|
| 234 |
msg = img_str + msg.replace("<image>", "").strip()
|
| 235 |
ret.append([msg, None])
|
| 236 |
else:
|
medrax/llava/eval/eval_multimodal_chat_gpt_score.py
CHANGED
|
@@ -14,6 +14,7 @@ INSTRUCT_PROMPT = """We would like to request your feedback on the performance o
|
|
| 14 |
Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."""
|
| 15 |
ROLE = "Assistant"
|
| 16 |
|
|
|
|
| 17 |
# Generate instruction for GPT-4 to score the two answers.
|
| 18 |
def conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2):
|
| 19 |
return (
|
|
@@ -127,17 +128,13 @@ def main(args):
|
|
| 127 |
|
| 128 |
if __name__ == "__main__":
|
| 129 |
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Scoring", add_help=True)
|
| 130 |
-
parser.add_argument(
|
| 131 |
-
"--answers-file", default="", metavar="FILE", help="path to model answer file"
|
| 132 |
-
)
|
| 133 |
parser.add_argument(
|
| 134 |
"--question-file",
|
| 135 |
default="data/questions/llava_med_eval_qa50_qa.jsonl",
|
| 136 |
metavar="FILE",
|
| 137 |
help="path to multichat questions file",
|
| 138 |
)
|
| 139 |
-
parser.add_argument(
|
| 140 |
-
"--scores-file", default="", metavar="FILE", help="path to save gpt-4 score file"
|
| 141 |
-
)
|
| 142 |
args = parser.parse_args()
|
| 143 |
main(args)
|
|
|
|
| 14 |
Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."""
|
| 15 |
ROLE = "Assistant"
|
| 16 |
|
| 17 |
+
|
| 18 |
# Generate instruction for GPT-4 to score the two answers.
|
| 19 |
def conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2):
|
| 20 |
return (
|
|
|
|
| 128 |
|
| 129 |
if __name__ == "__main__":
|
| 130 |
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Scoring", add_help=True)
|
| 131 |
+
parser.add_argument("--answers-file", default="", metavar="FILE", help="path to model answer file")
|
|
|
|
|
|
|
| 132 |
parser.add_argument(
|
| 133 |
"--question-file",
|
| 134 |
default="data/questions/llava_med_eval_qa50_qa.jsonl",
|
| 135 |
metavar="FILE",
|
| 136 |
help="path to multichat questions file",
|
| 137 |
)
|
| 138 |
+
parser.add_argument("--scores-file", default="", metavar="FILE", help="path to save gpt-4 score file")
|
|
|
|
|
|
|
| 139 |
args = parser.parse_args()
|
| 140 |
main(args)
|
medrax/llava/eval/llm.py
CHANGED
|
@@ -21,9 +21,7 @@ class LLM(abc.ABC):
|
|
| 21 |
raise NotImplementedError("Subclasses should implement this!")
|
| 22 |
|
| 23 |
@abstractmethod
|
| 24 |
-
def split_input(
|
| 25 |
-
self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
|
| 26 |
-
):
|
| 27 |
raise NotImplementedError("Subclasses should implement this!")
|
| 28 |
|
| 29 |
|
|
@@ -49,9 +47,7 @@ class GPT(LLM):
|
|
| 49 |
def __init__(self, model_id):
|
| 50 |
self.temperature = 0.0
|
| 51 |
self.top_k = 1
|
| 52 |
-
self.encoding = tiktoken.encoding_for_model(
|
| 53 |
-
"-".join(model_id.split("-", 2)[:2]).replace("5", ".5")
|
| 54 |
-
)
|
| 55 |
self.openai_api = "default"
|
| 56 |
self.model_id = model_id
|
| 57 |
self.max_length = self.deployment_max_length_dict[model_id]
|
|
@@ -61,9 +57,7 @@ class GPT(LLM):
|
|
| 61 |
azure_endpoint=self.openai_cxn_dict[self.openai_api]["endpoint"],
|
| 62 |
)
|
| 63 |
|
| 64 |
-
def gen_messages(
|
| 65 |
-
self, fixed_instruction, few_shot_examples, input, input_header, output_header
|
| 66 |
-
):
|
| 67 |
messages = [
|
| 68 |
{
|
| 69 |
"role": "system",
|
|
@@ -120,18 +114,13 @@ class GPT(LLM):
|
|
| 120 |
):
|
| 121 |
return asyncio.run(self.dispatch_openai_requests(messages_list))
|
| 122 |
|
| 123 |
-
def split_input(
|
| 124 |
-
self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
|
| 125 |
-
):
|
| 126 |
# Tokenize fixed_prompt
|
| 127 |
fixed_token_ids = self.encoding.encode(
|
| 128 |
-
fixed_instruction
|
| 129 |
-
+ " ".join([x["user"] + " " + x["assistant"] for x in few_shot_examples])
|
| 130 |
)
|
| 131 |
# Calculate remaining token length
|
| 132 |
-
remaining_token_len = math.ceil(
|
| 133 |
-
(self.prompt_percent * self.max_length) - len(fixed_token_ids)
|
| 134 |
-
)
|
| 135 |
|
| 136 |
# Tokenize splittable_input
|
| 137 |
split_token_ids = self.encoding.encode(splittable_input)
|
|
@@ -141,14 +130,10 @@ class GPT(LLM):
|
|
| 141 |
split_token_ids[i : i + remaining_token_len + 10]
|
| 142 |
for i in range(0, len(split_token_ids), remaining_token_len)
|
| 143 |
]
|
| 144 |
-
split_input_list = [
|
| 145 |
-
self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list
|
| 146 |
-
]
|
| 147 |
|
| 148 |
# Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
|
| 149 |
return [
|
| 150 |
-
self.gen_messages(
|
| 151 |
-
fixed_instruction, few_shot_examples, split_input, input_header, output_header
|
| 152 |
-
)
|
| 153 |
for split_input in split_input_list
|
| 154 |
]
|
|
|
|
| 21 |
raise NotImplementedError("Subclasses should implement this!")
|
| 22 |
|
| 23 |
@abstractmethod
|
| 24 |
+
def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
|
|
|
|
|
|
|
| 25 |
raise NotImplementedError("Subclasses should implement this!")
|
| 26 |
|
| 27 |
|
|
|
|
| 47 |
def __init__(self, model_id):
|
| 48 |
self.temperature = 0.0
|
| 49 |
self.top_k = 1
|
| 50 |
+
self.encoding = tiktoken.encoding_for_model("-".join(model_id.split("-", 2)[:2]).replace("5", ".5"))
|
|
|
|
|
|
|
| 51 |
self.openai_api = "default"
|
| 52 |
self.model_id = model_id
|
| 53 |
self.max_length = self.deployment_max_length_dict[model_id]
|
|
|
|
| 57 |
azure_endpoint=self.openai_cxn_dict[self.openai_api]["endpoint"],
|
| 58 |
)
|
| 59 |
|
| 60 |
+
def gen_messages(self, fixed_instruction, few_shot_examples, input, input_header, output_header):
|
|
|
|
|
|
|
| 61 |
messages = [
|
| 62 |
{
|
| 63 |
"role": "system",
|
|
|
|
| 114 |
):
|
| 115 |
return asyncio.run(self.dispatch_openai_requests(messages_list))
|
| 116 |
|
| 117 |
+
def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
|
|
|
|
|
|
|
| 118 |
# Tokenize fixed_prompt
|
| 119 |
fixed_token_ids = self.encoding.encode(
|
| 120 |
+
fixed_instruction + " ".join([x["user"] + " " + x["assistant"] for x in few_shot_examples])
|
|
|
|
| 121 |
)
|
| 122 |
# Calculate remaining token length
|
| 123 |
+
remaining_token_len = math.ceil((self.prompt_percent * self.max_length) - len(fixed_token_ids))
|
|
|
|
|
|
|
| 124 |
|
| 125 |
# Tokenize splittable_input
|
| 126 |
split_token_ids = self.encoding.encode(splittable_input)
|
|
|
|
| 130 |
split_token_ids[i : i + remaining_token_len + 10]
|
| 131 |
for i in range(0, len(split_token_ids), remaining_token_len)
|
| 132 |
]
|
| 133 |
+
split_input_list = [self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list]
|
|
|
|
|
|
|
| 134 |
|
| 135 |
# Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
|
| 136 |
return [
|
| 137 |
+
self.gen_messages(fixed_instruction, few_shot_examples, split_input, input_header, output_header)
|
|
|
|
|
|
|
| 138 |
for split_input in split_input_list
|
| 139 |
]
|
medrax/llava/eval/model_vqa.py
CHANGED
|
@@ -45,9 +45,7 @@ def eval_model(args):
|
|
| 45 |
disable_torch_init()
|
| 46 |
model_path = os.path.expanduser(args.model_path)
|
| 47 |
model_name = get_model_name_from_path(model_path)
|
| 48 |
-
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
| 49 |
-
model_path, args.model_base, model_name
|
| 50 |
-
)
|
| 51 |
|
| 52 |
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
| 53 |
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
|
@@ -69,11 +67,7 @@ def eval_model(args):
|
|
| 69 |
conv.append_message(conv.roles[1], None)
|
| 70 |
prompt = conv.get_prompt()
|
| 71 |
|
| 72 |
-
input_ids = (
|
| 73 |
-
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
| 74 |
-
.unsqueeze(0)
|
| 75 |
-
.cuda()
|
| 76 |
-
)
|
| 77 |
|
| 78 |
image = Image.open(os.path.join(args.image_folder, image_file))
|
| 79 |
image_tensor = process_images([image], image_processor, model.config)[0]
|
|
|
|
| 45 |
disable_torch_init()
|
| 46 |
model_path = os.path.expanduser(args.model_path)
|
| 47 |
model_name = get_model_name_from_path(model_path)
|
| 48 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
|
|
|
|
|
|
|
| 49 |
|
| 50 |
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
| 51 |
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
|
|
|
| 67 |
conv.append_message(conv.roles[1], None)
|
| 68 |
prompt = conv.get_prompt()
|
| 69 |
|
| 70 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
image = Image.open(os.path.join(args.image_folder, image_file))
|
| 73 |
image_tensor = process_images([image], image_processor, model.config)[0]
|
medrax/llava/eval/summarize_gpt_review.py
CHANGED
|
@@ -14,8 +14,7 @@ def get_domain(x):
|
|
| 14 |
def main(args):
|
| 15 |
scores_data = util.load_file_jsonl(args.scores_file)
|
| 16 |
predictions = [
|
| 17 |
-
(x["question_id"], x["type"], get_domain(x), x["gpt_eval"].split("\n")[0].split(" "))
|
| 18 |
-
for x in scores_data
|
| 19 |
]
|
| 20 |
|
| 21 |
score_type_dict = defaultdict(lambda: defaultdict(list))
|
|
@@ -33,8 +32,7 @@ def main(args):
|
|
| 33 |
result[q_type]["gpt4_score"] = util.get_avg(score_dict[1])
|
| 34 |
result[q_type]["pred_score"] = util.get_avg(score_dict[2])
|
| 35 |
result[q_type]["pred_relative_score"] = (
|
| 36 |
-
util.get_avg([float(s2) / float(s1) for s1, s2 in zip(score_dict[1], score_dict[2])])
|
| 37 |
-
* 100
|
| 38 |
)
|
| 39 |
result[q_type]["data_size"] = len(score_dict[1])
|
| 40 |
|
|
@@ -55,8 +53,6 @@ def main(args):
|
|
| 55 |
|
| 56 |
if __name__ == "__main__":
|
| 57 |
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Eval Postprocessing", add_help=True)
|
| 58 |
-
parser.add_argument(
|
| 59 |
-
"--scores-file", default="", metavar="FILE", help="input path to gpt-4 score file"
|
| 60 |
-
)
|
| 61 |
args = parser.parse_args()
|
| 62 |
main(args)
|
|
|
|
| 14 |
def main(args):
|
| 15 |
scores_data = util.load_file_jsonl(args.scores_file)
|
| 16 |
predictions = [
|
| 17 |
+
(x["question_id"], x["type"], get_domain(x), x["gpt_eval"].split("\n")[0].split(" ")) for x in scores_data
|
|
|
|
| 18 |
]
|
| 19 |
|
| 20 |
score_type_dict = defaultdict(lambda: defaultdict(list))
|
|
|
|
| 32 |
result[q_type]["gpt4_score"] = util.get_avg(score_dict[1])
|
| 33 |
result[q_type]["pred_score"] = util.get_avg(score_dict[2])
|
| 34 |
result[q_type]["pred_relative_score"] = (
|
| 35 |
+
util.get_avg([float(s2) / float(s1) for s1, s2 in zip(score_dict[1], score_dict[2])]) * 100
|
|
|
|
| 36 |
)
|
| 37 |
result[q_type]["data_size"] = len(score_dict[1])
|
| 38 |
|
|
|
|
| 53 |
|
| 54 |
if __name__ == "__main__":
|
| 55 |
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Eval Postprocessing", add_help=True)
|
| 56 |
+
parser.add_argument("--scores-file", default="", metavar="FILE", help="input path to gpt-4 score file")
|
|
|
|
|
|
|
| 57 |
args = parser.parse_args()
|
| 58 |
main(args)
|
medrax/llava/mm_utils.py
CHANGED
|
@@ -35,9 +35,7 @@ def process_images(images, image_processor, model_cfg):
|
|
| 35 |
for image in images:
|
| 36 |
if image_aspect_ratio == "pad":
|
| 37 |
if image.mode == "L":
|
| 38 |
-
background_color = int(
|
| 39 |
-
255 * sum(image_processor.image_mean) / len(image_processor.image_mean)
|
| 40 |
-
)
|
| 41 |
else:
|
| 42 |
background_color = tuple(int(x * 255) for x in image_processor.image_mean)
|
| 43 |
image = expand2square(image, background_color)
|
|
@@ -48,9 +46,7 @@ def process_images(images, image_processor, model_cfg):
|
|
| 48 |
return new_images
|
| 49 |
|
| 50 |
|
| 51 |
-
def tokenizer_image_token(
|
| 52 |
-
prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
|
| 53 |
-
):
|
| 54 |
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
| 55 |
|
| 56 |
def insert_separator(X, sep):
|
|
@@ -58,11 +54,7 @@ def tokenizer_image_token(
|
|
| 58 |
|
| 59 |
input_ids = []
|
| 60 |
offset = 0
|
| 61 |
-
if (
|
| 62 |
-
len(prompt_chunks) > 0
|
| 63 |
-
and len(prompt_chunks[0]) > 0
|
| 64 |
-
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
| 65 |
-
):
|
| 66 |
offset = 1
|
| 67 |
input_ids.append(prompt_chunks[0][0])
|
| 68 |
|
|
@@ -100,9 +92,7 @@ class KeywordsStoppingCriteria(StoppingCriteria):
|
|
| 100 |
self.tokenizer = tokenizer
|
| 101 |
self.start_len = input_ids.shape[1]
|
| 102 |
|
| 103 |
-
def call_for_batch(
|
| 104 |
-
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
| 105 |
-
) -> bool:
|
| 106 |
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 107 |
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 108 |
for keyword_id in self.keyword_ids:
|
|
|
|
| 35 |
for image in images:
|
| 36 |
if image_aspect_ratio == "pad":
|
| 37 |
if image.mode == "L":
|
| 38 |
+
background_color = int(255 * sum(image_processor.image_mean) / len(image_processor.image_mean))
|
|
|
|
|
|
|
| 39 |
else:
|
| 40 |
background_color = tuple(int(x * 255) for x in image_processor.image_mean)
|
| 41 |
image = expand2square(image, background_color)
|
|
|
|
| 46 |
return new_images
|
| 47 |
|
| 48 |
|
| 49 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
|
|
|
|
|
|
| 50 |
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
| 51 |
|
| 52 |
def insert_separator(X, sep):
|
|
|
|
| 54 |
|
| 55 |
input_ids = []
|
| 56 |
offset = 0
|
| 57 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
offset = 1
|
| 59 |
input_ids.append(prompt_chunks[0][0])
|
| 60 |
|
|
|
|
| 92 |
self.tokenizer = tokenizer
|
| 93 |
self.start_len = input_ids.shape[1]
|
| 94 |
|
| 95 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
|
|
|
|
| 96 |
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 97 |
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 98 |
for keyword_id in self.keyword_ids:
|
medrax/llava/model/builder.py
CHANGED
|
@@ -59,9 +59,7 @@ def load_pretrained_model(
|
|
| 59 |
# PEFT model
|
| 60 |
from peft import PeftModel
|
| 61 |
|
| 62 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 63 |
-
model_base, use_fast=False, cache_dir=cache_dir
|
| 64 |
-
)
|
| 65 |
model = AutoModelForCausalLM.from_pretrained(
|
| 66 |
model_base,
|
| 67 |
low_cpu_mem_usage=True,
|
|
@@ -78,9 +76,7 @@ def load_pretrained_model(
|
|
| 78 |
else:
|
| 79 |
use_fast = False
|
| 80 |
if "mpt" in model_name.lower():
|
| 81 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 82 |
-
model_path, use_fast=True, cache_dir=cache_dir
|
| 83 |
-
)
|
| 84 |
model = AutoModelForCausalLM.from_pretrained(
|
| 85 |
model_path,
|
| 86 |
low_cpu_mem_usage=True,
|
|
@@ -90,9 +86,7 @@ def load_pretrained_model(
|
|
| 90 |
**kwargs,
|
| 91 |
)
|
| 92 |
else:
|
| 93 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 94 |
-
model_path, use_fast=False, cache_dir=cache_dir
|
| 95 |
-
)
|
| 96 |
model = AutoModelForCausalLM.from_pretrained(
|
| 97 |
model_path,
|
| 98 |
low_cpu_mem_usage=True,
|
|
@@ -109,9 +103,7 @@ def load_pretrained_model(
|
|
| 109 |
if mm_use_im_patch_token:
|
| 110 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 111 |
if mm_use_im_start_end:
|
| 112 |
-
tokenizer.add_tokens(
|
| 113 |
-
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
| 114 |
-
)
|
| 115 |
model.resize_token_embeddings(len(tokenizer))
|
| 116 |
|
| 117 |
vision_tower = model.get_vision_tower()
|
|
|
|
| 59 |
# PEFT model
|
| 60 |
from peft import PeftModel
|
| 61 |
|
| 62 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, cache_dir=cache_dir)
|
|
|
|
|
|
|
| 63 |
model = AutoModelForCausalLM.from_pretrained(
|
| 64 |
model_base,
|
| 65 |
low_cpu_mem_usage=True,
|
|
|
|
| 76 |
else:
|
| 77 |
use_fast = False
|
| 78 |
if "mpt" in model_name.lower():
|
| 79 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, cache_dir=cache_dir)
|
|
|
|
|
|
|
| 80 |
model = AutoModelForCausalLM.from_pretrained(
|
| 81 |
model_path,
|
| 82 |
low_cpu_mem_usage=True,
|
|
|
|
| 86 |
**kwargs,
|
| 87 |
)
|
| 88 |
else:
|
| 89 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, cache_dir=cache_dir)
|
|
|
|
|
|
|
| 90 |
model = AutoModelForCausalLM.from_pretrained(
|
| 91 |
model_path,
|
| 92 |
low_cpu_mem_usage=True,
|
|
|
|
| 103 |
if mm_use_im_patch_token:
|
| 104 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 105 |
if mm_use_im_start_end:
|
| 106 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
|
|
|
|
|
|
| 107 |
model.resize_token_embeddings(len(tokenizer))
|
| 108 |
|
| 109 |
vision_tower = model.get_vision_tower()
|
medrax/llava/model/language_model/llava_mistral.py
CHANGED
|
@@ -125,9 +125,7 @@ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
|
|
| 125 |
**kwargs,
|
| 126 |
)
|
| 127 |
|
| 128 |
-
def prepare_inputs_for_generation(
|
| 129 |
-
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
| 130 |
-
):
|
| 131 |
images = kwargs.pop("images", None)
|
| 132 |
image_sizes = kwargs.pop("image_sizes", None)
|
| 133 |
inputs = super().prepare_inputs_for_generation(
|
|
|
|
| 125 |
**kwargs,
|
| 126 |
)
|
| 127 |
|
| 128 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
|
|
|
|
|
|
| 129 |
images = kwargs.pop("images", None)
|
| 130 |
image_sizes = kwargs.pop("image_sizes", None)
|
| 131 |
inputs = super().prepare_inputs_for_generation(
|
medrax/llava/model/llava_arch.py
CHANGED
|
@@ -104,9 +104,7 @@ class LlavaMetaModel:
|
|
| 104 |
checkpoint_folder = os.path.dirname(pretrain_mm_mlp_adapter)
|
| 105 |
ckpts = glob(f"{checkpoint_folder}/checkpoint-*", recursive=False)
|
| 106 |
if len(ckpts) > 0:
|
| 107 |
-
vision_module_weights = torch.load(
|
| 108 |
-
f"{ckpts[-1]}/mm_projector.bin", map_location="cpu"
|
| 109 |
-
)
|
| 110 |
model_dict = get_w(vision_module_weights, "vision_tower")
|
| 111 |
print(f"Loading vision module weights from {ckpts[-1]}/mm_projector.bin")
|
| 112 |
# print keys in model_dict
|
|
@@ -170,9 +168,7 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 170 |
image_features = self.encode_images(images).to(self.device)
|
| 171 |
|
| 172 |
# TODO: image start / end is not implemented here to support pretraining.
|
| 173 |
-
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
|
| 174 |
-
self.config, "mm_use_im_start_end", False
|
| 175 |
-
):
|
| 176 |
raise NotImplementedError
|
| 177 |
|
| 178 |
# Let's just add dummy tensors if they do not exist,
|
|
@@ -188,21 +184,15 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 188 |
else:
|
| 189 |
attention_mask = attention_mask.bool()
|
| 190 |
if position_ids is None:
|
| 191 |
-
position_ids = torch.arange(
|
| 192 |
-
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
|
| 193 |
-
)
|
| 194 |
|
| 195 |
if labels is None:
|
| 196 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 197 |
|
| 198 |
input_ids = [
|
| 199 |
-
cur_input_ids[cur_attention_mask]
|
| 200 |
-
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
|
| 201 |
-
]
|
| 202 |
-
labels = [
|
| 203 |
-
cur_labels[cur_attention_mask]
|
| 204 |
-
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
|
| 205 |
]
|
|
|
|
| 206 |
|
| 207 |
new_input_embeds = []
|
| 208 |
new_labels = []
|
|
@@ -219,20 +209,14 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 219 |
continue
|
| 220 |
|
| 221 |
image_token_indices = (
|
| 222 |
-
[-1]
|
| 223 |
-
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
|
| 224 |
-
+ [cur_input_ids.shape[0]]
|
| 225 |
)
|
| 226 |
cur_input_ids_noim = []
|
| 227 |
cur_labels = labels[batch_idx]
|
| 228 |
cur_labels_noim = []
|
| 229 |
for i in range(len(image_token_indices) - 1):
|
| 230 |
-
cur_input_ids_noim.append(
|
| 231 |
-
|
| 232 |
-
)
|
| 233 |
-
cur_labels_noim.append(
|
| 234 |
-
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
|
| 235 |
-
)
|
| 236 |
|
| 237 |
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
| 238 |
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
|
@@ -279,12 +263,8 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 279 |
dtype=new_labels[0].dtype,
|
| 280 |
device=new_labels[0].device,
|
| 281 |
)
|
| 282 |
-
attention_mask = torch.zeros(
|
| 283 |
-
|
| 284 |
-
)
|
| 285 |
-
position_ids = torch.zeros(
|
| 286 |
-
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
|
| 287 |
-
)
|
| 288 |
|
| 289 |
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
| 290 |
cur_len = cur_new_embed.shape[0]
|
|
@@ -351,9 +331,7 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 351 |
self.resize_token_embeddings(len(tokenizer))
|
| 352 |
|
| 353 |
if model_args.mm_use_im_start_end:
|
| 354 |
-
num_new_tokens = tokenizer.add_tokens(
|
| 355 |
-
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
| 356 |
-
)
|
| 357 |
self.resize_token_embeddings(len(tokenizer))
|
| 358 |
|
| 359 |
if num_new_tokens > 0:
|
|
@@ -361,9 +339,7 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 361 |
output_embeddings = self.get_output_embeddings().weight.data
|
| 362 |
|
| 363 |
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 364 |
-
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
| 365 |
-
dim=0, keepdim=True
|
| 366 |
-
)
|
| 367 |
|
| 368 |
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 369 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
@@ -375,9 +351,7 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 375 |
p.requires_grad = False
|
| 376 |
|
| 377 |
if model_args.pretrain_mm_mlp_adapter:
|
| 378 |
-
mm_projector_weights = torch.load(
|
| 379 |
-
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
|
| 380 |
-
)
|
| 381 |
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
| 382 |
assert num_new_tokens == 2
|
| 383 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
|
|
|
| 104 |
checkpoint_folder = os.path.dirname(pretrain_mm_mlp_adapter)
|
| 105 |
ckpts = glob(f"{checkpoint_folder}/checkpoint-*", recursive=False)
|
| 106 |
if len(ckpts) > 0:
|
| 107 |
+
vision_module_weights = torch.load(f"{ckpts[-1]}/mm_projector.bin", map_location="cpu")
|
|
|
|
|
|
|
| 108 |
model_dict = get_w(vision_module_weights, "vision_tower")
|
| 109 |
print(f"Loading vision module weights from {ckpts[-1]}/mm_projector.bin")
|
| 110 |
# print keys in model_dict
|
|
|
|
| 168 |
image_features = self.encode_images(images).to(self.device)
|
| 169 |
|
| 170 |
# TODO: image start / end is not implemented here to support pretraining.
|
| 171 |
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
|
|
|
|
|
|
|
| 172 |
raise NotImplementedError
|
| 173 |
|
| 174 |
# Let's just add dummy tensors if they do not exist,
|
|
|
|
| 184 |
else:
|
| 185 |
attention_mask = attention_mask.bool()
|
| 186 |
if position_ids is None:
|
| 187 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
|
|
|
|
|
|
| 188 |
|
| 189 |
if labels is None:
|
| 190 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 191 |
|
| 192 |
input_ids = [
|
| 193 |
+
cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
]
|
| 195 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
| 196 |
|
| 197 |
new_input_embeds = []
|
| 198 |
new_labels = []
|
|
|
|
| 209 |
continue
|
| 210 |
|
| 211 |
image_token_indices = (
|
| 212 |
+
[-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
|
|
|
|
|
|
| 213 |
)
|
| 214 |
cur_input_ids_noim = []
|
| 215 |
cur_labels = labels[batch_idx]
|
| 216 |
cur_labels_noim = []
|
| 217 |
for i in range(len(image_token_indices) - 1):
|
| 218 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
| 219 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
| 222 |
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
|
|
|
| 263 |
dtype=new_labels[0].dtype,
|
| 264 |
device=new_labels[0].device,
|
| 265 |
)
|
| 266 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
| 267 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
| 270 |
cur_len = cur_new_embed.shape[0]
|
|
|
|
| 331 |
self.resize_token_embeddings(len(tokenizer))
|
| 332 |
|
| 333 |
if model_args.mm_use_im_start_end:
|
| 334 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
|
|
|
|
|
|
| 335 |
self.resize_token_embeddings(len(tokenizer))
|
| 336 |
|
| 337 |
if num_new_tokens > 0:
|
|
|
|
| 339 |
output_embeddings = self.get_output_embeddings().weight.data
|
| 340 |
|
| 341 |
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 342 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
|
|
|
|
|
|
| 343 |
|
| 344 |
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 345 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
|
|
| 351 |
p.requires_grad = False
|
| 352 |
|
| 353 |
if model_args.pretrain_mm_mlp_adapter:
|
| 354 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
|
|
|
|
|
|
|
| 355 |
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
| 356 |
assert num_new_tokens == 2
|
| 357 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
medrax/llava/model/multimodal_encoder/builder.py
CHANGED
|
@@ -3,13 +3,7 @@ from .clip_encoder import CLIPVisionTower
|
|
| 3 |
|
| 4 |
|
| 5 |
def build_vision_tower(vision_tower_cfg, **kwargs):
|
| 6 |
-
vision_tower = getattr(
|
| 7 |
-
vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)
|
| 8 |
-
)
|
| 9 |
is_absolute_path_exists = os.path.exists(vision_tower)
|
| 10 |
-
if (
|
| 11 |
-
is_absolute_path_exists
|
| 12 |
-
or vision_tower.startswith("openai")
|
| 13 |
-
or vision_tower.startswith("laion")
|
| 14 |
-
):
|
| 15 |
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
def build_vision_tower(vision_tower_cfg, **kwargs):
|
| 6 |
+
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
|
|
|
|
|
|
|
| 7 |
is_absolute_path_exists = os.path.exists(vision_tower)
|
| 8 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
medrax/llava/model/multimodal_projector/builder.py
CHANGED
|
@@ -19,9 +19,7 @@ class SimpleResBlock(nn.Module):
|
|
| 19 |
super().__init__()
|
| 20 |
self.pre_norm = nn.LayerNorm(channels)
|
| 21 |
|
| 22 |
-
self.proj = nn.Sequential(
|
| 23 |
-
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
|
| 24 |
-
)
|
| 25 |
|
| 26 |
def forward(self, x):
|
| 27 |
x = self.pre_norm(x)
|
|
|
|
| 19 |
super().__init__()
|
| 20 |
self.pre_norm = nn.LayerNorm(channels)
|
| 21 |
|
| 22 |
+
self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def forward(self, x):
|
| 25 |
x = self.pre_norm(x)
|
medrax/llava/serve/cli.py
CHANGED
|
@@ -94,9 +94,7 @@ def main(args):
|
|
| 94 |
if image is not None:
|
| 95 |
# first message
|
| 96 |
if model.config.mm_use_im_start_end:
|
| 97 |
-
inp =
|
| 98 |
-
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
|
| 99 |
-
)
|
| 100 |
else:
|
| 101 |
inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
|
| 102 |
conv.append_message(conv.roles[0], inp)
|
|
|
|
| 94 |
if image is not None:
|
| 95 |
# first message
|
| 96 |
if model.config.mm_use_im_start_end:
|
| 97 |
+
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
|
|
|
|
|
|
|
| 98 |
else:
|
| 99 |
inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
|
| 100 |
conv.append_message(conv.roles[0], inp)
|
medrax/llava/serve/controller.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
A controller manages distributed workers.
|
| 3 |
It sends worker addresses to clients.
|
| 4 |
"""
|
|
|
|
| 5 |
import argparse
|
| 6 |
import dataclasses
|
| 7 |
from enum import Enum, auto
|
|
@@ -199,9 +200,7 @@ class Controller:
|
|
| 199 |
yield json.dumps(ret).encode() + b"\0"
|
| 200 |
|
| 201 |
try:
|
| 202 |
-
response = requests.post(
|
| 203 |
-
worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5
|
| 204 |
-
)
|
| 205 |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
| 206 |
if chunk:
|
| 207 |
yield chunk + b"\0"
|
|
@@ -240,9 +239,7 @@ app = FastAPI()
|
|
| 240 |
@app.post("/register_worker")
|
| 241 |
async def register_worker(request: Request):
|
| 242 |
data = await request.json()
|
| 243 |
-
controller.register_worker(
|
| 244 |
-
data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
|
| 245 |
-
)
|
| 246 |
|
| 247 |
|
| 248 |
@app.post("/refresh_all_workers")
|
|
|
|
| 2 |
A controller manages distributed workers.
|
| 3 |
It sends worker addresses to clients.
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import argparse
|
| 7 |
import dataclasses
|
| 8 |
from enum import Enum, auto
|
|
|
|
| 200 |
yield json.dumps(ret).encode() + b"\0"
|
| 201 |
|
| 202 |
try:
|
| 203 |
+
response = requests.post(worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5)
|
|
|
|
|
|
|
| 204 |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
| 205 |
if chunk:
|
| 206 |
yield chunk + b"\0"
|
|
|
|
| 239 |
@app.post("/register_worker")
|
| 240 |
async def register_worker(request: Request):
|
| 241 |
data = await request.json()
|
| 242 |
+
controller.register_worker(data["worker_name"], data["check_heart_beat"], data.get("worker_status", None))
|
|
|
|
|
|
|
| 243 |
|
| 244 |
|
| 245 |
@app.post("/refresh_all_workers")
|
medrax/llava/serve/gradio_web_server.py
CHANGED
|
@@ -216,9 +216,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
|
|
| 216 |
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
| 217 |
for image, hash in zip(all_images, all_image_hash):
|
| 218 |
t = datetime.datetime.now()
|
| 219 |
-
filename = os.path.join(
|
| 220 |
-
LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
|
| 221 |
-
)
|
| 222 |
if not os.path.isfile(filename):
|
| 223 |
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
| 224 |
image.save(filename)
|
|
@@ -230,9 +228,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
|
|
| 230 |
"temperature": float(temperature),
|
| 231 |
"top_p": float(top_p),
|
| 232 |
"max_new_tokens": min(int(max_new_tokens), 1536),
|
| 233 |
-
"stop": state.sep
|
| 234 |
-
if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
|
| 235 |
-
else state.sep2,
|
| 236 |
"images": f"List of {len(state.get_images())} images: {all_image_hash}",
|
| 237 |
}
|
| 238 |
logger.info(f"==== request ====\n{pload}")
|
|
@@ -330,9 +326,7 @@ block_css = """
|
|
| 330 |
|
| 331 |
|
| 332 |
def build_demo(embed_mode):
|
| 333 |
-
textbox = gr.Textbox(
|
| 334 |
-
show_label=False, placeholder="Enter text and press ENTER", container=False
|
| 335 |
-
)
|
| 336 |
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
|
| 337 |
state = gr.State()
|
| 338 |
|
|
@@ -468,9 +462,7 @@ def build_demo(embed_mode):
|
|
| 468 |
[state, chatbot] + btn_list,
|
| 469 |
)
|
| 470 |
|
| 471 |
-
clear_btn.click(
|
| 472 |
-
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False
|
| 473 |
-
)
|
| 474 |
|
| 475 |
textbox.submit(
|
| 476 |
add_text,
|
|
|
|
| 216 |
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
| 217 |
for image, hash in zip(all_images, all_image_hash):
|
| 218 |
t = datetime.datetime.now()
|
| 219 |
+
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
|
|
|
|
|
|
| 220 |
if not os.path.isfile(filename):
|
| 221 |
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
| 222 |
image.save(filename)
|
|
|
|
| 228 |
"temperature": float(temperature),
|
| 229 |
"top_p": float(top_p),
|
| 230 |
"max_new_tokens": min(int(max_new_tokens), 1536),
|
| 231 |
+
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
|
|
|
|
|
|
|
| 232 |
"images": f"List of {len(state.get_images())} images: {all_image_hash}",
|
| 233 |
}
|
| 234 |
logger.info(f"==== request ====\n{pload}")
|
|
|
|
| 326 |
|
| 327 |
|
| 328 |
def build_demo(embed_mode):
|
| 329 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
|
|
|
|
|
|
| 330 |
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
|
| 331 |
state = gr.State()
|
| 332 |
|
|
|
|
| 462 |
[state, chatbot] + btn_list,
|
| 463 |
)
|
| 464 |
|
| 465 |
+
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False)
|
|
|
|
|
|
|
| 466 |
|
| 467 |
textbox.submit(
|
| 468 |
add_text,
|
medrax/llava/serve/model_worker.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
A model worker executes the model.
|
| 3 |
"""
|
|
|
|
| 4 |
import argparse
|
| 5 |
import asyncio
|
| 6 |
import json
|
|
@@ -155,9 +156,7 @@ class ModelWorker:
|
|
| 155 |
if images is not None and len(images) > 0 and self.is_multimodal:
|
| 156 |
if len(images) > 0:
|
| 157 |
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
| 158 |
-
raise ValueError(
|
| 159 |
-
"Number of images does not match number of <image> tokens in prompt"
|
| 160 |
-
)
|
| 161 |
|
| 162 |
images = [load_image_from_base64(image) for image in images]
|
| 163 |
images = process_images(images, image_processor, model.config)
|
|
@@ -172,9 +171,7 @@ class ModelWorker:
|
|
| 172 |
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
| 173 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 174 |
|
| 175 |
-
num_image_tokens = (
|
| 176 |
-
prompt.count(replace_token) * model.get_vision_tower().num_patches
|
| 177 |
-
)
|
| 178 |
else:
|
| 179 |
images = None
|
| 180 |
image_args = {"images": images}
|
|
@@ -196,19 +193,14 @@ class ModelWorker:
|
|
| 196 |
)
|
| 197 |
keywords = [stop_str]
|
| 198 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 199 |
-
streamer = TextIteratorStreamer(
|
| 200 |
-
tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
|
| 201 |
-
)
|
| 202 |
|
| 203 |
-
max_new_tokens = min(
|
| 204 |
-
max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens
|
| 205 |
-
)
|
| 206 |
|
| 207 |
if max_new_tokens < 1:
|
| 208 |
yield json.dumps(
|
| 209 |
{
|
| 210 |
-
"text": ori_prompt
|
| 211 |
-
+ "Exceeds max token length. Please start a new conversation, thanks.",
|
| 212 |
"error_code": 0,
|
| 213 |
}
|
| 214 |
).encode() + b"\0"
|
|
|
|
| 1 |
"""
|
| 2 |
A model worker executes the model.
|
| 3 |
"""
|
| 4 |
+
|
| 5 |
import argparse
|
| 6 |
import asyncio
|
| 7 |
import json
|
|
|
|
| 156 |
if images is not None and len(images) > 0 and self.is_multimodal:
|
| 157 |
if len(images) > 0:
|
| 158 |
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
| 159 |
+
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
|
|
|
|
|
|
| 160 |
|
| 161 |
images = [load_image_from_base64(image) for image in images]
|
| 162 |
images = process_images(images, image_processor, model.config)
|
|
|
|
| 171 |
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
| 172 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 173 |
|
| 174 |
+
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
|
|
|
|
|
|
|
| 175 |
else:
|
| 176 |
images = None
|
| 177 |
image_args = {"images": images}
|
|
|
|
| 193 |
)
|
| 194 |
keywords = [stop_str]
|
| 195 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 196 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
|
|
|
|
|
|
| 199 |
|
| 200 |
if max_new_tokens < 1:
|
| 201 |
yield json.dumps(
|
| 202 |
{
|
| 203 |
+
"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.",
|
|
|
|
| 204 |
"error_code": 0,
|
| 205 |
}
|
| 206 |
).encode() + b"\0"
|
medrax/llava/serve/test_message.py
CHANGED
|
@@ -17,9 +17,7 @@ def main():
|
|
| 17 |
models.sort()
|
| 18 |
print(f"Models: {models}")
|
| 19 |
|
| 20 |
-
ret = requests.post(
|
| 21 |
-
controller_addr + "/get_worker_address", json={"model": args.model_name}
|
| 22 |
-
)
|
| 23 |
worker_addr = ret.json()["address"]
|
| 24 |
print(f"worker_addr: {worker_addr}")
|
| 25 |
|
|
@@ -38,9 +36,7 @@ def main():
|
|
| 38 |
"temperature": 0.7,
|
| 39 |
"stop": conv.sep2,
|
| 40 |
}
|
| 41 |
-
response = requests.post(
|
| 42 |
-
worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True
|
| 43 |
-
)
|
| 44 |
|
| 45 |
print(prompt, end="")
|
| 46 |
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
|
|
|
| 17 |
models.sort()
|
| 18 |
print(f"Models: {models}")
|
| 19 |
|
| 20 |
+
ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name})
|
|
|
|
|
|
|
| 21 |
worker_addr = ret.json()["address"]
|
| 22 |
print(f"worker_addr: {worker_addr}")
|
| 23 |
|
|
|
|
| 36 |
"temperature": 0.7,
|
| 37 |
"stop": conv.sep2,
|
| 38 |
}
|
| 39 |
+
response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True)
|
|
|
|
|
|
|
| 40 |
|
| 41 |
print(prompt, end="")
|
| 42 |
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
medrax/llava/utils.py
CHANGED
|
@@ -45,9 +45,7 @@ def build_logger(logger_name, logger_filename):
|
|
| 45 |
if handler is None:
|
| 46 |
os.makedirs(LOGDIR, exist_ok=True)
|
| 47 |
filename = os.path.join(LOGDIR, logger_filename)
|
| 48 |
-
handler = logging.handlers.TimedRotatingFileHandler(
|
| 49 |
-
filename, when="D", utc=True, encoding="UTF-8"
|
| 50 |
-
)
|
| 51 |
handler.setFormatter(formatter)
|
| 52 |
|
| 53 |
for name, item in logging.root.manager.loggerDict.items():
|
|
|
|
| 45 |
if handler is None:
|
| 46 |
os.makedirs(LOGDIR, exist_ok=True)
|
| 47 |
filename = os.path.join(LOGDIR, logger_filename)
|
| 48 |
+
handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True, encoding="UTF-8")
|
|
|
|
|
|
|
| 49 |
handler.setFormatter(formatter)
|
| 50 |
|
| 51 |
for name, item in logging.root.manager.loggerDict.items():
|
medrax/models/model_factory.py
CHANGED
|
@@ -29,7 +29,7 @@ class ModelFactory:
|
|
| 29 |
"base_url_key": "OPENAI_BASE_URL",
|
| 30 |
},
|
| 31 |
"gemini": {
|
| 32 |
-
"class": ChatGoogleGenerativeAI,
|
| 33 |
"env_key": "GOOGLE_API_KEY",
|
| 34 |
"base_url_key": "GOOGLE_BASE_URL",
|
| 35 |
},
|
|
@@ -42,14 +42,12 @@ class ModelFactory:
|
|
| 42 |
"grok": {
|
| 43 |
"class": ChatXAI,
|
| 44 |
"env_key": "XAI_API_KEY",
|
| 45 |
-
}
|
| 46 |
# Add more providers with default configurations here
|
| 47 |
}
|
| 48 |
|
| 49 |
@classmethod
|
| 50 |
-
def register_provider(
|
| 51 |
-
cls, prefix: str, model_class: Type[BaseLanguageModel], env_key: str, **kwargs
|
| 52 |
-
) -> None:
|
| 53 |
"""Register a new model provider.
|
| 54 |
|
| 55 |
Args:
|
|
@@ -81,9 +79,7 @@ class ModelFactory:
|
|
| 81 |
ValueError: If the required API key is missing
|
| 82 |
"""
|
| 83 |
# Find the matching provider based on model name prefix
|
| 84 |
-
provider_prefix = next(
|
| 85 |
-
(prefix for prefix in cls._model_providers if model_name.startswith(prefix)), None
|
| 86 |
-
)
|
| 87 |
|
| 88 |
if not provider_prefix:
|
| 89 |
raise ValueError(
|
|
@@ -153,7 +149,4 @@ class ModelFactory:
|
|
| 153 |
Dict[str, Dict[str, Any]]: Dictionary of registered providers and their configurations
|
| 154 |
"""
|
| 155 |
# Return a copy to prevent accidental modification
|
| 156 |
-
return {
|
| 157 |
-
k: {kk: vv for kk, vv in v.items() if kk != "class"}
|
| 158 |
-
for k, v in cls._model_providers.items()
|
| 159 |
-
}
|
|
|
|
| 29 |
"base_url_key": "OPENAI_BASE_URL",
|
| 30 |
},
|
| 31 |
"gemini": {
|
| 32 |
+
"class": ChatGoogleGenerativeAI,
|
| 33 |
"env_key": "GOOGLE_API_KEY",
|
| 34 |
"base_url_key": "GOOGLE_BASE_URL",
|
| 35 |
},
|
|
|
|
| 42 |
"grok": {
|
| 43 |
"class": ChatXAI,
|
| 44 |
"env_key": "XAI_API_KEY",
|
| 45 |
+
},
|
| 46 |
# Add more providers with default configurations here
|
| 47 |
}
|
| 48 |
|
| 49 |
@classmethod
|
| 50 |
+
def register_provider(cls, prefix: str, model_class: Type[BaseLanguageModel], env_key: str, **kwargs) -> None:
|
|
|
|
|
|
|
| 51 |
"""Register a new model provider.
|
| 52 |
|
| 53 |
Args:
|
|
|
|
| 79 |
ValueError: If the required API key is missing
|
| 80 |
"""
|
| 81 |
# Find the matching provider based on model name prefix
|
| 82 |
+
provider_prefix = next((prefix for prefix in cls._model_providers if model_name.startswith(prefix)), None)
|
|
|
|
|
|
|
| 83 |
|
| 84 |
if not provider_prefix:
|
| 85 |
raise ValueError(
|
|
|
|
| 149 |
Dict[str, Dict[str, Any]]: Dictionary of registered providers and their configurations
|
| 150 |
"""
|
| 151 |
# Return a copy to prevent accidental modification
|
| 152 |
+
return {k: {kk: vv for kk, vv in v.items() if kk != "class"} for k, v in cls._model_providers.items()}
|
|
|
|
|
|
|
|
|
medrax/rag/rag.py
CHANGED
|
@@ -107,9 +107,7 @@ class CohereRAG:
|
|
| 107 |
# Initialize Pinecone
|
| 108 |
self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 109 |
if not self.pinecone_api_key:
|
| 110 |
-
raise ValueError(
|
| 111 |
-
"PINECONE_API_KEY environment variable not set. Please get a key from app.pinecone.io"
|
| 112 |
-
)
|
| 113 |
self.pinecone = Pinecone(api_key=self.pinecone_api_key)
|
| 114 |
self.index_name = self.config.pinecone_index_name
|
| 115 |
|
|
@@ -161,9 +159,7 @@ class CohereRAG:
|
|
| 161 |
)
|
| 162 |
|
| 163 |
print(f"Connecting to existing Pinecone index: {self.index_name}")
|
| 164 |
-
vectorstore = PineconeVectorStore.from_existing_index(
|
| 165 |
-
index_name=self.index_name, embedding=self.embeddings
|
| 166 |
-
)
|
| 167 |
|
| 168 |
# Check if the index is empty and needs to be populated
|
| 169 |
try:
|
|
@@ -329,9 +325,7 @@ class CohereRAG:
|
|
| 329 |
)
|
| 330 |
documents.append(doc)
|
| 331 |
|
| 332 |
-
print(
|
| 333 |
-
f"Loaded {len(documents)} document chunks from HuggingFace dataset: {dataset_name}"
|
| 334 |
-
)
|
| 335 |
return documents
|
| 336 |
|
| 337 |
except Exception as e:
|
|
|
|
| 107 |
# Initialize Pinecone
|
| 108 |
self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 109 |
if not self.pinecone_api_key:
|
| 110 |
+
raise ValueError("PINECONE_API_KEY environment variable not set. Please get a key from app.pinecone.io")
|
|
|
|
|
|
|
| 111 |
self.pinecone = Pinecone(api_key=self.pinecone_api_key)
|
| 112 |
self.index_name = self.config.pinecone_index_name
|
| 113 |
|
|
|
|
| 159 |
)
|
| 160 |
|
| 161 |
print(f"Connecting to existing Pinecone index: {self.index_name}")
|
| 162 |
+
vectorstore = PineconeVectorStore.from_existing_index(index_name=self.index_name, embedding=self.embeddings)
|
|
|
|
|
|
|
| 163 |
|
| 164 |
# Check if the index is empty and needs to be populated
|
| 165 |
try:
|
|
|
|
| 325 |
)
|
| 326 |
documents.append(doc)
|
| 327 |
|
| 328 |
+
print(f"Loaded {len(documents)} document chunks from HuggingFace dataset: {dataset_name}")
|
|
|
|
|
|
|
| 329 |
return documents
|
| 330 |
|
| 331 |
except Exception as e:
|
medrax/tools/browsing/__init__.py
CHANGED
|
@@ -6,8 +6,8 @@ from .web_browser import WebBrowserTool, WebBrowserSchema, SearchQuerySchema, Vi
|
|
| 6 |
__all__ = [
|
| 7 |
"DuckDuckGoSearchTool",
|
| 8 |
"WebSearchInput",
|
| 9 |
-
"WebBrowserTool",
|
| 10 |
"WebBrowserSchema",
|
| 11 |
"SearchQuerySchema",
|
| 12 |
-
"VisitUrlSchema"
|
| 13 |
-
]
|
|
|
|
| 6 |
__all__ = [
|
| 7 |
"DuckDuckGoSearchTool",
|
| 8 |
"WebSearchInput",
|
| 9 |
+
"WebBrowserTool",
|
| 10 |
"WebBrowserSchema",
|
| 11 |
"SearchQuerySchema",
|
| 12 |
+
"VisitUrlSchema",
|
| 13 |
+
]
|
medrax/tools/browsing/duckduckgo.py
CHANGED
|
@@ -95,18 +95,12 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 95 |
super().__init__(**kwargs)
|
| 96 |
|
| 97 |
if DDGS is None:
|
| 98 |
-
logger.error(
|
| 99 |
-
|
| 100 |
-
)
|
| 101 |
-
raise ImportError(
|
| 102 |
-
"duckduckgo-search package is required for web search functionality"
|
| 103 |
-
)
|
| 104 |
|
| 105 |
logger.info("DuckDuckGo search tool initialized successfully")
|
| 106 |
|
| 107 |
-
def _perform_search_sync(
|
| 108 |
-
self, query: str, max_results: int = 5, region: str = "us-en"
|
| 109 |
-
) -> Dict[str, Any]:
|
| 110 |
"""
|
| 111 |
Perform the actual web search using DuckDuckGo synchronously.
|
| 112 |
|
|
@@ -118,9 +112,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 118 |
Returns:
|
| 119 |
Dict[str, Any]: Structured search results.
|
| 120 |
"""
|
| 121 |
-
logger.info(
|
| 122 |
-
f"Performing web search: '{query}' (max_results={max_results}, region={region})"
|
| 123 |
-
)
|
| 124 |
|
| 125 |
try:
|
| 126 |
# Initialize DDGS with error handling
|
|
@@ -158,9 +150,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 158 |
summary = f"No results found for '{query}'"
|
| 159 |
|
| 160 |
# Log successful completion
|
| 161 |
-
logger.info(
|
| 162 |
-
f"Web search completed successfully: {len(formatted_results)} results"
|
| 163 |
-
)
|
| 164 |
|
| 165 |
return {
|
| 166 |
"query": query,
|
|
@@ -217,7 +207,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 217 |
|
| 218 |
try:
|
| 219 |
result = self._perform_search_sync(query, max_results, region)
|
| 220 |
-
|
| 221 |
# Check if search was successful
|
| 222 |
if "error" in result:
|
| 223 |
metadata["analysis_status"] = "failed"
|
|
@@ -239,7 +229,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 239 |
}
|
| 240 |
metadata["analysis_status"] = "failed"
|
| 241 |
metadata["error_details"] = str(e)
|
| 242 |
-
|
| 243 |
return error_result, metadata
|
| 244 |
|
| 245 |
async def _arun(
|
|
@@ -296,9 +286,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 296 |
|
| 297 |
# Use asyncio to run sync search in executor
|
| 298 |
loop = asyncio.get_event_loop()
|
| 299 |
-
result, metadata = await loop.run_in_executor(
|
| 300 |
-
None, self._run, query, max_results, region
|
| 301 |
-
)
|
| 302 |
|
| 303 |
if writer:
|
| 304 |
# Parse result to get count for progress update
|
|
@@ -333,7 +321,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 333 |
"search_engine": "DuckDuckGo",
|
| 334 |
"timestamp": datetime.now().isoformat(),
|
| 335 |
}
|
| 336 |
-
|
| 337 |
metadata = {
|
| 338 |
"query": query,
|
| 339 |
"max_results": max_results,
|
|
@@ -344,12 +332,10 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 344 |
"analysis_status": "failed",
|
| 345 |
"error_details": str(e),
|
| 346 |
}
|
| 347 |
-
|
| 348 |
return error_result, metadata
|
| 349 |
|
| 350 |
-
def get_search_summary(
|
| 351 |
-
self, query: str, max_results: int = 3
|
| 352 |
-
) -> dict[str, str | list[str]]:
|
| 353 |
"""
|
| 354 |
Get a quick summary of search results for a given query.
|
| 355 |
|
|
@@ -375,14 +361,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 375 |
results = result.get("results", [])
|
| 376 |
titles = [r["title"] for r in results]
|
| 377 |
urls = [r["url"] for r in results]
|
| 378 |
-
snippets = [
|
| 379 |
-
(
|
| 380 |
-
r["snippet"][:100] + "..."
|
| 381 |
-
if len(r["snippet"]) > 100
|
| 382 |
-
else r["snippet"]
|
| 383 |
-
)
|
| 384 |
-
for r in results
|
| 385 |
-
]
|
| 386 |
|
| 387 |
return {
|
| 388 |
"query": query,
|
|
|
|
| 95 |
super().__init__(**kwargs)
|
| 96 |
|
| 97 |
if DDGS is None:
|
| 98 |
+
logger.error("duckduckgo-search package not installed. Install with: pip install duckduckgo-search")
|
| 99 |
+
raise ImportError("duckduckgo-search package is required for web search functionality")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
logger.info("DuckDuckGo search tool initialized successfully")
|
| 102 |
|
| 103 |
+
def _perform_search_sync(self, query: str, max_results: int = 5, region: str = "us-en") -> Dict[str, Any]:
|
|
|
|
|
|
|
| 104 |
"""
|
| 105 |
Perform the actual web search using DuckDuckGo synchronously.
|
| 106 |
|
|
|
|
| 112 |
Returns:
|
| 113 |
Dict[str, Any]: Structured search results.
|
| 114 |
"""
|
| 115 |
+
logger.info(f"Performing web search: '{query}' (max_results={max_results}, region={region})")
|
|
|
|
|
|
|
| 116 |
|
| 117 |
try:
|
| 118 |
# Initialize DDGS with error handling
|
|
|
|
| 150 |
summary = f"No results found for '{query}'"
|
| 151 |
|
| 152 |
# Log successful completion
|
| 153 |
+
logger.info(f"Web search completed successfully: {len(formatted_results)} results")
|
|
|
|
|
|
|
| 154 |
|
| 155 |
return {
|
| 156 |
"query": query,
|
|
|
|
| 207 |
|
| 208 |
try:
|
| 209 |
result = self._perform_search_sync(query, max_results, region)
|
| 210 |
+
|
| 211 |
# Check if search was successful
|
| 212 |
if "error" in result:
|
| 213 |
metadata["analysis_status"] = "failed"
|
|
|
|
| 229 |
}
|
| 230 |
metadata["analysis_status"] = "failed"
|
| 231 |
metadata["error_details"] = str(e)
|
| 232 |
+
|
| 233 |
return error_result, metadata
|
| 234 |
|
| 235 |
async def _arun(
|
|
|
|
| 286 |
|
| 287 |
# Use asyncio to run sync search in executor
|
| 288 |
loop = asyncio.get_event_loop()
|
| 289 |
+
result, metadata = await loop.run_in_executor(None, self._run, query, max_results, region)
|
|
|
|
|
|
|
| 290 |
|
| 291 |
if writer:
|
| 292 |
# Parse result to get count for progress update
|
|
|
|
| 321 |
"search_engine": "DuckDuckGo",
|
| 322 |
"timestamp": datetime.now().isoformat(),
|
| 323 |
}
|
| 324 |
+
|
| 325 |
metadata = {
|
| 326 |
"query": query,
|
| 327 |
"max_results": max_results,
|
|
|
|
| 332 |
"analysis_status": "failed",
|
| 333 |
"error_details": str(e),
|
| 334 |
}
|
| 335 |
+
|
| 336 |
return error_result, metadata
|
| 337 |
|
| 338 |
+
def get_search_summary(self, query: str, max_results: int = 3) -> dict[str, str | list[str]]:
|
|
|
|
|
|
|
| 339 |
"""
|
| 340 |
Get a quick summary of search results for a given query.
|
| 341 |
|
|
|
|
| 361 |
results = result.get("results", [])
|
| 362 |
titles = [r["title"] for r in results]
|
| 363 |
urls = [r["url"] for r in results]
|
| 364 |
+
snippets = [(r["snippet"][:100] + "..." if len(r["snippet"]) > 100 else r["snippet"]) for r in results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
return {
|
| 367 |
"query": query,
|
medrax/tools/browsing/web_browser.py
CHANGED
|
@@ -78,9 +78,7 @@ class WebBrowserTool(BaseTool):
|
|
| 78 |
max_results: int = 5
|
| 79 |
args_schema: Type[BaseModel] = WebBrowserSchema
|
| 80 |
|
| 81 |
-
def __init__(
|
| 82 |
-
self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs
|
| 83 |
-
):
|
| 84 |
"""Initialize the web browser tool with optional search API credentials.
|
| 85 |
|
| 86 |
Args:
|
|
@@ -145,9 +143,7 @@ class WebBrowserTool(BaseTool):
|
|
| 145 |
except Exception as e:
|
| 146 |
return {"error": f"Search failed: {str(e)}"}
|
| 147 |
|
| 148 |
-
def visit_url(
|
| 149 |
-
self, url: str, max_content_length: int = 5000, max_links: int = 5
|
| 150 |
-
) -> Dict[str, Any]:
|
| 151 |
"""Visit a URL and extract its content with comprehensive parsing.
|
| 152 |
|
| 153 |
Args:
|
|
@@ -218,9 +214,7 @@ class WebBrowserTool(BaseTool):
|
|
| 218 |
return {
|
| 219 |
"title": title,
|
| 220 |
"content": (
|
| 221 |
-
text_content[:max_content_length]
|
| 222 |
-
if len(text_content) > max_content_length
|
| 223 |
-
else text_content
|
| 224 |
),
|
| 225 |
"url": url,
|
| 226 |
"links": links[:max_links], # Limit to max_links
|
|
|
|
| 78 |
max_results: int = 5
|
| 79 |
args_schema: Type[BaseModel] = WebBrowserSchema
|
| 80 |
|
| 81 |
+
def __init__(self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs):
|
|
|
|
|
|
|
| 82 |
"""Initialize the web browser tool with optional search API credentials.
|
| 83 |
|
| 84 |
Args:
|
|
|
|
| 143 |
except Exception as e:
|
| 144 |
return {"error": f"Search failed: {str(e)}"}
|
| 145 |
|
| 146 |
+
def visit_url(self, url: str, max_content_length: int = 5000, max_links: int = 5) -> Dict[str, Any]:
|
|
|
|
|
|
|
| 147 |
"""Visit a URL and extract its content with comprehensive parsing.
|
| 148 |
|
| 149 |
Args:
|
|
|
|
| 214 |
return {
|
| 215 |
"title": title,
|
| 216 |
"content": (
|
| 217 |
+
text_content[:max_content_length] if len(text_content) > max_content_length else text_content
|
|
|
|
|
|
|
| 218 |
),
|
| 219 |
"url": url,
|
| 220 |
"links": links[:max_links], # Limit to max_links
|
medrax/tools/classification/__init__.py
CHANGED
|
@@ -3,9 +3,4 @@
|
|
| 3 |
from .torchxrayvision import TorchXRayVisionClassifierTool, TorchXRayVisionInput
|
| 4 |
from .arcplus import ArcPlusClassifierTool, ArcPlusInput
|
| 5 |
|
| 6 |
-
__all__ = [
|
| 7 |
-
"TorchXRayVisionClassifierTool",
|
| 8 |
-
"TorchXRayVisionInput",
|
| 9 |
-
"ArcPlusClassifierTool",
|
| 10 |
-
"ArcPlusInput"
|
| 11 |
-
]
|
|
|
|
| 3 |
from .torchxrayvision import TorchXRayVisionClassifierTool, TorchXRayVisionInput
|
| 4 |
from .arcplus import ArcPlusClassifierTool, ArcPlusInput
|
| 5 |
|
| 6 |
+
__all__ = ["TorchXRayVisionClassifierTool", "TorchXRayVisionInput", "ArcPlusClassifierTool", "ArcPlusInput"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
medrax/tools/classification/arcplus.py
CHANGED
|
@@ -38,9 +38,7 @@ class OmniSwinTransformer(SwinTransformer):
|
|
| 38 |
|
| 39 |
self.omni_heads = []
|
| 40 |
for num_classes in num_classes_list:
|
| 41 |
-
self.omni_heads.append(
|
| 42 |
-
nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 43 |
-
)
|
| 44 |
self.omni_heads = nn.ModuleList(self.omni_heads)
|
| 45 |
|
| 46 |
def forward(self, x, head_n=None):
|
|
@@ -62,9 +60,7 @@ class OmniSwinTransformer(SwinTransformer):
|
|
| 62 |
class ArcPlusInput(BaseModel):
|
| 63 |
"""Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images."""
|
| 64 |
|
| 65 |
-
image_path: str = Field(
|
| 66 |
-
..., description="Path to the radiology image file, only supports JPG or PNG images"
|
| 67 |
-
)
|
| 68 |
|
| 69 |
|
| 70 |
class ArcPlusClassifierTool(BaseTool):
|
|
@@ -249,11 +245,7 @@ class ArcPlusClassifierTool(BaseTool):
|
|
| 249 |
|
| 250 |
# Remove "module." prefix if present (improved logic from example)
|
| 251 |
if any([True if "module." in k else False for k in state_dict.keys()]):
|
| 252 |
-
state_dict = {
|
| 253 |
-
k.replace("module.", ""): v
|
| 254 |
-
for k, v in state_dict.items()
|
| 255 |
-
if k.startswith("module.")
|
| 256 |
-
}
|
| 257 |
|
| 258 |
# Load the model weights
|
| 259 |
msg = self.model.load_state_dict(state_dict, strict=False)
|
|
@@ -342,14 +334,10 @@ class ArcPlusClassifierTool(BaseTool):
|
|
| 342 |
|
| 343 |
# Map predictions to disease names
|
| 344 |
if len(predictions) != len(self.disease_list):
|
| 345 |
-
print(
|
| 346 |
-
f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}"
|
| 347 |
-
)
|
| 348 |
# Pad or truncate as needed
|
| 349 |
if len(predictions) < len(self.disease_list):
|
| 350 |
-
predictions = np.pad(
|
| 351 |
-
predictions, (0, len(self.disease_list) - len(predictions))
|
| 352 |
-
)
|
| 353 |
else:
|
| 354 |
predictions = predictions[: len(self.disease_list)]
|
| 355 |
|
|
|
|
| 38 |
|
| 39 |
self.omni_heads = []
|
| 40 |
for num_classes in num_classes_list:
|
| 41 |
+
self.omni_heads.append(nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
|
|
|
|
|
|
|
| 42 |
self.omni_heads = nn.ModuleList(self.omni_heads)
|
| 43 |
|
| 44 |
def forward(self, x, head_n=None):
|
|
|
|
| 60 |
class ArcPlusInput(BaseModel):
|
| 61 |
"""Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images."""
|
| 62 |
|
| 63 |
+
image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
class ArcPlusClassifierTool(BaseTool):
|
|
|
|
| 245 |
|
| 246 |
# Remove "module." prefix if present (improved logic from example)
|
| 247 |
if any([True if "module." in k else False for k in state_dict.keys()]):
|
| 248 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if k.startswith("module.")}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
# Load the model weights
|
| 251 |
msg = self.model.load_state_dict(state_dict, strict=False)
|
|
|
|
| 334 |
|
| 335 |
# Map predictions to disease names
|
| 336 |
if len(predictions) != len(self.disease_list):
|
| 337 |
+
print(f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}")
|
|
|
|
|
|
|
| 338 |
# Pad or truncate as needed
|
| 339 |
if len(predictions) < len(self.disease_list):
|
| 340 |
+
predictions = np.pad(predictions, (0, len(self.disease_list) - len(predictions)))
|
|
|
|
|
|
|
| 341 |
else:
|
| 342 |
predictions = predictions[: len(self.disease_list)]
|
| 343 |
|
medrax/tools/classification/torchxrayvision.py
CHANGED
|
@@ -19,9 +19,7 @@ from medrax.utils.utils import preprocess_medical_image
|
|
| 19 |
class TorchXRayVisionInput(BaseModel):
|
| 20 |
"""Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
|
| 21 |
|
| 22 |
-
image_path: str = Field(
|
| 23 |
-
..., description="Path to the radiology image file, only supports JPG or PNG images"
|
| 24 |
-
)
|
| 25 |
|
| 26 |
|
| 27 |
class TorchXRayVisionClassifierTool(BaseTool):
|
|
|
|
| 19 |
class TorchXRayVisionInput(BaseModel):
|
| 20 |
"""Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
|
| 21 |
|
| 22 |
+
image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class TorchXRayVisionClassifierTool(BaseTool):
|
medrax/tools/dicom.py
CHANGED
|
@@ -14,9 +14,7 @@ class DicomProcessorInput(BaseModel):
|
|
| 14 |
"""Input schema for the DICOM Processor Tool."""
|
| 15 |
|
| 16 |
dicom_path: str = Field(..., description="Path to the DICOM file")
|
| 17 |
-
window_center: Optional[float] = Field(
|
| 18 |
-
None, description="Window center for contrast adjustment"
|
| 19 |
-
)
|
| 20 |
window_width: Optional[float] = Field(None, description="Window width for contrast adjustment")
|
| 21 |
|
| 22 |
|
|
|
|
| 14 |
"""Input schema for the DICOM Processor Tool."""
|
| 15 |
|
| 16 |
dicom_path: str = Field(..., description="Path to the DICOM file")
|
| 17 |
+
window_center: Optional[float] = Field(None, description="Window center for contrast adjustment")
|
|
|
|
|
|
|
| 18 |
window_width: Optional[float] = Field(None, description="Window width for contrast adjustment")
|
| 19 |
|
| 20 |
|
medrax/tools/grounding.py
CHANGED
|
@@ -90,11 +90,8 @@ class XRayPhraseGroundingTool(BaseTool):
|
|
| 90 |
trust_remote_code=True,
|
| 91 |
quantization_config=quantization_config,
|
| 92 |
)
|
| 93 |
-
self.processor = AutoProcessor.from_pretrained(
|
| 94 |
-
model_path, cache_dir=cache_dir, trust_remote_code=True
|
| 95 |
-
)
|
| 96 |
|
| 97 |
-
|
| 98 |
self.model = self.model.eval()
|
| 99 |
|
| 100 |
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
|
|
@@ -176,12 +173,8 @@ class XRayPhraseGroundingTool(BaseTool):
|
|
| 176 |
)
|
| 177 |
|
| 178 |
prompt_length = inputs["input_ids"].shape[-1]
|
| 179 |
-
decoded_text = self.processor.decode(
|
| 180 |
-
|
| 181 |
-
)
|
| 182 |
-
predictions = self.processor.convert_output_to_plaintext_or_grounded_sequence(
|
| 183 |
-
decoded_text
|
| 184 |
-
)
|
| 185 |
|
| 186 |
metadata = {
|
| 187 |
"image_path": image_path,
|
|
@@ -208,9 +201,7 @@ class XRayPhraseGroundingTool(BaseTool):
|
|
| 208 |
# Convert model bboxes to list format and get original image bboxes
|
| 209 |
model_bboxes = [list(bbox) for bbox in pred_bboxes]
|
| 210 |
original_bboxes = [
|
| 211 |
-
self.processor.adjust_box_for_original_image_size(
|
| 212 |
-
bbox, width=image.size[0], height=image.size[1]
|
| 213 |
-
)
|
| 214 |
for bbox in model_bboxes
|
| 215 |
]
|
| 216 |
|
|
|
|
| 90 |
trust_remote_code=True,
|
| 91 |
quantization_config=quantization_config,
|
| 92 |
)
|
| 93 |
+
self.processor = AutoProcessor.from_pretrained(model_path, cache_dir=cache_dir, trust_remote_code=True)
|
|
|
|
|
|
|
| 94 |
|
|
|
|
| 95 |
self.model = self.model.eval()
|
| 96 |
|
| 97 |
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
|
|
|
|
| 173 |
)
|
| 174 |
|
| 175 |
prompt_length = inputs["input_ids"].shape[-1]
|
| 176 |
+
decoded_text = self.processor.decode(output[0][prompt_length:], skip_special_tokens=True)
|
| 177 |
+
predictions = self.processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
metadata = {
|
| 180 |
"image_path": image_path,
|
|
|
|
| 201 |
# Convert model bboxes to list format and get original image bboxes
|
| 202 |
model_bboxes = [list(bbox) for bbox in pred_bboxes]
|
| 203 |
original_bboxes = [
|
| 204 |
+
self.processor.adjust_box_for_original_image_size(bbox, width=image.size[0], height=image.size[1])
|
|
|
|
|
|
|
| 205 |
for bbox in model_bboxes
|
| 206 |
]
|
| 207 |
|
medrax/tools/rag.py
CHANGED
|
@@ -14,7 +14,7 @@ class RAGTool(BaseTool):
|
|
| 14 |
|
| 15 |
The knowledge base includes:
|
| 16 |
- Medical textbooks and reference materials
|
| 17 |
-
- Research papers and clinical studies
|
| 18 |
- Medical manuals and guidelines
|
| 19 |
- Specialized medical literature
|
| 20 |
|
|
|
|
| 14 |
|
| 15 |
The knowledge base includes:
|
| 16 |
- Medical textbooks and reference materials
|
| 17 |
+
- Research papers and clinical studies
|
| 18 |
- Medical manuals and guidelines
|
| 19 |
- Specialized medical literature
|
| 20 |
|
medrax/tools/report_generation.py
CHANGED
|
@@ -23,9 +23,7 @@ from transformers import (
|
|
| 23 |
class ChestXRayInput(BaseModel):
|
| 24 |
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
|
| 25 |
|
| 26 |
-
image_path: str = Field(
|
| 27 |
-
..., description="Path to the radiology image file, only supports JPG or PNG images"
|
| 28 |
-
)
|
| 29 |
|
| 30 |
|
| 31 |
class ChestXRayReportGeneratorTool(BaseTool):
|
|
@@ -180,12 +178,8 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
| 180 |
"""
|
| 181 |
try:
|
| 182 |
# Process image for both models
|
| 183 |
-
findings_pixels = self._process_image(
|
| 184 |
-
|
| 185 |
-
)
|
| 186 |
-
impression_pixels = self._process_image(
|
| 187 |
-
image_path, self.impression_processor, self.impression_model
|
| 188 |
-
)
|
| 189 |
|
| 190 |
# Generate both sections
|
| 191 |
with torch.inference_mode():
|
|
@@ -197,11 +191,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
| 197 |
)
|
| 198 |
|
| 199 |
# Combine into formatted report
|
| 200 |
-
report =
|
| 201 |
-
"CHEST X-RAY REPORT\n\n"
|
| 202 |
-
f"FINDINGS:\n{findings_text}\n\n"
|
| 203 |
-
f"IMPRESSION:\n{impression_text}"
|
| 204 |
-
)
|
| 205 |
|
| 206 |
output = {
|
| 207 |
"report": report,
|
|
|
|
| 23 |
class ChestXRayInput(BaseModel):
|
| 24 |
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
|
| 25 |
|
| 26 |
+
image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class ChestXRayReportGeneratorTool(BaseTool):
|
|
|
|
| 178 |
"""
|
| 179 |
try:
|
| 180 |
# Process image for both models
|
| 181 |
+
findings_pixels = self._process_image(image_path, self.findings_processor, self.findings_model)
|
| 182 |
+
impression_pixels = self._process_image(image_path, self.impression_processor, self.impression_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
# Generate both sections
|
| 185 |
with torch.inference_mode():
|
|
|
|
| 191 |
)
|
| 192 |
|
| 193 |
# Combine into formatted report
|
| 194 |
+
report = "CHEST X-RAY REPORT\n\n" f"FINDINGS:\n{findings_text}\n\n" f"IMPRESSION:\n{impression_text}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
output = {
|
| 197 |
"report": report,
|
medrax/tools/segmentation/__init__.py
CHANGED
|
@@ -3,10 +3,4 @@
|
|
| 3 |
from .segmentation import ChestXRaySegmentationTool, ChestXRaySegmentationInput, OrganMetrics
|
| 4 |
from .medsam2 import MedSAM2Tool, MedSAM2Input
|
| 5 |
|
| 6 |
-
__all__ = [
|
| 7 |
-
"ChestXRaySegmentationTool",
|
| 8 |
-
"ChestXRaySegmentationInput",
|
| 9 |
-
"OrganMetrics",
|
| 10 |
-
"MedSAM2Tool",
|
| 11 |
-
"MedSAM2Input"
|
| 12 |
-
]
|
|
|
|
| 3 |
from .segmentation import ChestXRaySegmentationTool, ChestXRaySegmentationInput, OrganMetrics
|
| 4 |
from .medsam2 import MedSAM2Tool, MedSAM2Input
|
| 5 |
|
| 6 |
+
__all__ = ["ChestXRaySegmentationTool", "ChestXRaySegmentationInput", "OrganMetrics", "MedSAM2Tool", "MedSAM2Input"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
medrax/tools/segmentation/medsam2.py
CHANGED
|
@@ -26,7 +26,6 @@ from hydra import initialize_config_dir
|
|
| 26 |
from hydra.core.global_hydra import GlobalHydra
|
| 27 |
|
| 28 |
|
| 29 |
-
|
| 30 |
class MedSAM2Input(BaseModel):
|
| 31 |
"""Input schema for the MedSAM2 Tool."""
|
| 32 |
|
|
@@ -47,7 +46,7 @@ class MedSAM2Input(BaseModel):
|
|
| 47 |
|
| 48 |
class MedSAM2Tool(BaseTool):
|
| 49 |
"""Advanced medical image segmentation tool using MedSAM2.
|
| 50 |
-
|
| 51 |
This tool provides state-of-the-art medical image segmentation capabilities using
|
| 52 |
the MedSAM2 model, which is specifically adapted for medical imaging from Meta's SAM2.
|
| 53 |
Supports interactive prompting with boxes, points, or automatic segmentation.
|
|
@@ -92,22 +91,17 @@ class MedSAM2Tool(BaseTool):
|
|
| 92 |
# This works around the issue with initialize_config_module in sam2
|
| 93 |
if GlobalHydra.instance().is_initialized():
|
| 94 |
GlobalHydra.instance().clear()
|
| 95 |
-
|
| 96 |
config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
|
| 97 |
initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
|
| 98 |
-
|
| 99 |
hf_hub_download(
|
| 100 |
-
repo_id=model_path,
|
| 101 |
-
filename=model_file,
|
| 102 |
-
local_dir=self.cache_dir,
|
| 103 |
-
local_dir_use_symlinks=False
|
| 104 |
)
|
| 105 |
|
| 106 |
-
config_path = model_cfg.replace(
|
| 107 |
sam2_model = build_sam2(config_path, str(self.cache_dir / model_file), device=device)
|
| 108 |
self.predictor = SAM2ImagePredictor(sam2_model)
|
| 109 |
-
|
| 110 |
-
print(f"MedSAM2 model loaded successfully on {device}")
|
| 111 |
|
| 112 |
except Exception as e:
|
| 113 |
raise RuntimeError(f"Failed to initialize MedSAM2: {str(e)}")
|
|
@@ -116,10 +110,10 @@ class MedSAM2Tool(BaseTool):
|
|
| 116 |
"""Load and preprocess image for medical analysis."""
|
| 117 |
try:
|
| 118 |
# Handle different image formats
|
| 119 |
-
if image_path.lower().endswith(
|
| 120 |
# DICOM files - would need DICOM processor
|
| 121 |
raise ValueError("DICOM files not directly supported. Please convert to standard image format first.")
|
| 122 |
-
|
| 123 |
# Load standard image formats
|
| 124 |
image = Image.open(image_path)
|
| 125 |
|
|
@@ -131,29 +125,29 @@ class MedSAM2Tool(BaseTool):
|
|
| 131 |
image = Image.fromarray(img_normalized, mode='L')
|
| 132 |
|
| 133 |
# For medical images, convert to grayscale first if needed, then to RGB
|
| 134 |
-
if image.mode ==
|
| 135 |
# Convert grayscale to RGB for SAM2
|
| 136 |
-
image = image.convert(
|
| 137 |
-
elif image.mode !=
|
| 138 |
-
if image.mode ==
|
| 139 |
# Create white background for RGBA
|
| 140 |
-
background = Image.new(
|
| 141 |
background.paste(image, mask=image.split()[-1])
|
| 142 |
image = background
|
| 143 |
else:
|
| 144 |
-
image = image.convert(
|
| 145 |
-
|
| 146 |
# Convert to numpy array
|
| 147 |
image_np = np.array(image)
|
| 148 |
-
|
| 149 |
# Ensure image is in proper range [0, 255]
|
| 150 |
if image_np.max() <= 1.0:
|
| 151 |
image_np = (image_np * 255).astype(np.uint8)
|
| 152 |
else:
|
| 153 |
image_np = image_np.astype(np.uint8)
|
| 154 |
-
|
| 155 |
return image_np
|
| 156 |
-
|
| 157 |
except Exception as e:
|
| 158 |
raise ValueError(f"Failed to load image {image_path}: {str(e)}")
|
| 159 |
|
|
@@ -161,55 +155,53 @@ class MedSAM2Tool(BaseTool):
|
|
| 161 |
"""Process and validate prompts."""
|
| 162 |
if prompt_type == "auto":
|
| 163 |
return None, None, None
|
| 164 |
-
|
| 165 |
if prompt_coords is None:
|
| 166 |
if prompt_type != "auto":
|
| 167 |
raise ValueError(f"Prompt coordinates required for prompt type '{prompt_type}'")
|
| 168 |
return None, None, None
|
| 169 |
-
|
| 170 |
if prompt_type == "box":
|
| 171 |
if len(prompt_coords) != 4:
|
| 172 |
raise ValueError("Box prompt requires 4 coordinates: [x1,y1,x2,y2]")
|
| 173 |
-
|
| 174 |
x1, y1, x2, y2 = prompt_coords
|
| 175 |
# Validate coordinates
|
| 176 |
if x1 >= x2 or y1 >= y2:
|
| 177 |
raise ValueError("Invalid box coordinates: x1 < x2 and y1 < y2 required")
|
| 178 |
-
|
| 179 |
input_box = np.array([[x1, y1, x2, y2]])
|
| 180 |
return input_box, None, None
|
| 181 |
-
|
| 182 |
elif prompt_type == "point":
|
| 183 |
if len(prompt_coords) != 2:
|
| 184 |
raise ValueError("Point prompt requires 2 coordinates: [x,y]")
|
| 185 |
-
|
| 186 |
x, y = prompt_coords
|
| 187 |
input_point = np.array([[x, y]])
|
| 188 |
input_label = np.array([1]) # Positive point
|
| 189 |
return None, input_point, input_label
|
| 190 |
-
|
| 191 |
else:
|
| 192 |
raise ValueError(f"Unknown prompt type: {prompt_type}")
|
| 193 |
|
| 194 |
def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
|
| 195 |
"""Create visualization of segmentation results."""
|
| 196 |
plt.figure(figsize=(10, 10))
|
| 197 |
-
|
| 198 |
# Convert RGB image to grayscale for background display
|
| 199 |
if len(image.shape) == 3:
|
| 200 |
# Convert RGB to grayscale using standard luminance formula
|
| 201 |
-
gray_image = 0.299 * image[
|
| 202 |
else:
|
| 203 |
gray_image = image
|
| 204 |
-
|
| 205 |
# Display grayscale background
|
| 206 |
-
plt.imshow(
|
| 207 |
-
|
| 208 |
-
)
|
| 209 |
-
|
| 210 |
# Generate color palette for multiple masks
|
| 211 |
colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
|
| 212 |
-
|
| 213 |
# Process and overlay each mask
|
| 214 |
for idx, (mask, color) in enumerate(zip(masks, colors)):
|
| 215 |
if mask.sum() > 0:
|
|
@@ -217,33 +209,31 @@ class MedSAM2Tool(BaseTool):
|
|
| 217 |
mask_bool = mask.astype(bool)
|
| 218 |
colored_mask = np.zeros((*mask_bool.shape, 4))
|
| 219 |
colored_mask[mask_bool] = (*color[:3], 0.3) # 30% transparency like segmentation tool
|
| 220 |
-
plt.imshow(
|
| 221 |
-
|
| 222 |
-
)
|
| 223 |
-
|
| 224 |
# Add legend entry for each mask
|
| 225 |
mask_label = f"Mask {idx + 1} (score: {prompt_info.get('scores', [0])[idx] if idx < len(prompt_info.get('scores', [])) else 0:.3f})"
|
| 226 |
plt.plot([], [], color=color, label=mask_label, linewidth=3)
|
| 227 |
-
|
| 228 |
# Add prompt visualization with consistent styling
|
| 229 |
-
if prompt_info.get(
|
| 230 |
-
box = prompt_info[
|
| 231 |
x1, y1, x2, y2 = box
|
| 232 |
-
plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1],
|
| 233 |
-
|
| 234 |
-
if prompt_info.get(
|
| 235 |
-
point = prompt_info[
|
| 236 |
-
plt.plot(point[0], point[1],
|
| 237 |
-
|
| 238 |
plt.title("Segmentation Overlay")
|
| 239 |
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
| 240 |
plt.axis("off")
|
| 241 |
-
|
| 242 |
# Save visualization with higher DPI like segmentation tool
|
| 243 |
viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
|
| 244 |
-
plt.savefig(viz_path, bbox_inches=
|
| 245 |
plt.close()
|
| 246 |
-
|
| 247 |
return str(viz_path)
|
| 248 |
|
| 249 |
def _run(
|
|
@@ -258,28 +248,28 @@ class MedSAM2Tool(BaseTool):
|
|
| 258 |
try:
|
| 259 |
# Load image
|
| 260 |
image = self._load_image(image_path)
|
| 261 |
-
|
| 262 |
# Set image for predictor
|
| 263 |
self.predictor.set_image(image)
|
| 264 |
-
|
| 265 |
# Process prompts
|
| 266 |
-
input_box, input_point, input_label = self._process_prompts(
|
| 267 |
-
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
# Run inference
|
| 271 |
if prompt_type == "auto":
|
| 272 |
# For auto segmentation, try multiple approaches and select best result
|
| 273 |
h, w = image.shape[:2]
|
| 274 |
-
|
| 275 |
# Try multiple points in key areas for medical images
|
| 276 |
-
sample_points = np.array(
|
| 277 |
-
[
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
| 281 |
sample_labels = np.array([1, 1, 1]) # All positive points
|
| 282 |
-
|
| 283 |
masks, scores, logits = self.predictor.predict(
|
| 284 |
point_coords=sample_points,
|
| 285 |
point_labels=sample_labels,
|
|
@@ -292,29 +282,29 @@ class MedSAM2Tool(BaseTool):
|
|
| 292 |
box=input_box,
|
| 293 |
multimask_output=True,
|
| 294 |
)
|
| 295 |
-
|
| 296 |
# Create visualization
|
| 297 |
prompt_info = {
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
}
|
| 303 |
viz_path = self._create_visualization(image, masks, prompt_info)
|
| 304 |
-
|
| 305 |
# Create output dictionary (main results)
|
| 306 |
output = {
|
| 307 |
"segmentation_image_path": viz_path,
|
| 308 |
-
"confidence_scores": scores.tolist() if hasattr(scores,
|
| 309 |
"num_masks": len(masks),
|
| 310 |
"best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
|
| 311 |
"mask_summary": {
|
| 312 |
"total_masks": len(masks),
|
| 313 |
"mask_shapes": [list(mask.shape) for mask in masks],
|
| 314 |
-
"segmented_area_pixels": [int(mask.sum()) for mask in masks]
|
| 315 |
},
|
| 316 |
}
|
| 317 |
-
|
| 318 |
# Create metadata dictionary
|
| 319 |
metadata = {
|
| 320 |
"image_path": image_path,
|
|
@@ -326,9 +316,9 @@ class MedSAM2Tool(BaseTool):
|
|
| 326 |
"num_masks_generated": len(masks),
|
| 327 |
"analysis_status": "completed",
|
| 328 |
}
|
| 329 |
-
|
| 330 |
return output, metadata
|
| 331 |
-
|
| 332 |
except Exception as e:
|
| 333 |
error_output = {"error": str(e)}
|
| 334 |
error_metadata = {
|
|
@@ -347,4 +337,4 @@ class MedSAM2Tool(BaseTool):
|
|
| 347 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 348 |
) -> Tuple[Dict[str, Any], Dict]:
|
| 349 |
"""Async version of _run."""
|
| 350 |
-
return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)
|
|
|
|
| 26 |
from hydra.core.global_hydra import GlobalHydra
|
| 27 |
|
| 28 |
|
|
|
|
| 29 |
class MedSAM2Input(BaseModel):
|
| 30 |
"""Input schema for the MedSAM2 Tool."""
|
| 31 |
|
|
|
|
| 46 |
|
| 47 |
class MedSAM2Tool(BaseTool):
|
| 48 |
"""Advanced medical image segmentation tool using MedSAM2.
|
| 49 |
+
|
| 50 |
This tool provides state-of-the-art medical image segmentation capabilities using
|
| 51 |
the MedSAM2 model, which is specifically adapted for medical imaging from Meta's SAM2.
|
| 52 |
Supports interactive prompting with boxes, points, or automatic segmentation.
|
|
|
|
| 91 |
# This works around the issue with initialize_config_module in sam2
|
| 92 |
if GlobalHydra.instance().is_initialized():
|
| 93 |
GlobalHydra.instance().clear()
|
| 94 |
+
|
| 95 |
config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
|
| 96 |
initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
|
| 97 |
+
|
| 98 |
hf_hub_download(
|
| 99 |
+
repo_id=model_path, filename=model_file, local_dir=self.cache_dir, local_dir_use_symlinks=False
|
|
|
|
|
|
|
|
|
|
| 100 |
)
|
| 101 |
|
| 102 |
+
config_path = model_cfg.replace(".yaml", "")
|
| 103 |
sam2_model = build_sam2(config_path, str(self.cache_dir / model_file), device=device)
|
| 104 |
self.predictor = SAM2ImagePredictor(sam2_model)
|
|
|
|
|
|
|
| 105 |
|
| 106 |
except Exception as e:
|
| 107 |
raise RuntimeError(f"Failed to initialize MedSAM2: {str(e)}")
|
|
|
|
| 110 |
"""Load and preprocess image for medical analysis."""
|
| 111 |
try:
|
| 112 |
# Handle different image formats
|
| 113 |
+
if image_path.lower().endswith(".dcm"):
|
| 114 |
# DICOM files - would need DICOM processor
|
| 115 |
raise ValueError("DICOM files not directly supported. Please convert to standard image format first.")
|
| 116 |
+
|
| 117 |
# Load standard image formats
|
| 118 |
image = Image.open(image_path)
|
| 119 |
|
|
|
|
| 125 |
image = Image.fromarray(img_normalized, mode='L')
|
| 126 |
|
| 127 |
# For medical images, convert to grayscale first if needed, then to RGB
|
| 128 |
+
if image.mode == "L": # Grayscale
|
| 129 |
# Convert grayscale to RGB for SAM2
|
| 130 |
+
image = image.convert("RGB")
|
| 131 |
+
elif image.mode != "RGB":
|
| 132 |
+
if image.mode == "RGBA":
|
| 133 |
# Create white background for RGBA
|
| 134 |
+
background = Image.new("RGB", image.size, (255, 255, 255))
|
| 135 |
background.paste(image, mask=image.split()[-1])
|
| 136 |
image = background
|
| 137 |
else:
|
| 138 |
+
image = image.convert("RGB")
|
| 139 |
+
|
| 140 |
# Convert to numpy array
|
| 141 |
image_np = np.array(image)
|
| 142 |
+
|
| 143 |
# Ensure image is in proper range [0, 255]
|
| 144 |
if image_np.max() <= 1.0:
|
| 145 |
image_np = (image_np * 255).astype(np.uint8)
|
| 146 |
else:
|
| 147 |
image_np = image_np.astype(np.uint8)
|
| 148 |
+
|
| 149 |
return image_np
|
| 150 |
+
|
| 151 |
except Exception as e:
|
| 152 |
raise ValueError(f"Failed to load image {image_path}: {str(e)}")
|
| 153 |
|
|
|
|
| 155 |
"""Process and validate prompts."""
|
| 156 |
if prompt_type == "auto":
|
| 157 |
return None, None, None
|
| 158 |
+
|
| 159 |
if prompt_coords is None:
|
| 160 |
if prompt_type != "auto":
|
| 161 |
raise ValueError(f"Prompt coordinates required for prompt type '{prompt_type}'")
|
| 162 |
return None, None, None
|
| 163 |
+
|
| 164 |
if prompt_type == "box":
|
| 165 |
if len(prompt_coords) != 4:
|
| 166 |
raise ValueError("Box prompt requires 4 coordinates: [x1,y1,x2,y2]")
|
| 167 |
+
|
| 168 |
x1, y1, x2, y2 = prompt_coords
|
| 169 |
# Validate coordinates
|
| 170 |
if x1 >= x2 or y1 >= y2:
|
| 171 |
raise ValueError("Invalid box coordinates: x1 < x2 and y1 < y2 required")
|
| 172 |
+
|
| 173 |
input_box = np.array([[x1, y1, x2, y2]])
|
| 174 |
return input_box, None, None
|
| 175 |
+
|
| 176 |
elif prompt_type == "point":
|
| 177 |
if len(prompt_coords) != 2:
|
| 178 |
raise ValueError("Point prompt requires 2 coordinates: [x,y]")
|
| 179 |
+
|
| 180 |
x, y = prompt_coords
|
| 181 |
input_point = np.array([[x, y]])
|
| 182 |
input_label = np.array([1]) # Positive point
|
| 183 |
return None, input_point, input_label
|
| 184 |
+
|
| 185 |
else:
|
| 186 |
raise ValueError(f"Unknown prompt type: {prompt_type}")
|
| 187 |
|
| 188 |
def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
|
| 189 |
"""Create visualization of segmentation results."""
|
| 190 |
plt.figure(figsize=(10, 10))
|
| 191 |
+
|
| 192 |
# Convert RGB image to grayscale for background display
|
| 193 |
if len(image.shape) == 3:
|
| 194 |
# Convert RGB to grayscale using standard luminance formula
|
| 195 |
+
gray_image = 0.299 * image[:, :, 0] + 0.587 * image[:, :, 1] + 0.114 * image[:, :, 2]
|
| 196 |
else:
|
| 197 |
gray_image = image
|
| 198 |
+
|
| 199 |
# Display grayscale background
|
| 200 |
+
plt.imshow(gray_image, cmap="gray", extent=[0, image.shape[1], image.shape[0], 0])
|
| 201 |
+
|
|
|
|
|
|
|
| 202 |
# Generate color palette for multiple masks
|
| 203 |
colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
|
| 204 |
+
|
| 205 |
# Process and overlay each mask
|
| 206 |
for idx, (mask, color) in enumerate(zip(masks, colors)):
|
| 207 |
if mask.sum() > 0:
|
|
|
|
| 209 |
mask_bool = mask.astype(bool)
|
| 210 |
colored_mask = np.zeros((*mask_bool.shape, 4))
|
| 211 |
colored_mask[mask_bool] = (*color[:3], 0.3) # 30% transparency like segmentation tool
|
| 212 |
+
plt.imshow(colored_mask, extent=[0, image.shape[1], image.shape[0], 0])
|
| 213 |
+
|
|
|
|
|
|
|
| 214 |
# Add legend entry for each mask
|
| 215 |
mask_label = f"Mask {idx + 1} (score: {prompt_info.get('scores', [0])[idx] if idx < len(prompt_info.get('scores', [])) else 0:.3f})"
|
| 216 |
plt.plot([], [], color=color, label=mask_label, linewidth=3)
|
| 217 |
+
|
| 218 |
# Add prompt visualization with consistent styling
|
| 219 |
+
if prompt_info.get("box") is not None:
|
| 220 |
+
box = prompt_info["box"][0]
|
| 221 |
x1, y1, x2, y2 = box
|
| 222 |
+
plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], "g-", linewidth=2, label="Box Prompt")
|
| 223 |
+
|
| 224 |
+
if prompt_info.get("point") is not None:
|
| 225 |
+
point = prompt_info["point"][0]
|
| 226 |
+
plt.plot(point[0], point[1], "go", markersize=10, label="Point Prompt")
|
| 227 |
+
|
| 228 |
plt.title("Segmentation Overlay")
|
| 229 |
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
| 230 |
plt.axis("off")
|
| 231 |
+
|
| 232 |
# Save visualization with higher DPI like segmentation tool
|
| 233 |
viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
|
| 234 |
+
plt.savefig(viz_path, bbox_inches="tight", dpi=300)
|
| 235 |
plt.close()
|
| 236 |
+
|
| 237 |
return str(viz_path)
|
| 238 |
|
| 239 |
def _run(
|
|
|
|
| 248 |
try:
|
| 249 |
# Load image
|
| 250 |
image = self._load_image(image_path)
|
| 251 |
+
|
| 252 |
# Set image for predictor
|
| 253 |
self.predictor.set_image(image)
|
| 254 |
+
|
| 255 |
# Process prompts
|
| 256 |
+
input_box, input_point, input_label = self._process_prompts(prompt_type, prompt_coords, image.shape[:2])
|
| 257 |
+
|
|
|
|
|
|
|
| 258 |
# Run inference
|
| 259 |
if prompt_type == "auto":
|
| 260 |
# For auto segmentation, try multiple approaches and select best result
|
| 261 |
h, w = image.shape[:2]
|
| 262 |
+
|
| 263 |
# Try multiple points in key areas for medical images
|
| 264 |
+
sample_points = np.array(
|
| 265 |
+
[
|
| 266 |
+
[w // 3, h // 3], # Upper left lung area
|
| 267 |
+
[2 * w // 3, h // 3], # Upper right lung area
|
| 268 |
+
[w // 2, 2 * h // 3], # Lower center area
|
| 269 |
+
]
|
| 270 |
+
)
|
| 271 |
sample_labels = np.array([1, 1, 1]) # All positive points
|
| 272 |
+
|
| 273 |
masks, scores, logits = self.predictor.predict(
|
| 274 |
point_coords=sample_points,
|
| 275 |
point_labels=sample_labels,
|
|
|
|
| 282 |
box=input_box,
|
| 283 |
multimask_output=True,
|
| 284 |
)
|
| 285 |
+
|
| 286 |
# Create visualization
|
| 287 |
prompt_info = {
|
| 288 |
+
"box": input_box,
|
| 289 |
+
"point": input_point,
|
| 290 |
+
"type": prompt_type,
|
| 291 |
+
"scores": scores, # Add scores for legend display
|
| 292 |
}
|
| 293 |
viz_path = self._create_visualization(image, masks, prompt_info)
|
| 294 |
+
|
| 295 |
# Create output dictionary (main results)
|
| 296 |
output = {
|
| 297 |
"segmentation_image_path": viz_path,
|
| 298 |
+
"confidence_scores": scores.tolist() if hasattr(scores, "tolist") else list(scores),
|
| 299 |
"num_masks": len(masks),
|
| 300 |
"best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
|
| 301 |
"mask_summary": {
|
| 302 |
"total_masks": len(masks),
|
| 303 |
"mask_shapes": [list(mask.shape) for mask in masks],
|
| 304 |
+
"segmented_area_pixels": [int(mask.sum()) for mask in masks],
|
| 305 |
},
|
| 306 |
}
|
| 307 |
+
|
| 308 |
# Create metadata dictionary
|
| 309 |
metadata = {
|
| 310 |
"image_path": image_path,
|
|
|
|
| 316 |
"num_masks_generated": len(masks),
|
| 317 |
"analysis_status": "completed",
|
| 318 |
}
|
| 319 |
+
|
| 320 |
return output, metadata
|
| 321 |
+
|
| 322 |
except Exception as e:
|
| 323 |
error_output = {"error": str(e)}
|
| 324 |
error_metadata = {
|
|
|
|
| 337 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 338 |
) -> Tuple[Dict[str, Any], Dict]:
|
| 339 |
"""Async version of _run."""
|
| 340 |
+
return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)
|
medrax/tools/segmentation/segmentation.py
CHANGED
|
@@ -43,9 +43,7 @@ class OrganMetrics(BaseModel):
|
|
| 43 |
area_pixels: int = Field(..., description="Area in pixels")
|
| 44 |
area_cm2: float = Field(..., description="Approximate area in cm²")
|
| 45 |
centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid")
|
| 46 |
-
bbox: Tuple[int, int, int, int] = Field(
|
| 47 |
-
..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)"
|
| 48 |
-
)
|
| 49 |
|
| 50 |
# Size metrics
|
| 51 |
width: int = Field(..., description="Width of the organ in pixels")
|
|
@@ -53,9 +51,7 @@ class OrganMetrics(BaseModel):
|
|
| 53 |
aspect_ratio: float = Field(..., description="Height/width ratio")
|
| 54 |
|
| 55 |
# Position metrics
|
| 56 |
-
relative_position: Dict[str, float] = Field(
|
| 57 |
-
..., description="Position relative to image boundaries (0-1 scale)"
|
| 58 |
-
)
|
| 59 |
|
| 60 |
# Analysis metrics
|
| 61 |
mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region")
|
|
@@ -92,9 +88,7 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 92 |
self.model = self.model.to(self.device)
|
| 93 |
self.model.eval()
|
| 94 |
|
| 95 |
-
self.transform = torchvision.transforms.Compose(
|
| 96 |
-
[xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)]
|
| 97 |
-
)
|
| 98 |
|
| 99 |
self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir)
|
| 100 |
self.temp_dir.mkdir(exist_ok=True)
|
|
@@ -117,9 +111,7 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 117 |
"Spine": 13,
|
| 118 |
}
|
| 119 |
|
| 120 |
-
def _align_mask_to_original(
|
| 121 |
-
self, mask: np.ndarray, original_shape: Tuple[int, int]
|
| 122 |
-
) -> np.ndarray:
|
| 123 |
"""
|
| 124 |
Align a mask from the transformed (cropped/resized) space back to the full original image.
|
| 125 |
Assumes that the transform does a center crop to a square of side = min(original height, width)
|
|
@@ -172,23 +164,17 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 172 |
bbox=tuple(map(int, props.bbox)),
|
| 173 |
width=int(props.bbox[3] - props.bbox[1]),
|
| 174 |
height=int(props.bbox[2] - props.bbox[0]),
|
| 175 |
-
aspect_ratio=float(
|
| 176 |
-
(props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1])
|
| 177 |
-
),
|
| 178 |
relative_position=relative_pos,
|
| 179 |
mean_intensity=float(mean_intensity),
|
| 180 |
std_intensity=float(std_intensity),
|
| 181 |
confidence_score=float(confidence),
|
| 182 |
)
|
| 183 |
|
| 184 |
-
def _save_visualization(
|
| 185 |
-
self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]
|
| 186 |
-
) -> str:
|
| 187 |
"""Save visualization of original image with segmentation masks overlaid."""
|
| 188 |
plt.figure(figsize=(10, 10))
|
| 189 |
-
plt.imshow(
|
| 190 |
-
original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0]
|
| 191 |
-
)
|
| 192 |
|
| 193 |
# Generate color palette for organs
|
| 194 |
colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices)))
|
|
@@ -204,14 +190,10 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 204 |
# Create a colored overlay with transparency
|
| 205 |
colored_mask = np.zeros((*original_img.shape, 4))
|
| 206 |
colored_mask[mask > 0] = (*color[:3], 0.3)
|
| 207 |
-
plt.imshow(
|
| 208 |
-
colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0]
|
| 209 |
-
)
|
| 210 |
|
| 211 |
# Add legend entry for the organ
|
| 212 |
-
organ_name = list(self.organ_map.keys())[
|
| 213 |
-
list(self.organ_map.values()).index(organ_idx)
|
| 214 |
-
]
|
| 215 |
plt.plot([], [], color=color, label=organ_name, linewidth=3)
|
| 216 |
|
| 217 |
plt.title("Segmentation Overlay")
|
|
@@ -269,9 +251,7 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 269 |
for idx, organ_name in zip(organ_indices, organs):
|
| 270 |
mask = pred_masks[0, idx].cpu().numpy()
|
| 271 |
if mask.sum() > 0:
|
| 272 |
-
metrics = self._compute_organ_metrics(
|
| 273 |
-
mask, original_img, float(pred_probs[0, idx].mean().cpu())
|
| 274 |
-
)
|
| 275 |
if metrics:
|
| 276 |
results[organ_name] = metrics
|
| 277 |
|
|
|
|
| 43 |
area_pixels: int = Field(..., description="Area in pixels")
|
| 44 |
area_cm2: float = Field(..., description="Approximate area in cm²")
|
| 45 |
centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid")
|
| 46 |
+
bbox: Tuple[int, int, int, int] = Field(..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)")
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# Size metrics
|
| 49 |
width: int = Field(..., description="Width of the organ in pixels")
|
|
|
|
| 51 |
aspect_ratio: float = Field(..., description="Height/width ratio")
|
| 52 |
|
| 53 |
# Position metrics
|
| 54 |
+
relative_position: Dict[str, float] = Field(..., description="Position relative to image boundaries (0-1 scale)")
|
|
|
|
|
|
|
| 55 |
|
| 56 |
# Analysis metrics
|
| 57 |
mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region")
|
|
|
|
| 88 |
self.model = self.model.to(self.device)
|
| 89 |
self.model.eval()
|
| 90 |
|
| 91 |
+
self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)])
|
|
|
|
|
|
|
| 92 |
|
| 93 |
self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir)
|
| 94 |
self.temp_dir.mkdir(exist_ok=True)
|
|
|
|
| 111 |
"Spine": 13,
|
| 112 |
}
|
| 113 |
|
| 114 |
+
def _align_mask_to_original(self, mask: np.ndarray, original_shape: Tuple[int, int]) -> np.ndarray:
|
|
|
|
|
|
|
| 115 |
"""
|
| 116 |
Align a mask from the transformed (cropped/resized) space back to the full original image.
|
| 117 |
Assumes that the transform does a center crop to a square of side = min(original height, width)
|
|
|
|
| 164 |
bbox=tuple(map(int, props.bbox)),
|
| 165 |
width=int(props.bbox[3] - props.bbox[1]),
|
| 166 |
height=int(props.bbox[2] - props.bbox[0]),
|
| 167 |
+
aspect_ratio=float((props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1])),
|
|
|
|
|
|
|
| 168 |
relative_position=relative_pos,
|
| 169 |
mean_intensity=float(mean_intensity),
|
| 170 |
std_intensity=float(std_intensity),
|
| 171 |
confidence_score=float(confidence),
|
| 172 |
)
|
| 173 |
|
| 174 |
+
def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str:
|
|
|
|
|
|
|
| 175 |
"""Save visualization of original image with segmentation masks overlaid."""
|
| 176 |
plt.figure(figsize=(10, 10))
|
| 177 |
+
plt.imshow(original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0])
|
|
|
|
|
|
|
| 178 |
|
| 179 |
# Generate color palette for organs
|
| 180 |
colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices)))
|
|
|
|
| 190 |
# Create a colored overlay with transparency
|
| 191 |
colored_mask = np.zeros((*original_img.shape, 4))
|
| 192 |
colored_mask[mask > 0] = (*color[:3], 0.3)
|
| 193 |
+
plt.imshow(colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0])
|
|
|
|
|
|
|
| 194 |
|
| 195 |
# Add legend entry for the organ
|
| 196 |
+
organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)]
|
|
|
|
|
|
|
| 197 |
plt.plot([], [], color=color, label=organ_name, linewidth=3)
|
| 198 |
|
| 199 |
plt.title("Segmentation Overlay")
|
|
|
|
| 251 |
for idx, organ_name in zip(organ_indices, organs):
|
| 252 |
mask = pred_masks[0, idx].cpu().numpy()
|
| 253 |
if mask.sum() > 0:
|
| 254 |
+
metrics = self._compute_organ_metrics(mask, original_img, float(pred_probs[0, idx].mean().cpu()))
|
|
|
|
|
|
|
| 255 |
if metrics:
|
| 256 |
results[organ_name] = metrics
|
| 257 |
|
medrax/tools/utils.py
CHANGED
|
@@ -16,18 +16,10 @@ class ImageVisualizerInput(BaseModel):
|
|
| 16 |
|
| 17 |
image_path: str = Field(..., description="Path to the image file to display, only supports JPG or PNG images")
|
| 18 |
title: Optional[str] = Field(None, description="Optional title to display above the image")
|
| 19 |
-
description: Optional[str] = Field(
|
| 20 |
-
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
10, description="Optional figure width in inches"
|
| 24 |
-
)
|
| 25 |
-
height: Optional[int] = Field(
|
| 26 |
-
10, description="Optional figure height in inches"
|
| 27 |
-
)
|
| 28 |
-
cmap: Optional[str] = Field(
|
| 29 |
-
"rgb", description="Optional colormap to use for displaying the image"
|
| 30 |
-
)
|
| 31 |
|
| 32 |
|
| 33 |
class ImageVisualizerTool(BaseTool):
|
|
@@ -65,9 +57,7 @@ class ImageVisualizerTool(BaseTool):
|
|
| 65 |
|
| 66 |
# Add description if provided
|
| 67 |
if description:
|
| 68 |
-
plt.figtext(
|
| 69 |
-
0.5, 0.01, description, wrap=True, horizontalalignment="center", fontsize=10
|
| 70 |
-
)
|
| 71 |
|
| 72 |
# Adjust margins to minimize whitespace while preventing overlap
|
| 73 |
plt.subplots_adjust(top=0.95, bottom=0.05, left=0.05, right=0.95)
|
|
|
|
| 16 |
|
| 17 |
image_path: str = Field(..., description="Path to the image file to display, only supports JPG or PNG images")
|
| 18 |
title: Optional[str] = Field(None, description="Optional title to display above the image")
|
| 19 |
+
description: Optional[str] = Field(None, description="Optional description to display below the image")
|
| 20 |
+
width: Optional[int] = Field(10, description="Optional figure width in inches")
|
| 21 |
+
height: Optional[int] = Field(10, description="Optional figure height in inches")
|
| 22 |
+
cmap: Optional[str] = Field("rgb", description="Optional colormap to use for displaying the image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class ImageVisualizerTool(BaseTool):
|
|
|
|
| 57 |
|
| 58 |
# Add description if provided
|
| 59 |
if description:
|
| 60 |
+
plt.figtext(0.5, 0.01, description, wrap=True, horizontalalignment="center", fontsize=10)
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Adjust margins to minimize whitespace while preventing overlap
|
| 63 |
plt.subplots_adjust(top=0.95, bottom=0.05, left=0.05, right=0.95)
|
medrax/tools/vqa/__init__.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
"""Visual Question Answering tools for medical images."""
|
| 2 |
|
| 3 |
from .llava_med import LlavaMedTool, LlavaMedInput
|
| 4 |
-
from .xray_vqa import CheXagentXRayVQATool, XRayVQAToolInput
|
| 5 |
from .medgemma.medgemma_client import MedGemmaAPIClientTool, MedGemmaVQAInput
|
| 6 |
from .medgemma.medgemma_setup import setup_medgemma_env
|
| 7 |
|
| 8 |
__all__ = [
|
| 9 |
"LlavaMedTool",
|
| 10 |
"LlavaMedInput",
|
| 11 |
-
"CheXagentXRayVQATool",
|
| 12 |
"XRayVQAToolInput",
|
| 13 |
"MedGemmaAPIClientTool",
|
| 14 |
"MedGemmaVQAInput",
|
| 15 |
-
"setup_medgemma_env"
|
| 16 |
-
]
|
|
|
|
| 1 |
"""Visual Question Answering tools for medical images."""
|
| 2 |
|
| 3 |
from .llava_med import LlavaMedTool, LlavaMedInput
|
| 4 |
+
from .xray_vqa import CheXagentXRayVQATool, XRayVQAToolInput
|
| 5 |
from .medgemma.medgemma_client import MedGemmaAPIClientTool, MedGemmaVQAInput
|
| 6 |
from .medgemma.medgemma_setup import setup_medgemma_env
|
| 7 |
|
| 8 |
__all__ = [
|
| 9 |
"LlavaMedTool",
|
| 10 |
"LlavaMedInput",
|
| 11 |
+
"CheXagentXRayVQATool",
|
| 12 |
"XRayVQAToolInput",
|
| 13 |
"MedGemmaAPIClientTool",
|
| 14 |
"MedGemmaVQAInput",
|
| 15 |
+
"setup_medgemma_env",
|
| 16 |
+
]
|
medrax/tools/vqa/llava_med.py
CHANGED
|
@@ -84,13 +84,7 @@ class LlavaMedTool(BaseTool):
|
|
| 84 |
self, question: str, image_path: Optional[str] = None
|
| 85 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 86 |
if self.model.config.mm_use_im_start_end:
|
| 87 |
-
question =
|
| 88 |
-
DEFAULT_IM_START_TOKEN
|
| 89 |
-
+ DEFAULT_IMAGE_TOKEN
|
| 90 |
-
+ DEFAULT_IM_END_TOKEN
|
| 91 |
-
+ "\n"
|
| 92 |
-
+ question
|
| 93 |
-
)
|
| 94 |
else:
|
| 95 |
question = DEFAULT_IMAGE_TOKEN + "\n" + question
|
| 96 |
|
|
@@ -100,9 +94,7 @@ class LlavaMedTool(BaseTool):
|
|
| 100 |
prompt = conv.get_prompt()
|
| 101 |
|
| 102 |
input_ids = (
|
| 103 |
-
tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
| 104 |
-
.unsqueeze(0)
|
| 105 |
-
.cuda()
|
| 106 |
)
|
| 107 |
|
| 108 |
image_tensor = None
|
|
@@ -156,11 +148,11 @@ class LlavaMedTool(BaseTool):
|
|
| 156 |
)
|
| 157 |
|
| 158 |
answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 159 |
-
|
| 160 |
output = {
|
| 161 |
"answer": answer,
|
| 162 |
}
|
| 163 |
-
|
| 164 |
metadata = {
|
| 165 |
"question": question,
|
| 166 |
"image_path": image_path,
|
|
|
|
| 84 |
self, question: str, image_path: Optional[str] = None
|
| 85 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 86 |
if self.model.config.mm_use_im_start_end:
|
| 87 |
+
question = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + question
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
else:
|
| 89 |
question = DEFAULT_IMAGE_TOKEN + "\n" + question
|
| 90 |
|
|
|
|
| 94 |
prompt = conv.get_prompt()
|
| 95 |
|
| 96 |
input_ids = (
|
| 97 |
+
tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
|
|
|
|
|
|
|
| 98 |
)
|
| 99 |
|
| 100 |
image_tensor = None
|
|
|
|
| 148 |
)
|
| 149 |
|
| 150 |
answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 151 |
+
|
| 152 |
output = {
|
| 153 |
"answer": answer,
|
| 154 |
}
|
| 155 |
+
|
| 156 |
metadata = {
|
| 157 |
"question": question,
|
| 158 |
"image_path": image_path,
|
medrax/tools/vqa/medgemma/medgemma.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
from pathlib import Path
|
| 4 |
-
import sys
|
| 5 |
import traceback
|
| 6 |
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
import uuid
|
|
@@ -22,6 +21,7 @@ UPLOAD_DIR = "./medgemma_images"
|
|
| 22 |
# Create directories if they don't exist
|
| 23 |
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 24 |
|
|
|
|
| 25 |
# Pydantic Models for API
|
| 26 |
class VQAInput(BaseModel):
|
| 27 |
"""Input schema for the MedGemma VQA API endpoint.
|
|
@@ -100,7 +100,7 @@ class MedGemmaModel:
|
|
| 100 |
device: Optional[str] = "cuda",
|
| 101 |
dtype: torch.dtype = torch.bfloat16,
|
| 102 |
cache_dir: Optional[str] = None,
|
| 103 |
-
|
| 104 |
**kwargs: Any,
|
| 105 |
) -> None:
|
| 106 |
"""Initialize the MedGemmaModel.
|
|
@@ -110,7 +110,7 @@ class MedGemmaModel:
|
|
| 110 |
device: Device to run model on - "cuda" or "cpu" (default: "cuda")
|
| 111 |
dtype: Data type for model weights - bfloat16 recommended for efficiency (default: torch.bfloat16)
|
| 112 |
cache_dir: Directory to cache downloaded models (default: None)
|
| 113 |
-
|
| 114 |
**kwargs: Additional arguments passed to the model pipeline
|
| 115 |
|
| 116 |
Raises:
|
|
@@ -140,8 +140,8 @@ class MedGemmaModel:
|
|
| 140 |
"use_cache": True,
|
| 141 |
}
|
| 142 |
|
| 143 |
-
if
|
| 144 |
-
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 145 |
model_kwargs["device_map"] = {"": self.device}
|
| 146 |
|
| 147 |
try:
|
|
@@ -298,6 +298,12 @@ app = FastAPI(
|
|
| 298 |
)
|
| 299 |
|
| 300 |
medgemma_model: Optional[MedGemmaModel] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
@app.on_event("startup")
|
| 303 |
async def startup_event():
|
|
@@ -316,7 +322,32 @@ async def startup_event():
|
|
| 316 |
"""
|
| 317 |
global medgemma_model
|
| 318 |
try:
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
print("MedGemma model loaded successfully.")
|
| 321 |
except RuntimeError as e:
|
| 322 |
print(f"Error loading MedGemma model: {e}")
|
|
@@ -389,8 +420,12 @@ async def analyze_images(
|
|
| 389 |
raise HTTPException(status_code=500, detail=f"Failed to save uploaded image: {str(e)}")
|
| 390 |
|
| 391 |
try:
|
| 392 |
-
# Generate AI analysis
|
| 393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
# Prepare success response
|
| 396 |
metadata = {
|
|
@@ -428,7 +463,12 @@ async def analyze_images(
|
|
| 428 |
if __name__ == "__main__":
|
| 429 |
"""Launch the MedGemma VQA API server.
|
| 430 |
|
| 431 |
-
|
| 432 |
-
|
| 433 |
"""
|
| 434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
from pathlib import Path
|
|
|
|
| 4 |
import traceback
|
| 5 |
from typing import Any, Dict, List, Optional, Tuple
|
| 6 |
import uuid
|
|
|
|
| 21 |
# Create directories if they don't exist
|
| 22 |
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 23 |
|
| 24 |
+
|
| 25 |
# Pydantic Models for API
|
| 26 |
class VQAInput(BaseModel):
|
| 27 |
"""Input schema for the MedGemma VQA API endpoint.
|
|
|
|
| 100 |
device: Optional[str] = "cuda",
|
| 101 |
dtype: torch.dtype = torch.bfloat16,
|
| 102 |
cache_dir: Optional[str] = None,
|
| 103 |
+
load_in_8bit: bool = True,
|
| 104 |
**kwargs: Any,
|
| 105 |
) -> None:
|
| 106 |
"""Initialize the MedGemmaModel.
|
|
|
|
| 110 |
device: Device to run model on - "cuda" or "cpu" (default: "cuda")
|
| 111 |
dtype: Data type for model weights - bfloat16 recommended for efficiency (default: torch.bfloat16)
|
| 112 |
cache_dir: Directory to cache downloaded models (default: None)
|
| 113 |
+
load_in_8bit: Whether to load model in 4-bit quantization for memory efficiency (default: True)
|
| 114 |
**kwargs: Additional arguments passed to the model pipeline
|
| 115 |
|
| 116 |
Raises:
|
|
|
|
| 140 |
"use_cache": True,
|
| 141 |
}
|
| 142 |
|
| 143 |
+
if load_in_8bit:
|
| 144 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
| 145 |
model_kwargs["device_map"] = {"": self.device}
|
| 146 |
|
| 147 |
try:
|
|
|
|
| 298 |
)
|
| 299 |
|
| 300 |
medgemma_model: Optional[MedGemmaModel] = None
|
| 301 |
+
inference_semaphore: Optional[asyncio.Semaphore] = None
|
| 302 |
+
|
| 303 |
+
@app.get("/health")
|
| 304 |
+
async def health():
|
| 305 |
+
"""Health check endpoint."""
|
| 306 |
+
return {"status": "ok"}
|
| 307 |
|
| 308 |
@app.on_event("startup")
|
| 309 |
async def startup_event():
|
|
|
|
| 322 |
"""
|
| 323 |
global medgemma_model
|
| 324 |
try:
|
| 325 |
+
# Allow overriding Hugging Face cache directory and device via env vars
|
| 326 |
+
cache_dir_env = os.getenv("MEDGEMMA_CACHE_DIR")
|
| 327 |
+
device_env = os.getenv("MEDGEMMA_DEVICE")
|
| 328 |
+
max_concurrency_env = os.getenv("MEDGEMMA_MAX_CONCURRENCY", "1")
|
| 329 |
+
|
| 330 |
+
# Ensure the cache directory is writable; if not, fall back to a user cache
|
| 331 |
+
if cache_dir_env:
|
| 332 |
+
try:
|
| 333 |
+
os.makedirs(cache_dir_env, exist_ok=True)
|
| 334 |
+
if not os.access(cache_dir_env, os.W_OK):
|
| 335 |
+
raise PermissionError("Cache dir not writable")
|
| 336 |
+
except Exception:
|
| 337 |
+
fallback = os.path.join(Path.home(), ".cache", "medrax", "medgemma")
|
| 338 |
+
os.makedirs(fallback, exist_ok=True)
|
| 339 |
+
print(f"Warning: MEDGEMMA_CACHE_DIR '{cache_dir_env}' not writable. Falling back to '{fallback}'.")
|
| 340 |
+
cache_dir_env = fallback
|
| 341 |
+
|
| 342 |
+
medgemma_model = MedGemmaModel(cache_dir=cache_dir_env, device=device_env)
|
| 343 |
+
# Initialize concurrency gate
|
| 344 |
+
try:
|
| 345 |
+
max_concurrency = int(max_concurrency_env)
|
| 346 |
+
except ValueError:
|
| 347 |
+
max_concurrency = 1
|
| 348 |
+
max_concurrency = max(1, max_concurrency)
|
| 349 |
+
global inference_semaphore
|
| 350 |
+
inference_semaphore = asyncio.Semaphore(max_concurrency)
|
| 351 |
print("MedGemma model loaded successfully.")
|
| 352 |
except RuntimeError as e:
|
| 353 |
print(f"Error loading MedGemma model: {e}")
|
|
|
|
| 420 |
raise HTTPException(status_code=500, detail=f"Failed to save uploaded image: {str(e)}")
|
| 421 |
|
| 422 |
try:
|
| 423 |
+
# Generate AI analysis with concurrency gating to avoid GPU contention timeouts
|
| 424 |
+
global inference_semaphore
|
| 425 |
+
if inference_semaphore is None:
|
| 426 |
+
inference_semaphore = asyncio.Semaphore(1)
|
| 427 |
+
async with inference_semaphore:
|
| 428 |
+
response_text = await medgemma_model.aget_response(image_paths, prompt, system_prompt, max_new_tokens)
|
| 429 |
|
| 430 |
# Prepare success response
|
| 431 |
metadata = {
|
|
|
|
| 463 |
if __name__ == "__main__":
|
| 464 |
"""Launch the MedGemma VQA API server.
|
| 465 |
|
| 466 |
+
Reads MEDGEMMA_HOST and MEDGEMMA_PORT if provided; otherwise defaults
|
| 467 |
+
to 0.0.0.0:8002.
|
| 468 |
"""
|
| 469 |
+
host = os.getenv("MEDGEMMA_HOST", "0.0.0.0")
|
| 470 |
+
try:
|
| 471 |
+
port = int(os.getenv("MEDGEMMA_PORT", "8002"))
|
| 472 |
+
except ValueError:
|
| 473 |
+
port = 8002
|
| 474 |
+
uvicorn.run(app, host=host, port=port)
|
medrax/tools/vqa/medgemma/medgemma_client.py
CHANGED
|
@@ -59,15 +59,21 @@ class MedGemmaAPIClientTool(BaseTool):
|
|
| 59 |
|
| 60 |
# API configuration
|
| 61 |
api_url: str # The URL of the running FastAPI service
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
def __init__(self, api_url: str, **kwargs: Any):
|
| 64 |
"""Initialize the MedGemmaAPIClientTool.
|
| 65 |
|
| 66 |
Args:
|
| 67 |
api_url: The URL of the running MedGemma FastAPI service
|
|
|
|
|
|
|
|
|
|
| 68 |
**kwargs: Additional arguments passed to BaseTool
|
| 69 |
"""
|
| 70 |
-
super().__init__(api_url=api_url, **kwargs)
|
|
|
|
| 71 |
|
| 72 |
def _prepare_request_data(
|
| 73 |
self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int
|
|
@@ -154,7 +160,8 @@ class MedGemmaAPIClientTool(BaseTool):
|
|
| 154 |
Tuple of output dictionary and metadata
|
| 155 |
"""
|
| 156 |
# httpx is a modern HTTP client that supports sync and async
|
| 157 |
-
|
|
|
|
| 158 |
client = httpx.Client(timeout=timeout_config)
|
| 159 |
|
| 160 |
try:
|
|
@@ -238,11 +245,12 @@ class MedGemmaAPIClientTool(BaseTool):
|
|
| 238 |
image_paths, prompt, system_prompt, max_new_tokens
|
| 239 |
)
|
| 240 |
|
|
|
|
| 241 |
response = await client.post(
|
| 242 |
f"{self.api_url}/analyze-images/",
|
| 243 |
data=data,
|
| 244 |
files=files_to_send,
|
| 245 |
-
timeout=
|
| 246 |
)
|
| 247 |
response.raise_for_status()
|
| 248 |
|
|
|
|
| 59 |
|
| 60 |
# API configuration
|
| 61 |
api_url: str # The URL of the running FastAPI service
|
| 62 |
+
cache_dir: Optional[str] = None # Not used by the client directly, but accepted to keep a uniform constructor
|
| 63 |
+
device: Optional[str] = None
|
| 64 |
|
| 65 |
+
def __init__(self, api_url: str, cache_dir: Optional[str] = None, device: Optional[str] = None, timeout_seconds: Optional[float] = None, **kwargs: Any):
|
| 66 |
"""Initialize the MedGemmaAPIClientTool.
|
| 67 |
|
| 68 |
Args:
|
| 69 |
api_url: The URL of the running MedGemma FastAPI service
|
| 70 |
+
cache_dir: Optional local cache directory for model weights (accepted for interface consistency)
|
| 71 |
+
device: Optional device spec (accepted for interface consistency)
|
| 72 |
+
timeout_seconds: Optional request timeout override (seconds)
|
| 73 |
**kwargs: Additional arguments passed to BaseTool
|
| 74 |
"""
|
| 75 |
+
super().__init__(api_url=api_url, cache_dir=cache_dir, device=device, **kwargs)
|
| 76 |
+
self._timeout_seconds = timeout_seconds
|
| 77 |
|
| 78 |
def _prepare_request_data(
|
| 79 |
self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int
|
|
|
|
| 160 |
Tuple of output dictionary and metadata
|
| 161 |
"""
|
| 162 |
# httpx is a modern HTTP client that supports sync and async
|
| 163 |
+
timeout_value = self._timeout_seconds if self._timeout_seconds is not None else 600.0
|
| 164 |
+
timeout_config = httpx.Timeout(timeout_value, connect=10.0)
|
| 165 |
client = httpx.Client(timeout=timeout_config)
|
| 166 |
|
| 167 |
try:
|
|
|
|
| 245 |
image_paths, prompt, system_prompt, max_new_tokens
|
| 246 |
)
|
| 247 |
|
| 248 |
+
timeout_value = self._timeout_seconds if self._timeout_seconds is not None else 600.0
|
| 249 |
response = await client.post(
|
| 250 |
f"{self.api_url}/analyze-images/",
|
| 251 |
data=data,
|
| 252 |
files=files_to_send,
|
| 253 |
+
timeout=timeout_value
|
| 254 |
)
|
| 255 |
response.raise_for_status()
|
| 256 |
|
medrax/tools/vqa/medgemma/medgemma_requirements_standard.txt
CHANGED
|
@@ -52,4 +52,4 @@ typing_inspection==0.4.1
|
|
| 52 |
urllib3==2.5.0
|
| 53 |
uvicorn==0.35.0
|
| 54 |
wcwidth==0.2.13
|
| 55 |
-
zstandard==0.23.0
|
|
|
|
| 52 |
urllib3==2.5.0
|
| 53 |
uvicorn==0.35.0
|
| 54 |
wcwidth==0.2.13
|
| 55 |
+
zstandard==0.23.0
|
medrax/tools/vqa/medgemma/medgemma_setup.py
CHANGED
|
@@ -1,9 +1,61 @@
|
|
| 1 |
import os
|
| 2 |
from pathlib import Path
|
| 3 |
import subprocess
|
|
|
|
|
|
|
| 4 |
import venv
|
| 5 |
|
| 6 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
"""Set up MedGemma virtual environment and launch the FastAPI service.
|
| 8 |
|
| 9 |
This function performs the following steps:
|
|
@@ -53,12 +105,47 @@ def setup_medgemma_env():
|
|
| 53 |
if not env_dir.exists():
|
| 54 |
raise RuntimeError("Failed to create MedGemma virtual environment")
|
| 55 |
|
| 56 |
-
#
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
subprocess.Popen([
|
| 59 |
str(python_executable),
|
| 60 |
str(medgemma_path)
|
| 61 |
-
])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
# Note: stdout and stderr redirection commented out for debugging
|
| 63 |
# stdout=subprocess.DEVNULL,
|
| 64 |
# stderr=subprocess.DEVNULL,
|
|
|
|
| 1 |
import os
|
| 2 |
from pathlib import Path
|
| 3 |
import subprocess
|
| 4 |
+
import socket
|
| 5 |
+
from contextlib import closing
|
| 6 |
import venv
|
| 7 |
|
| 8 |
+
def _resolve_writable_cache_dir(preferred: str | None) -> str:
|
| 9 |
+
"""Return a writable cache directory, falling back to user cache if needed."""
|
| 10 |
+
# Preferred path first
|
| 11 |
+
if preferred:
|
| 12 |
+
try:
|
| 13 |
+
os.makedirs(preferred, exist_ok=True)
|
| 14 |
+
if os.access(preferred, os.W_OK):
|
| 15 |
+
return preferred
|
| 16 |
+
except Exception:
|
| 17 |
+
pass
|
| 18 |
+
# Fallback path under user's home
|
| 19 |
+
fallback = os.path.join(Path.home(), ".cache", "medrax", "medgemma")
|
| 20 |
+
os.makedirs(fallback, exist_ok=True)
|
| 21 |
+
return fallback
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _is_port_free(host: str, port: int) -> bool:
|
| 25 |
+
"""Return True if (host, port) is free to bind on this machine."""
|
| 26 |
+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
| 27 |
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 28 |
+
try:
|
| 29 |
+
sock.bind((host, port))
|
| 30 |
+
return True
|
| 31 |
+
except OSError:
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _find_free_loopback_and_port(start_octet: int = 2, end_octet: int = 254, base_port: int = 8002, max_port_tries: int = 50) -> tuple[str, int]:
|
| 36 |
+
"""Find a free 127.0.0.X address and port combination.
|
| 37 |
+
|
| 38 |
+
Tries 127.0.0.2..127.0.0.254 each with ports base_port..base_port+max_port_tries
|
| 39 |
+
until a free pair is found. Falls back to 127.0.0.1 if none found for other octets.
|
| 40 |
+
"""
|
| 41 |
+
# Try alternate loopback IPs first
|
| 42 |
+
for last_octet in range(start_octet, end_octet + 1):
|
| 43 |
+
host = f"127.0.0.{last_octet}"
|
| 44 |
+
for port in range(base_port, base_port + max_port_tries):
|
| 45 |
+
if _is_port_free(host, port):
|
| 46 |
+
return host, port
|
| 47 |
+
# Fallback: use 127.0.0.1 with port scan
|
| 48 |
+
host = "127.0.0.1"
|
| 49 |
+
for port in range(base_port, base_port + max_port_tries):
|
| 50 |
+
if _is_port_free(host, port):
|
| 51 |
+
return host, port
|
| 52 |
+
# Last resort: system-chosen ephemeral on 127.0.0.1
|
| 53 |
+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
| 54 |
+
sock.bind((host, 0))
|
| 55 |
+
return host, sock.getsockname()[1]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def setup_medgemma_env(cache_dir: str | None = None, device: str | None = None) -> str:
|
| 59 |
"""Set up MedGemma virtual environment and launch the FastAPI service.
|
| 60 |
|
| 61 |
This function performs the following steps:
|
|
|
|
| 105 |
if not env_dir.exists():
|
| 106 |
raise RuntimeError("Failed to create MedGemma virtual environment")
|
| 107 |
|
| 108 |
+
# Decide host/port to avoid collisions when multiple instances run
|
| 109 |
+
medgemma_host = os.getenv("MEDGEMMA_HOST")
|
| 110 |
+
medgemma_port_env = os.getenv("MEDGEMMA_PORT")
|
| 111 |
+
chosen_host: str
|
| 112 |
+
chosen_port: int
|
| 113 |
+
if medgemma_host and medgemma_port_env:
|
| 114 |
+
try:
|
| 115 |
+
port_val = int(medgemma_port_env)
|
| 116 |
+
except ValueError:
|
| 117 |
+
port_val = 8002
|
| 118 |
+
# If explicit host/port are provided, prefer them; if taken, try incrementing the port on the same host
|
| 119 |
+
chosen_host = medgemma_host
|
| 120 |
+
chosen_port = None
|
| 121 |
+
for p in range(port_val, port_val + 50):
|
| 122 |
+
if _is_port_free(medgemma_host, p):
|
| 123 |
+
chosen_port = p
|
| 124 |
+
break
|
| 125 |
+
if chosen_port is None:
|
| 126 |
+
print(f"No free ports in range {port_val}-{port_val+49} on {medgemma_host}; selecting a free loopback IP/port...")
|
| 127 |
+
chosen_host, chosen_port = _find_free_loopback_and_port()
|
| 128 |
+
else:
|
| 129 |
+
# Auto-pick a free loopback IP and port
|
| 130 |
+
chosen_host, chosen_port = _find_free_loopback_and_port()
|
| 131 |
+
|
| 132 |
+
print(f"Launching MedGemma FastAPI service on {chosen_host}:{chosen_port} ...")
|
| 133 |
+
env = os.environ.copy()
|
| 134 |
+
resolved_cache = _resolve_writable_cache_dir(cache_dir)
|
| 135 |
+
env["MEDGEMMA_CACHE_DIR"] = resolved_cache
|
| 136 |
+
if device:
|
| 137 |
+
env["MEDGEMMA_DEVICE"] = device
|
| 138 |
+
# Pass the chosen binding to the server via env
|
| 139 |
+
env["MEDGEMMA_HOST"] = chosen_host
|
| 140 |
+
env["MEDGEMMA_PORT"] = str(chosen_port)
|
| 141 |
subprocess.Popen([
|
| 142 |
str(python_executable),
|
| 143 |
str(medgemma_path)
|
| 144 |
+
], env=env)
|
| 145 |
+
|
| 146 |
+
# Return the base URL so callers can use it. If bound to 0.0.0.0, use 127.0.0.1 for local client access.
|
| 147 |
+
chosen_client_host = "127.0.0.1" if chosen_host in ("0.0.0.0", "::") else chosen_host
|
| 148 |
+
return f"http://{chosen_client_host}:{chosen_port}"
|
| 149 |
# Note: stdout and stderr redirection commented out for debugging
|
| 150 |
# stdout=subprocess.DEVNULL,
|
| 151 |
# stderr=subprocess.DEVNULL,
|
medrax/tools/vqa/xray_vqa.py
CHANGED
|
@@ -15,13 +15,9 @@ from langchain_core.tools import BaseTool
|
|
| 15 |
class XRayVQAToolInput(BaseModel):
|
| 16 |
"""Input schema for the CheXagent Tool."""
|
| 17 |
|
| 18 |
-
image_paths: List[str] = Field(
|
| 19 |
-
..., description="List of paths to chest X-ray images to analyze"
|
| 20 |
-
)
|
| 21 |
prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
|
| 22 |
-
max_new_tokens: int = Field(
|
| 23 |
-
512, description="Maximum number of tokens to generate in the response"
|
| 24 |
-
)
|
| 25 |
|
| 26 |
|
| 27 |
class CheXagentXRayVQATool(BaseTool):
|
|
@@ -99,16 +95,14 @@ class CheXagentXRayVQATool(BaseTool):
|
|
| 99 |
Returns:
|
| 100 |
str: Model's response
|
| 101 |
"""
|
| 102 |
-
query = self.tokenizer.from_list_format(
|
| 103 |
-
[*[{"image": path} for path in image_paths], {"text": prompt}]
|
| 104 |
-
)
|
| 105 |
conv = [
|
| 106 |
{"from": "system", "value": "You are a helpful assistant."},
|
| 107 |
{"from": "human", "value": query},
|
| 108 |
]
|
| 109 |
-
input_ids = self.tokenizer.apply_chat_template(
|
| 110 |
-
|
| 111 |
-
)
|
| 112 |
|
| 113 |
# Run inference
|
| 114 |
with torch.inference_mode():
|
|
|
|
| 15 |
class XRayVQAToolInput(BaseModel):
|
| 16 |
"""Input schema for the CheXagent Tool."""
|
| 17 |
|
| 18 |
+
image_paths: List[str] = Field(..., description="List of paths to chest X-ray images to analyze")
|
|
|
|
|
|
|
| 19 |
prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
|
| 20 |
+
max_new_tokens: int = Field(512, description="Maximum number of tokens to generate in the response")
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
class CheXagentXRayVQATool(BaseTool):
|
|
|
|
| 95 |
Returns:
|
| 96 |
str: Model's response
|
| 97 |
"""
|
| 98 |
+
query = self.tokenizer.from_list_format([*[{"image": path} for path in image_paths], {"text": prompt}])
|
|
|
|
|
|
|
| 99 |
conv = [
|
| 100 |
{"from": "system", "value": "You are a helpful assistant."},
|
| 101 |
{"from": "human", "value": query},
|
| 102 |
]
|
| 103 |
+
input_ids = self.tokenizer.apply_chat_template(conv, add_generation_prompt=True, return_tensors="pt").to(
|
| 104 |
+
device=self.device
|
| 105 |
+
)
|
| 106 |
|
| 107 |
# Run inference
|
| 108 |
with torch.inference_mode():
|
medrax/tools/xray_generation.py
CHANGED
|
@@ -11,26 +11,15 @@ from langchain_core.tools import BaseTool
|
|
| 11 |
|
| 12 |
class ChestXRayGeneratorInput(BaseModel):
|
| 13 |
"""Input schema for the Chest X-Ray Generator Tool."""
|
| 14 |
-
|
| 15 |
prompt: str = Field(
|
| 16 |
-
...,
|
| 17 |
-
description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')"
|
| 18 |
-
)
|
| 19 |
-
height: int = Field(
|
| 20 |
-
512,
|
| 21 |
-
description="Height of generated image in pixels"
|
| 22 |
-
)
|
| 23 |
-
width: int = Field(
|
| 24 |
-
512,
|
| 25 |
-
description="Width of generated image in pixels"
|
| 26 |
-
)
|
| 27 |
-
num_inference_steps: int = Field(
|
| 28 |
-
75,
|
| 29 |
-
description="Number of denoising steps (higher = better quality but slower)"
|
| 30 |
)
|
|
|
|
|
|
|
|
|
|
| 31 |
guidance_scale: float = Field(
|
| 32 |
-
4.0,
|
| 33 |
-
description="How closely to follow the prompt (higher = more faithful but less diverse)"
|
| 34 |
)
|
| 35 |
|
| 36 |
|
|
@@ -60,11 +49,11 @@ class ChestXRayGeneratorTool(BaseTool):
|
|
| 60 |
):
|
| 61 |
"""Initialize the chest X-ray generator tool."""
|
| 62 |
super().__init__()
|
| 63 |
-
|
| 64 |
self.device = torch.device(device) if device else "cuda"
|
| 65 |
self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir)
|
| 66 |
self.model = self.model.to(torch.float32).to(self.device)
|
| 67 |
-
|
| 68 |
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
|
| 69 |
self.temp_dir.mkdir(exist_ok=True)
|
| 70 |
|
|
@@ -97,7 +86,7 @@ class ChestXRayGeneratorTool(BaseTool):
|
|
| 97 |
num_inference_steps=num_inference_steps,
|
| 98 |
height=height,
|
| 99 |
width=width,
|
| 100 |
-
guidance_scale=guidance_scale
|
| 101 |
)
|
| 102 |
|
| 103 |
# Save generated image
|
|
@@ -107,7 +96,7 @@ class ChestXRayGeneratorTool(BaseTool):
|
|
| 107 |
output = {
|
| 108 |
"image_path": str(image_path),
|
| 109 |
}
|
| 110 |
-
|
| 111 |
metadata = {
|
| 112 |
"prompt": prompt,
|
| 113 |
"num_inference_steps": num_inference_steps,
|
|
@@ -126,7 +115,7 @@ class ChestXRayGeneratorTool(BaseTool):
|
|
| 126 |
"prompt": prompt,
|
| 127 |
"analysis_status": "failed",
|
| 128 |
"error_details": str(e),
|
| 129 |
-
}
|
| 130 |
)
|
| 131 |
|
| 132 |
async def _arun(
|
|
@@ -139,4 +128,4 @@ class ChestXRayGeneratorTool(BaseTool):
|
|
| 139 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 140 |
) -> Tuple[Dict[str, str], Dict]:
|
| 141 |
"""Async version of _run."""
|
| 142 |
-
return self._run(prompt, num_inference_steps, guidance_scale, height, width)
|
|
|
|
| 11 |
|
| 12 |
class ChestXRayGeneratorInput(BaseModel):
|
| 13 |
"""Input schema for the Chest X-Ray Generator Tool."""
|
| 14 |
+
|
| 15 |
prompt: str = Field(
|
| 16 |
+
..., description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
+
height: int = Field(512, description="Height of generated image in pixels")
|
| 19 |
+
width: int = Field(512, description="Width of generated image in pixels")
|
| 20 |
+
num_inference_steps: int = Field(75, description="Number of denoising steps (higher = better quality but slower)")
|
| 21 |
guidance_scale: float = Field(
|
| 22 |
+
4.0, description="How closely to follow the prompt (higher = more faithful but less diverse)"
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
|
|
|
|
| 49 |
):
|
| 50 |
"""Initialize the chest X-ray generator tool."""
|
| 51 |
super().__init__()
|
| 52 |
+
|
| 53 |
self.device = torch.device(device) if device else "cuda"
|
| 54 |
self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir)
|
| 55 |
self.model = self.model.to(torch.float32).to(self.device)
|
| 56 |
+
|
| 57 |
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
|
| 58 |
self.temp_dir.mkdir(exist_ok=True)
|
| 59 |
|
|
|
|
| 86 |
num_inference_steps=num_inference_steps,
|
| 87 |
height=height,
|
| 88 |
width=width,
|
| 89 |
+
guidance_scale=guidance_scale,
|
| 90 |
)
|
| 91 |
|
| 92 |
# Save generated image
|
|
|
|
| 96 |
output = {
|
| 97 |
"image_path": str(image_path),
|
| 98 |
}
|
| 99 |
+
|
| 100 |
metadata = {
|
| 101 |
"prompt": prompt,
|
| 102 |
"num_inference_steps": num_inference_steps,
|
|
|
|
| 115 |
"prompt": prompt,
|
| 116 |
"analysis_status": "failed",
|
| 117 |
"error_details": str(e),
|
| 118 |
+
},
|
| 119 |
)
|
| 120 |
|
| 121 |
async def _arun(
|
|
|
|
| 128 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 129 |
) -> Tuple[Dict[str, str], Dict]:
|
| 130 |
"""Async version of _run."""
|
| 131 |
+
return self._run(prompt, num_inference_steps, guidance_scale, height, width)
|
pyproject.toml
CHANGED
|
@@ -14,15 +14,15 @@ requires-python = ">=3.12"
|
|
| 14 |
dependencies = [
|
| 15 |
"requests>=2.25.0",
|
| 16 |
"numpy>=1.19.0",
|
| 17 |
-
"langchain>=0.
|
| 18 |
-
"langchain-core>=0.
|
| 19 |
"langchain-community>=0.0.20",
|
| 20 |
-
"langchain-openai>=0.
|
| 21 |
-
"langchain-cohere>=0.3.
|
| 22 |
-
"langchain-anthropic>=0.
|
| 23 |
-
"langchain-xai>=0.
|
| 24 |
-
"langchain-chroma>=0.
|
| 25 |
-
"langgraph>=0.
|
| 26 |
"hydra-core>=1.1.0",
|
| 27 |
"python-dotenv>=0.19.0",
|
| 28 |
"pandas>=1.5.0",
|
|
@@ -46,8 +46,9 @@ dependencies = [
|
|
| 46 |
"gradio>=3.0.0",
|
| 47 |
"gradio_client>=0.2.0",
|
| 48 |
"httpx>=0.23.0",
|
| 49 |
-
"uvicorn>=0.15.0",
|
| 50 |
"fastapi>=0.68.0",
|
|
|
|
| 51 |
"einops>=0.3.0",
|
| 52 |
"einops-exts>=0.0.4",
|
| 53 |
"timm==0.5.4",
|
|
@@ -73,6 +74,7 @@ dependencies = [
|
|
| 73 |
"huggingface_hub>=0.17.0",
|
| 74 |
"iopath>=0.1.10",
|
| 75 |
"duckduckgo-search>=4.0.0",
|
|
|
|
| 76 |
]
|
| 77 |
|
| 78 |
[project.optional-dependencies]
|
|
|
|
| 14 |
dependencies = [
|
| 15 |
"requests>=2.25.0",
|
| 16 |
"numpy>=1.19.0",
|
| 17 |
+
"langchain>=0.3.26",
|
| 18 |
+
"langchain-core>=0.3.68",
|
| 19 |
"langchain-community>=0.0.20",
|
| 20 |
+
"langchain-openai>=0.3.27",
|
| 21 |
+
"langchain-cohere>=0.3.5",
|
| 22 |
+
"langchain-anthropic>=0.3.17",
|
| 23 |
+
"langchain-xai>=0.2.4",
|
| 24 |
+
"langchain-chroma>=0.2.4",
|
| 25 |
+
"langgraph>=0.5.1",
|
| 26 |
"hydra-core>=1.1.0",
|
| 27 |
"python-dotenv>=0.19.0",
|
| 28 |
"pandas>=1.5.0",
|
|
|
|
| 46 |
"gradio>=3.0.0",
|
| 47 |
"gradio_client>=0.2.0",
|
| 48 |
"httpx>=0.23.0",
|
| 49 |
+
"uvicorn[standard]>=0.15.0",
|
| 50 |
"fastapi>=0.68.0",
|
| 51 |
+
"python-multipart>=0.0.6",
|
| 52 |
"einops>=0.3.0",
|
| 53 |
"einops-exts>=0.0.4",
|
| 54 |
"timm==0.5.4",
|
|
|
|
| 74 |
"huggingface_hub>=0.17.0",
|
| 75 |
"iopath>=0.1.10",
|
| 76 |
"duckduckgo-search>=4.0.0",
|
| 77 |
+
"pyngrok>=7.0.0",
|
| 78 |
]
|
| 79 |
|
| 80 |
[project.optional-dependencies]
|