victorli commited on
Commit
7b3e756
·
1 Parent(s): fbcbf94

fixed medgemma

Browse files
.gitignore CHANGED
@@ -180,4 +180,5 @@ model-weights/
180
  .DS_Store
181
 
182
  benchmarking/data/
183
- model_cache/
 
 
180
  .DS_Store
181
 
182
  benchmarking/data/
183
+ model_cache/
184
+ medgemma/
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -33,12 +33,12 @@ class MedRAXProvider(LLMProvider):
33
  print("Starting server...")
34
 
35
  selected_tools = [
36
- # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
37
- # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
38
- # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
 
 
39
  # "XRayVQATool", # For visual question answering on X-rays
40
- "MedGemmaVQATool"
41
- # "XRayPhraseGroundingTool", # For locating described features in X-rays
42
  # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
43
  # "WebBrowserTool", # For web browsing and search capabilities
44
  # "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
 
33
  print("Starting server...")
34
 
35
  selected_tools = [
36
+ "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
37
+ "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
38
+ "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
39
+ "XRayPhraseGroundingTool", # For locating described features in X-rays
40
+ "MedGemmaVQATool",
41
  # "XRayVQATool", # For visual question answering on X-rays
 
 
42
  # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
43
  # "WebBrowserTool", # For web browsing and search capabilities
44
  # "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
main.py CHANGED
@@ -68,7 +68,7 @@ def initialize_agent(
68
  prompt = prompts[system_prompt]
69
 
70
  # Define the URL of the MedGemma FastAPI service.
71
- MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://127.0.0.1:8002")
72
 
73
  all_tools = {
74
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
@@ -98,20 +98,20 @@ def initialize_agent(
98
 
99
  # Initialize only selected tools or all if none specified
100
  tools_dict: Dict[str, BaseTool] = {}
 
101
  if tools_to_use is None:
102
  tools_to_use = []
 
103
  for tool_name in tools_to_use:
104
  if tool_name == "PythonSandboxTool":
105
- continue
 
 
 
 
106
  if tool_name in all_tools:
107
  tools_dict[tool_name] = all_tools[tool_name]()
108
-
109
- # Try to create the PythonSandboxTool
110
- try:
111
- tools_dict["PythonSandboxTool"] = create_python_sandbox()
112
- except Exception as e:
113
- print(f"Error creating PythonSandboxTool: {e}")
114
- print("Skipping PythonSandboxTool")
115
 
116
  # Set up checkpointing for conversation state
117
  checkpointer = MemorySaver()
 
68
  prompt = prompts[system_prompt]
69
 
70
  # Define the URL of the MedGemma FastAPI service.
71
+ MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://172.17.8.141:8002")
72
 
73
  all_tools = {
74
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
 
98
 
99
  # Initialize only selected tools or all if none specified
100
  tools_dict: Dict[str, BaseTool] = {}
101
+
102
  if tools_to_use is None:
103
  tools_to_use = []
104
+
105
  for tool_name in tools_to_use:
106
  if tool_name == "PythonSandboxTool":
107
+ try:
108
+ tools_dict["PythonSandboxTool"] = create_python_sandbox()
109
+ except Exception as e:
110
+ print(f"Error creating PythonSandboxTool: {e}")
111
+ print("Skipping PythonSandboxTool")
112
  if tool_name in all_tools:
113
  tools_dict[tool_name] = all_tools[tool_name]()
114
+
 
 
 
 
 
 
115
 
116
  # Set up checkpointing for conversation state
117
  checkpointer = MemorySaver()
medrax/tools/medgemma.py DELETED
@@ -1,225 +0,0 @@
1
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
- from pydantic import BaseModel, Field
3
- from typing import List, Optional, Any, Dict, Tuple
4
- from pathlib import Path
5
- import torch
6
- from PIL import Image
7
- from transformers import pipeline, BitsAndBytesConfig
8
- import asyncio
9
- import uvicorn
10
- import os
11
- import uuid
12
- import traceback
13
- import sys
14
- import transformers
15
-
16
- print("--- ENVIRONMENT CHECK ---")
17
- print(f"Python Executable: {sys.executable}")
18
- print(f"PyTorch version: {torch.__version__}")
19
- print(f"Transformers version: {transformers.__version__}")
20
- print("-----------------------")
21
-
22
- # --- Configuration ---
23
- CACHE_DIR = "./model_cache"
24
- UPLOAD_DIR = "./uploaded_images"
25
-
26
- # Create directories if they don't exist
27
- os.makedirs(CACHE_DIR, exist_ok=True)
28
- os.makedirs(UPLOAD_DIR, exist_ok=True)
29
-
30
- # --- Pydantic Models for API ---
31
- class VQAInput(BaseModel):
32
- prompt: str = Field(..., description="Question or instruction about the medical images")
33
- system_prompt: Optional[str] = Field(
34
- "You are an expert radiologist.",
35
- description="System prompt to set the context for the model",
36
- )
37
- max_new_tokens: int = Field(
38
- 300, description="Maximum number of tokens to generate in the response"
39
- )
40
-
41
- class VQAResponse(BaseModel):
42
- response: str
43
- metadata: Dict[str, Any]
44
-
45
- class ErrorResponse(BaseModel):
46
- error: str
47
- metadata: Dict[str, Any]
48
-
49
- # --- MedGemma Model Handling ---
50
- class MedGemmaModel:
51
- _instance = None
52
-
53
- def __new__(cls, *args, **kwargs):
54
- if not cls._instance:
55
- cls._instance = super(MedGemmaModel, cls).__new__(cls)
56
- return cls._instance
57
-
58
- def __init__(self,
59
- model_name: str = "google/medgemma-4b-it",
60
- device: Optional[str] = "cuda",
61
- dtype: torch.dtype = torch.bfloat16,
62
- load_in_4bit: bool = False):
63
- if hasattr(self, 'pipe') and self.pipe is not None:
64
- return
65
-
66
- self.device = device if device and torch.cuda.is_available() else "cpu"
67
- self.dtype = dtype
68
- self.pipe = None
69
-
70
- model_kwargs = {"torch_dtype": self.dtype, "cache_dir": CACHE_DIR}
71
-
72
- if load_in_4bit:
73
- model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
74
- model_kwargs["device_map"] = {"": self.device}
75
-
76
- try:
77
- self.pipe = pipeline("image-text-to-text",
78
- model=model_name,
79
- model_kwargs=model_kwargs,
80
- trust_remote_code=True,
81
- use_cache=True)
82
- except Exception as e:
83
- raise RuntimeError(f"Failed to initialize MedGemma pipeline: {str(e)}")
84
-
85
- def _prepare_messages(
86
- self, image_paths: List[str], prompt: str, system_prompt: str
87
- ) -> Tuple[List[Dict[str, Any]], List[Image.Image]]:
88
- images = []
89
- for path in image_paths:
90
- if not Path(path).is_file():
91
- raise FileNotFoundError(f"Image file not found: {path}")
92
-
93
- image = Image.open(path)
94
- if image.mode != "RGB":
95
- image = image.convert("RGB")
96
- images.append(image)
97
-
98
- messages = [
99
- {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
100
- {
101
- "role": "user",
102
- "content": [{"type": "text", "text": prompt}]
103
- + [{"type": "image", "image": img} for img in images],
104
- },
105
- ]
106
-
107
- return messages, images
108
-
109
- async def aget_response(self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int) -> str:
110
- loop = asyncio.get_event_loop()
111
- messages, _ = await loop.run_in_executor(None, self._prepare_messages, image_paths, prompt, system_prompt)
112
-
113
- def _generate():
114
- return self.pipe(
115
- text=messages,
116
- max_new_tokens=max_new_tokens,
117
- do_sample=False,
118
- )
119
-
120
- output = await loop.run_in_executor(None, _generate)
121
-
122
- if (
123
- isinstance(output, list)
124
- and output
125
- and isinstance(output[0].get("generated_text"), list)
126
- ):
127
- generated_text = output[0]["generated_text"]
128
- if generated_text:
129
- return generated_text[-1].get("content", "").strip()
130
-
131
- return "No response generated"
132
-
133
- # --- FastAPI Application ---
134
- app = FastAPI(title="MedGemma VQA API",
135
- description="API for medical visual question answering using Google's MedGemma model.")
136
-
137
- medgemma_model: Optional[MedGemmaModel] = None
138
-
139
- @app.on_event("startup")
140
- async def startup_event():
141
- """Load the MedGemma model at application startup."""
142
- global medgemma_model
143
- try:
144
- medgemma_model = MedGemmaModel()
145
- print("MedGemma model loaded successfully.")
146
- except RuntimeError as e:
147
- print(f"Error loading MedGemma model: {e}")
148
- # Depending on the desired behavior, you might want to exit the application
149
- # if the model fails to load.
150
- # exit(1)
151
-
152
- @app.post("/analyze-images/",
153
- response_model=VQAResponse,
154
- responses={500: {"model": ErrorResponse},
155
- 404: {"model": ErrorResponse}},
156
- summary="Analyze one or more medical images")
157
- async def analyze_images(
158
- images: List[UploadFile] = File(..., description="List of medical image files to analyze (JPG or PNG)."),
159
- prompt: str = Form(..., description="Question or instruction about the medical images."),
160
- system_prompt: Optional[str] = Form("You are an expert radiologist.", description="System prompt to set the context for the model."),
161
- max_new_tokens: int = Form(100, description="Maximum number of tokens to generate in the response.")
162
- ):
163
- """
164
- Upload one or more medical images and a prompt to get an analysis from the MedGemma model.
165
- """
166
- if medgemma_model is None or medgemma_model.pipe is None:
167
- raise HTTPException(status_code=503, detail="Model is not available. Please try again later.")
168
-
169
- image_paths = []
170
- for image in images:
171
- if image.content_type not in ["image/jpeg", "image/png"]:
172
- raise HTTPException(status_code=400, detail=f"Unsupported image format: {image.content_type}. Only JPG and PNG are supported.")
173
-
174
- # Generate a unique filename to avoid overwrites
175
- unique_filename = f"{uuid.uuid4()}_{image.filename}"
176
- file_path = os.path.join(UPLOAD_DIR, unique_filename)
177
-
178
- try:
179
- with open(file_path, "wb") as buffer:
180
- buffer.write(await image.read())
181
- image_paths.append(file_path)
182
- except Exception as e:
183
- raise HTTPException(status_code=500, detail=f"Failed to save uploaded image: {str(e)}")
184
-
185
-
186
- try:
187
- response_text = await medgemma_model.aget_response(image_paths, prompt, system_prompt, max_new_tokens)
188
- metadata = {
189
- "image_paths": image_paths,
190
- "prompt": prompt,
191
- "system_prompt": system_prompt,
192
- "max_new_tokens": max_new_tokens,
193
- "num_images": len(image_paths),
194
- "analysis_status": "completed",
195
- }
196
- return VQAResponse(response=response_text, metadata=metadata)
197
- except FileNotFoundError as e:
198
- raise HTTPException(status_code=404, detail=f"Image file not found: {str(e)}")
199
- except Exception as e:
200
- print("--- AN EXCEPTION OCCURRED IN THE ENDPOINT ---")
201
- traceback.print_exc()
202
- # Catch potential CUDA out-of-memory errors and other exceptions
203
- error_message = "An unexpected error occurred during analysis."
204
- if "CUDA out of memory" in str(e):
205
- error_message = "GPU memory exhausted. Try reducing image resolution or max_new_tokens."
206
-
207
- metadata = {
208
- "image_paths": image_paths,
209
- "prompt": prompt,
210
- "analysis_status": "failed",
211
- "error_details": str(e),
212
- }
213
- raise HTTPException(status_code=500, detail=error_message)
214
- finally:
215
- # Clean up saved images
216
- for path in image_paths:
217
- try:
218
- os.remove(path)
219
- except OSError:
220
- # Log this error if needed, but don't let it crash the request
221
- pass
222
-
223
-
224
- if __name__ == "__main__":
225
- uvicorn.run(app, host="0.0.0.0", port=8002)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
medrax/tools/medgemma_client.py DELETED
@@ -1,145 +0,0 @@
1
- import httpx
2
- from typing import Dict, List, Optional, Type, Any
3
- from langchain_core.tools import BaseTool
4
- from langchain_core.callbacks import (
5
- AsyncCallbackManagerForToolRun,
6
- CallbackManagerForToolRun,
7
- )
8
- from pydantic import BaseModel, Field
9
- import os
10
-
11
- # This input schema should be identical to the one in your original tool
12
- class MedGemmaVQAInput(BaseModel):
13
- """Input schema for the MedGemma VQA Tool. The agent provides local paths to images."""
14
- image_paths: List[str] = Field(
15
- ...,
16
- description="List of paths to medical image files to analyze. These are local paths accessible to the agent.",
17
- )
18
- prompt: str = Field(..., description="Question or instruction about the medical images")
19
- system_prompt: Optional[str] = Field(
20
- "You are an expert radiologist.",
21
- description="System prompt to set the context for the model",
22
- )
23
- max_new_tokens: int = Field(
24
- 300, description="Maximum number of tokens to generate in the response"
25
- )
26
-
27
- class MedGemmaAPIClientTool(BaseTool):
28
- """
29
- A client tool to interact with a remote MedGemma VQA FastAPI service.
30
- This tool takes local image paths, reads them, and sends them to the API endpoint
31
- for analysis.
32
- """
33
- name: str = "medgemma_medical_vqa_service"
34
- description: str = (
35
- "Sends medical images and a prompt to a specialized MedGemma VQA service for analysis. "
36
- "Use this for expert-level reasoning, diagnosis assistance, and detailed image interpretation "
37
- "across modalities like chest X-rays, dermatology, etc. Input must be local image paths and a prompt."
38
- )
39
- args_schema: Type[BaseModel] = MedGemmaVQAInput
40
- api_url: str # The URL of the running FastAPI service
41
-
42
- def _run(
43
- self,
44
- image_paths: List[str],
45
- prompt: str,
46
- system_prompt: str = "You are an expert radiologist.",
47
- max_new_tokens: int = 300,
48
- run_manager: Optional[CallbackManagerForToolRun] = None,
49
- ) -> str:
50
- """Execute the tool synchronously."""
51
- # httpx is a modern HTTP client that supports sync and async
52
- timeout_config = httpx.Timeout(300.0, connect=10.0)
53
- client = httpx.Client(timeout=timeout_config)
54
-
55
- # Prepare the multipart form data
56
- files_to_send = []
57
- opened_files = []
58
- try:
59
- for path in image_paths:
60
- f = open(path, "rb")
61
- opened_files.append(f)
62
- # The key 'images' must match the parameter name in the FastAPI endpoint
63
- files_to_send.append(("images", (os.path.basename(path), f, "image/jpeg")))
64
-
65
- data = {
66
- "prompt": prompt,
67
- "system_prompt": system_prompt,
68
- "max_new_tokens": max_new_tokens,
69
- }
70
-
71
- response = client.post(
72
- f"{self.api_url}/analyze-images/",
73
- data=data,
74
- files=files_to_send,
75
- )
76
- response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
77
-
78
- # The agent expects a string response from a tool
79
- return response.json()["response"]
80
-
81
- # --- KEY FIX 3: More specific exception handling for clearer errors ---
82
- except httpx.TimeoutException:
83
- return f"Error: The request to the MedGemma API timed out after {timeout_config.read} seconds. The server might be overloaded or the model is taking too long to load. Try again later."
84
- except httpx.ConnectError:
85
- return f"Error: Could not connect to the MedGemma API. Check if the server address '{self.api_url}' is correct and running."
86
- except httpx.HTTPStatusError as e:
87
- return f"Error: The MedGemma API returned an error (Status {e.response.status_code}): {e.response.text}"
88
- except Exception as e:
89
- return f"An unexpected error occurred in the MedGemma client tool: {str(e)}"
90
- finally:
91
- # Important: Ensure all opened files are closed.
92
- for f in opened_files:
93
- f.close()
94
-
95
- async def _arun(
96
- self,
97
- image_paths: List[str],
98
- prompt: str,
99
- system_prompt: str = "You are an expert radiologist.",
100
- max_new_tokens: int = 300,
101
- run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
102
- ) -> str:
103
- """Execute the tool asynchronously."""
104
- async with httpx.AsyncClient() as client:
105
- files_to_send = []
106
- opened_files = []
107
- try:
108
- # Note: File I/O is blocking, for a truly async app you might use aiofiles
109
- # But for this use case, this is generally acceptable.
110
- for path in image_paths:
111
- f = open(path, "rb")
112
- opened_files.append(f)
113
- files_to_send.append(("images", (os.path.basename(path), f, "image/jpeg")))
114
-
115
- data = {
116
- "prompt": prompt,
117
- "system_prompt": system_prompt,
118
- "max_new_tokens": max_new_tokens,
119
- }
120
-
121
- response = await client.post(
122
- f"{self.api_url}/analyze-images/",
123
- data=data,
124
- files=files_to_send,
125
- timeout=120.0
126
- )
127
- response.raise_for_status()
128
-
129
- return response.json()["response"]
130
-
131
- except httpx.HTTPStatusError as e:
132
- return f"Error calling MedGemma API: {e.response.status_code} - {e.response.text}"
133
- except Exception as e:
134
- return f"An unexpected error occurred: {str(e)}"
135
- finally:
136
- for f in opened_files:
137
- f.close()
138
-
139
- if __name__ == "__main__":
140
- client_tool = MedGemmaAPIClientTool(api_url="http://localhost:8002")
141
- result = client_tool.run({
142
- "image_paths": ["demo/chest/pneumonia1.jpg"],
143
- "prompt": "What abnormality do you see?"
144
- })
145
- print(result)