Emily Xie commited on
Commit
b2aba7d
·
1 Parent(s): c34de72

MedGemma fixes

Browse files
main.py CHANGED
@@ -91,7 +91,7 @@ def initialize_agent(
91
  "MedSAM2Tool": lambda: MedSAM2Tool(
92
  device=device, cache_dir=model_dir, temp_dir=temp_dir
93
  ),
94
- "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(cache_dir=model_dir, device=device, api_url=MEDGEMMA_API_URL)
95
  }
96
 
97
  # Initialize only selected tools or all if none specified
@@ -184,9 +184,13 @@ if __name__ == "__main__":
184
  # "PythonSandboxTool", # Add the Python sandbox tool
185
  ]
186
 
 
 
 
 
187
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
188
  if "MedGemmaVQATool" in selected_tools:
189
- setup_medgemma_env()
190
 
191
  # Configure the Retrieval Augmented Generation (RAG) system
192
  # This allows the agent to access and use medical knowledge documents
@@ -210,9 +214,9 @@ if __name__ == "__main__":
210
  agent, tools_dict = initialize_agent(
211
  prompt_file="medrax/docs/system_prompts.txt",
212
  tools_to_use=selected_tools,
213
- model_dir="/model-weights",
214
  temp_dir="temp2", # Change this to the path of the temporary directory
215
- device="cuda:0",
216
  model="gpt-5", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro, gpt-5
217
  temperature=1.0,
218
  model_kwargs=model_kwargs,
 
91
  "MedSAM2Tool": lambda: MedSAM2Tool(
92
  device=device, cache_dir=model_dir, temp_dir=temp_dir
93
  ),
94
+ "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(cache_dir=model_dir, device=device, load_in_8bit=True, api_url=MEDGEMMA_API_URL)
95
  }
96
 
97
  # Initialize only selected tools or all if none specified
 
184
  # "PythonSandboxTool", # Add the Python sandbox tool
185
  ]
186
 
187
+ # Share a single cache directory and device across tools
188
+ shared_model_dir = os.getenv("MODEL_WEIGHTS_DIR", "/model-weights")
189
+ shared_device = os.getenv("MEDRAX_DEVICE", "cuda:0")
190
+
191
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
192
  if "MedGemmaVQATool" in selected_tools:
193
+ setup_medgemma_env(cache_dir=shared_model_dir, device=shared_device)
194
 
195
  # Configure the Retrieval Augmented Generation (RAG) system
196
  # This allows the agent to access and use medical knowledge documents
 
214
  agent, tools_dict = initialize_agent(
215
  prompt_file="medrax/docs/system_prompts.txt",
216
  tools_to_use=selected_tools,
217
+ model_dir=shared_model_dir,
218
  temp_dir="temp2", # Change this to the path of the temporary directory
219
+ device=shared_device,
220
  model="gpt-5", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro, gpt-5
221
  temperature=1.0,
222
  model_kwargs=model_kwargs,
medrax/tools/vqa/medgemma/medgemma.py CHANGED
@@ -98,7 +98,7 @@ class MedGemmaModel:
98
  device: Optional[str] = "cuda",
99
  dtype: torch.dtype = torch.bfloat16,
100
  cache_dir: Optional[str] = None,
101
- load_in_4bit: bool = True,
102
  **kwargs: Any,
103
  ) -> None:
104
  """Initialize the MedGemmaModel.
@@ -108,7 +108,7 @@ class MedGemmaModel:
108
  device: Device to run model on - "cuda" or "cpu" (default: "cuda")
109
  dtype: Data type for model weights - bfloat16 recommended for efficiency (default: torch.bfloat16)
110
  cache_dir: Directory to cache downloaded models (default: None)
111
- load_in_4bit: Whether to load model in 4-bit quantization for memory efficiency (default: True)
112
  **kwargs: Additional arguments passed to the model pipeline
113
 
114
  Raises:
@@ -138,8 +138,8 @@ class MedGemmaModel:
138
  "use_cache": True,
139
  }
140
 
141
- if load_in_4bit:
142
- model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
143
  model_kwargs["device_map"] = {"": self.device}
144
 
145
  try:
@@ -288,6 +288,7 @@ app = FastAPI(
288
  )
289
 
290
  medgemma_model: Optional[MedGemmaModel] = None
 
291
 
292
  @app.on_event("startup")
293
  async def startup_event():
@@ -306,7 +307,32 @@ async def startup_event():
306
  """
307
  global medgemma_model
308
  try:
309
- medgemma_model = MedGemmaModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  print("MedGemma model loaded successfully.")
311
  except RuntimeError as e:
312
  print(f"Error loading MedGemma model: {e}")
@@ -379,8 +405,12 @@ async def analyze_images(
379
  raise HTTPException(status_code=500, detail=f"Failed to save uploaded image: {str(e)}")
380
 
381
  try:
382
- # Generate AI analysis
383
- response_text = await medgemma_model.aget_response(image_paths, prompt, system_prompt, max_new_tokens)
 
 
 
 
384
 
385
  # Prepare success response
386
  metadata = {
 
98
  device: Optional[str] = "cuda",
99
  dtype: torch.dtype = torch.bfloat16,
100
  cache_dir: Optional[str] = None,
101
+ load_in_8bit: bool = True,
102
  **kwargs: Any,
103
  ) -> None:
104
  """Initialize the MedGemmaModel.
 
108
  device: Device to run model on - "cuda" or "cpu" (default: "cuda")
109
  dtype: Data type for model weights - bfloat16 recommended for efficiency (default: torch.bfloat16)
110
  cache_dir: Directory to cache downloaded models (default: None)
111
+ load_in_8bit: Whether to load model in 4-bit quantization for memory efficiency (default: True)
112
  **kwargs: Additional arguments passed to the model pipeline
113
 
114
  Raises:
 
138
  "use_cache": True,
139
  }
140
 
141
+ if load_in_8bit:
142
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
143
  model_kwargs["device_map"] = {"": self.device}
144
 
145
  try:
 
288
  )
289
 
290
  medgemma_model: Optional[MedGemmaModel] = None
291
+ inference_semaphore: Optional[asyncio.Semaphore] = None
292
 
293
  @app.on_event("startup")
294
  async def startup_event():
 
307
  """
308
  global medgemma_model
309
  try:
310
+ # Allow overriding Hugging Face cache directory and device via env vars
311
+ cache_dir_env = os.getenv("MEDGEMMA_CACHE_DIR")
312
+ device_env = os.getenv("MEDGEMMA_DEVICE")
313
+ max_concurrency_env = os.getenv("MEDGEMMA_MAX_CONCURRENCY", "1")
314
+
315
+ # Ensure the cache directory is writable; if not, fall back to a user cache
316
+ if cache_dir_env:
317
+ try:
318
+ os.makedirs(cache_dir_env, exist_ok=True)
319
+ if not os.access(cache_dir_env, os.W_OK):
320
+ raise PermissionError("Cache dir not writable")
321
+ except Exception:
322
+ fallback = os.path.join(Path.home(), ".cache", "medrax", "medgemma")
323
+ os.makedirs(fallback, exist_ok=True)
324
+ print(f"Warning: MEDGEMMA_CACHE_DIR '{cache_dir_env}' not writable. Falling back to '{fallback}'.")
325
+ cache_dir_env = fallback
326
+
327
+ medgemma_model = MedGemmaModel(cache_dir=cache_dir_env, device=device_env)
328
+ # Initialize concurrency gate
329
+ try:
330
+ max_concurrency = int(max_concurrency_env)
331
+ except ValueError:
332
+ max_concurrency = 1
333
+ max_concurrency = max(1, max_concurrency)
334
+ global inference_semaphore
335
+ inference_semaphore = asyncio.Semaphore(max_concurrency)
336
  print("MedGemma model loaded successfully.")
337
  except RuntimeError as e:
338
  print(f"Error loading MedGemma model: {e}")
 
405
  raise HTTPException(status_code=500, detail=f"Failed to save uploaded image: {str(e)}")
406
 
407
  try:
408
+ # Generate AI analysis with concurrency gating to avoid GPU contention timeouts
409
+ global inference_semaphore
410
+ if inference_semaphore is None:
411
+ inference_semaphore = asyncio.Semaphore(1)
412
+ async with inference_semaphore:
413
+ response_text = await medgemma_model.aget_response(image_paths, prompt, system_prompt, max_new_tokens)
414
 
415
  # Prepare success response
416
  metadata = {
medrax/tools/vqa/medgemma/medgemma_client.py CHANGED
@@ -59,15 +59,21 @@ class MedGemmaAPIClientTool(BaseTool):
59
 
60
  # API configuration
61
  api_url: str # The URL of the running FastAPI service
 
 
62
 
63
- def __init__(self, api_url: str, **kwargs: Any):
64
  """Initialize the MedGemmaAPIClientTool.
65
 
66
  Args:
67
  api_url: The URL of the running MedGemma FastAPI service
 
 
 
68
  **kwargs: Additional arguments passed to BaseTool
69
  """
70
- super().__init__(api_url=api_url, **kwargs)
 
71
 
72
  def _prepare_request_data(
73
  self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int
@@ -149,7 +155,8 @@ class MedGemmaAPIClientTool(BaseTool):
149
  Tuple of output dictionary and metadata
150
  """
151
  # httpx is a modern HTTP client that supports sync and async
152
- timeout_config = httpx.Timeout(300.0, connect=10.0)
 
153
  client = httpx.Client(timeout=timeout_config)
154
 
155
  try:
@@ -233,11 +240,12 @@ class MedGemmaAPIClientTool(BaseTool):
233
  image_paths, prompt, system_prompt, max_new_tokens
234
  )
235
 
 
236
  response = await client.post(
237
  f"{self.api_url}/analyze-images/",
238
  data=data,
239
  files=files_to_send,
240
- timeout=120.0
241
  )
242
  response.raise_for_status()
243
 
 
59
 
60
  # API configuration
61
  api_url: str # The URL of the running FastAPI service
62
+ cache_dir: Optional[str] = None # Not used by the client directly, but accepted to keep a uniform constructor
63
+ device: Optional[str] = None
64
 
65
+ def __init__(self, api_url: str, cache_dir: Optional[str] = None, device: Optional[str] = None, timeout_seconds: Optional[float] = None, **kwargs: Any):
66
  """Initialize the MedGemmaAPIClientTool.
67
 
68
  Args:
69
  api_url: The URL of the running MedGemma FastAPI service
70
+ cache_dir: Optional local cache directory for model weights (accepted for interface consistency)
71
+ device: Optional device spec (accepted for interface consistency)
72
+ timeout_seconds: Optional request timeout override (seconds)
73
  **kwargs: Additional arguments passed to BaseTool
74
  """
75
+ super().__init__(api_url=api_url, cache_dir=cache_dir, device=device, **kwargs)
76
+ self._timeout_seconds = timeout_seconds
77
 
78
  def _prepare_request_data(
79
  self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int
 
155
  Tuple of output dictionary and metadata
156
  """
157
  # httpx is a modern HTTP client that supports sync and async
158
+ timeout_value = self._timeout_seconds if self._timeout_seconds is not None else 600.0
159
+ timeout_config = httpx.Timeout(timeout_value, connect=10.0)
160
  client = httpx.Client(timeout=timeout_config)
161
 
162
  try:
 
240
  image_paths, prompt, system_prompt, max_new_tokens
241
  )
242
 
243
+ timeout_value = self._timeout_seconds if self._timeout_seconds is not None else 600.0
244
  response = await client.post(
245
  f"{self.api_url}/analyze-images/",
246
  data=data,
247
  files=files_to_send,
248
+ timeout=timeout_value
249
  )
250
  response.raise_for_status()
251
 
medrax/tools/vqa/medgemma/medgemma_setup.py CHANGED
@@ -3,7 +3,23 @@ from pathlib import Path
3
  import subprocess
4
  import venv
5
 
6
- def setup_medgemma_env():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """Set up MedGemma virtual environment and launch the FastAPI service.
8
 
9
  This function performs the following steps:
@@ -55,10 +71,15 @@ def setup_medgemma_env():
55
 
56
  # Launch MedGemma FastAPI service
57
  print("Launching MedGemma FastAPI service...")
 
 
 
 
 
58
  subprocess.Popen([
59
  str(python_executable),
60
  str(medgemma_path)
61
- ])
62
  # Note: stdout and stderr redirection commented out for debugging
63
  # stdout=subprocess.DEVNULL,
64
  # stderr=subprocess.DEVNULL,
 
3
  import subprocess
4
  import venv
5
 
6
+ def _resolve_writable_cache_dir(preferred: str | None) -> str:
7
+ """Return a writable cache directory, falling back to user cache if needed."""
8
+ # Preferred path first
9
+ if preferred:
10
+ try:
11
+ os.makedirs(preferred, exist_ok=True)
12
+ if os.access(preferred, os.W_OK):
13
+ return preferred
14
+ except Exception:
15
+ pass
16
+ # Fallback path under user's home
17
+ fallback = os.path.join(Path.home(), ".cache", "medrax", "medgemma")
18
+ os.makedirs(fallback, exist_ok=True)
19
+ return fallback
20
+
21
+
22
+ def setup_medgemma_env(cache_dir: str | None = None, device: str | None = None):
23
  """Set up MedGemma virtual environment and launch the FastAPI service.
24
 
25
  This function performs the following steps:
 
71
 
72
  # Launch MedGemma FastAPI service
73
  print("Launching MedGemma FastAPI service...")
74
+ env = os.environ.copy()
75
+ resolved_cache = _resolve_writable_cache_dir(cache_dir)
76
+ env["MEDGEMMA_CACHE_DIR"] = resolved_cache
77
+ if device:
78
+ env["MEDGEMMA_DEVICE"] = device
79
  subprocess.Popen([
80
  str(python_executable),
81
  str(medgemma_path)
82
+ ], env=env)
83
  # Note: stdout and stderr redirection commented out for debugging
84
  # stdout=subprocess.DEVNULL,
85
  # stderr=subprocess.DEVNULL,