VictorLJZ commited on
Commit
4c987e5
·
2 Parent(s): c94d452 2639ef4

Merge branch 'main' into emily/medgemma

Browse files
.gitignore CHANGED
@@ -178,3 +178,5 @@ medrax-pdfs/
178
  model-weights/
179
 
180
  .DS_Store
 
 
 
178
  model-weights/
179
 
180
  .DS_Store
181
+
182
+ benchmarking/data/
README.md CHANGED
@@ -29,6 +29,7 @@ MedRAX is built on a robust technical foundation:
29
  - **Disease Classification**: Leverages DenseNet-121 from TorchXRayVision for detecting 18 pathology classes
30
  - **X-ray Generation**: Utilizes RoentGen for synthetic CXR generation
31
  - **Web Browser**: Provides web search capabilities and URL content retrieval using Google Custom Search API
 
32
  - **Python Sandbox**: Executes Python code in a secure, stateful sandbox environment using `langchain-sandbox` and Pyodide. Supports custom data analysis, calculations, and dynamic package installations. Pre-configured with medical analysis packages including pandas, numpy, pydicom, SimpleITK, scikit-image, Pillow, scikit-learn, matplotlib, seaborn, and openpyxl. **Requires Deno runtime.**
33
  - **Utilities**: Includes DICOM processing, visualization tools, and custom plotting capabilities
34
  <br><br>
@@ -164,6 +165,7 @@ selected_tools = [
164
  "ChestXRaySegmentationTool",
165
  "PythonSandboxTool", # Python code execution
166
  "WebBrowserTool", # Web search and URL access
 
167
  # Add or remove tools as needed
168
  ]
169
 
@@ -179,17 +181,10 @@ agent, tools_dict = initialize_agent(
179
 
180
  The following tools will automatically download their model weights when initialized:
181
 
182
- ### Classification Tools
183
  ```python
184
  # TorchXRayVision-based classifier (original)
185
  TorchXRayVisionClassifierTool(device=device)
186
-
187
- # ArcPlus SwinTransformer-based classifier (new)
188
- ArcPlusClassifierTool(
189
- model_path="/path/to/Ark6_swinLarge768_ep50.pth.tar", # Optional
190
- num_classes=18, # Default
191
- device=device
192
- )
193
  ```
194
 
195
  ### Segmentation Tool
@@ -283,6 +278,7 @@ No additional model weights required:
283
  ImageVisualizerTool()
284
  DicomProcessorTool(temp_dir=temp_dir)
285
  WebBrowserTool() # Requires Google Search API credentials
 
286
  ```
287
  <br>
288
 
@@ -301,6 +297,25 @@ ChestXRayGeneratorTool(
301
  2. Place weights in `{model_dir}/roentgen`
302
  3. Optional tool, can be excluded if not needed
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  ### Knowledge Base Setup (MedicalRAGTool)
305
 
306
  The `MedicalRAGTool` uses a Pinecone vector database to store and retrieve medical knowledge. To use this tool, you need to set up a Pinecone account and a Cohere account.
@@ -403,6 +418,8 @@ If you are running a local LLM using frameworks like [Ollama](https://ollama.com
403
 
404
  **WebBrowserTool**: Requires Google Custom Search API credentials, which can be set in the `.env` file.
405
 
 
 
406
  **PythonSandboxTool**: Requires Deno runtime installation:
407
  ```bash
408
  # Verify Deno is installed
 
29
  - **Disease Classification**: Leverages DenseNet-121 from TorchXRayVision for detecting 18 pathology classes
30
  - **X-ray Generation**: Utilizes RoentGen for synthetic CXR generation
31
  - **Web Browser**: Provides web search capabilities and URL content retrieval using Google Custom Search API
32
+ - **DuckDuckGo Search**: Offers privacy-focused web search capabilities using DuckDuckGo search engine for medical research, fact-checking, and accessing current medical information without API keys
33
  - **Python Sandbox**: Executes Python code in a secure, stateful sandbox environment using `langchain-sandbox` and Pyodide. Supports custom data analysis, calculations, and dynamic package installations. Pre-configured with medical analysis packages including pandas, numpy, pydicom, SimpleITK, scikit-image, Pillow, scikit-learn, matplotlib, seaborn, and openpyxl. **Requires Deno runtime.**
34
  - **Utilities**: Includes DICOM processing, visualization tools, and custom plotting capabilities
35
  <br><br>
 
165
  "ChestXRaySegmentationTool",
166
  "PythonSandboxTool", # Python code execution
167
  "WebBrowserTool", # Web search and URL access
168
+ "DuckDuckGoSearchTool", # Privacy-focused web search
169
  # Add or remove tools as needed
170
  ]
171
 
 
181
 
182
  The following tools will automatically download their model weights when initialized:
183
 
184
+ ### Classification Tool
185
  ```python
186
  # TorchXRayVision-based classifier (original)
187
  TorchXRayVisionClassifierTool(device=device)
 
 
 
 
 
 
 
188
  ```
189
 
190
  ### Segmentation Tool
 
278
  ImageVisualizerTool()
279
  DicomProcessorTool(temp_dir=temp_dir)
280
  WebBrowserTool() # Requires Google Search API credentials
281
+ DuckDuckGoSearchTool() # No API key required, privacy-focused search
282
  ```
283
  <br>
284
 
 
297
  2. Place weights in `{model_dir}/roentgen`
298
  3. Optional tool, can be excluded if not needed
299
 
300
+ ### ArcPlus SwinTransformer-based Classifier
301
+ ```python
302
+ ArcPlusClassifierTool(
303
+ model_path="/path/to/Ark6_swinLarge768_ep50.pth.tar", # Optional
304
+ num_classes=18, # Default
305
+ device=device
306
+ )
307
+ ```
308
+
309
+ The ArcPlus classifier requires manual setup as the pre-trained model is not publicly available for automatic download:
310
+
311
+ 1. **Request Access**: Visit [https://github.com/jlianglab/Ark](https://github.com/jlianglab/Ark) and request the pretrained model through their Google Forms
312
+ 2. **Download Model**: Once approved, download the `Ark6_swinLarge768_ep50.pth.tar` file
313
+ 3. **Place in Directory**: Drag the downloaded file into your `model-weights` directory
314
+ 4. **Initialize Tool**: The tool will automatically look for the model file in the specified `cache_dir`
315
+
316
+ The ArcPlus model provides advanced chest X-ray classification across 6 medical datasets (MIMIC, CheXpert, NIH, RSNA, VinDr, Shenzhen) with 52+ pathology categories.
317
+ ```
318
+
319
  ### Knowledge Base Setup (MedicalRAGTool)
320
 
321
  The `MedicalRAGTool` uses a Pinecone vector database to store and retrieve medical knowledge. To use this tool, you need to set up a Pinecone account and a Cohere account.
 
418
 
419
  **WebBrowserTool**: Requires Google Custom Search API credentials, which can be set in the `.env` file.
420
 
421
+ **DuckDuckGoSearchTool**: No API key required. Uses DuckDuckGo's privacy-focused search engine for medical research and fact-checking.
422
+
423
  **PythonSandboxTool**: Requires Deno runtime installation:
424
  ```bash
425
  # Verify Deno is installed
benchmarking/benchmarks/base.py CHANGED
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
4
  from typing import Dict, List, Optional, Any, Iterator, Tuple
5
  from dataclasses import dataclass
6
  from pathlib import Path
 
7
 
8
 
9
  @dataclass
@@ -31,17 +32,31 @@ class Benchmark(ABC):
31
  Args:
32
  data_dir (str): Directory containing benchmark data
33
  **kwargs: Additional configuration parameters
 
34
  """
35
  self.data_dir = Path(data_dir)
36
  self.config = kwargs
37
  self.data_points = []
38
  self._load_data()
 
39
 
40
  @abstractmethod
41
  def _load_data(self) -> None:
42
  """Load benchmark data from the data directory."""
43
  pass
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def get_data_point(self, index: int) -> BenchmarkDataPoint:
46
  """Get a specific data point by index.
47
 
 
4
  from typing import Dict, List, Optional, Any, Iterator, Tuple
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
+ import random
8
 
9
 
10
  @dataclass
 
32
  Args:
33
  data_dir (str): Directory containing benchmark data
34
  **kwargs: Additional configuration parameters
35
+ random_seed (int): Random seed for shuffling data (default: None, no shuffling)
36
  """
37
  self.data_dir = Path(data_dir)
38
  self.config = kwargs
39
  self.data_points = []
40
  self._load_data()
41
+ self._shuffle_data()
42
 
43
  @abstractmethod
44
  def _load_data(self) -> None:
45
  """Load benchmark data from the data directory."""
46
  pass
47
 
48
+ def _shuffle_data(self) -> None:
49
+ """Shuffle the data points if a random seed is provided.
50
+
51
+ This method is called automatically after data loading to ensure
52
+ reproducible benchmark runs when a random seed is specified.
53
+ """
54
+ random_seed = self.config.get("random_seed", None)
55
+ if random_seed is not None:
56
+ random.seed(random_seed)
57
+ random.shuffle(self.data_points)
58
+ print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")
59
+
60
  def get_data_point(self, index: int) -> BenchmarkDataPoint:
61
  """Get a specific data point by index.
62
 
benchmarking/benchmarks/rexvqa_benchmark.py CHANGED
@@ -2,10 +2,12 @@
2
 
3
  import json
4
  import os
5
- from typing import Dict, List, Optional, Any
6
  from datasets import load_dataset
7
  from .base import Benchmark, BenchmarkDataPoint
8
  from pathlib import Path
 
 
9
 
10
 
11
  class ReXVQABenchmark(Benchmark):
@@ -47,11 +49,128 @@ class ReXVQABenchmark(Benchmark):
47
 
48
  super().__init__(data_dir, **kwargs)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def _load_data(self) -> None:
51
  """Load ReXVQA data from local JSON file."""
52
  try:
 
 
 
 
53
  # Construct path to the JSON file
54
- json_file_path = os.path.join("benchmarking", "data", "rexvqa", "test_vqa_data.json")
55
 
56
  # Check if file exists
57
  if not os.path.exists(json_file_path):
 
2
 
3
  import json
4
  import os
5
+ from typing import Dict, Optional, Any
6
  from datasets import load_dataset
7
  from .base import Benchmark, BenchmarkDataPoint
8
  from pathlib import Path
9
+ import subprocess
10
+ from huggingface_hub import hf_hub_download, list_repo_files
11
 
12
 
13
  class ReXVQABenchmark(Benchmark):
 
49
 
50
  super().__init__(data_dir, **kwargs)
51
 
52
+ @staticmethod
53
+ def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K"):
54
+ """Download and extract ReXGradient-160K images if not already present."""
55
+ output_dir = Path(output_dir)
56
+ tar_path = output_dir / "deid_png.tar"
57
+ images_dir = output_dir / "images"
58
+
59
+ # Check if images already exist
60
+ if images_dir.exists() and any(images_dir.rglob("*.png")):
61
+ print(f"Images already exist in {images_dir}, skipping download.")
62
+ return
63
+ output_dir.mkdir(parents=True, exist_ok=True)
64
+ print(f"Output directory: {output_dir}")
65
+ try:
66
+ print("Listing files in repository...")
67
+ files = list_repo_files(repo_id, repo_type='dataset')
68
+ part_files = [f for f in files if f.startswith("deid_png.part")]
69
+ if not part_files:
70
+ print("No part files found. The images might be in a different format.")
71
+ return
72
+ print(f"Found {len(part_files)} part files.")
73
+ # Download part files
74
+ for part_file in part_files:
75
+ output_path = output_dir / part_file
76
+ if output_path.exists():
77
+ print(f"Skipping {part_file} (already exists)")
78
+ continue
79
+ print(f"Downloading {part_file}...")
80
+ hf_hub_download(
81
+ repo_id=repo_id,
82
+ filename=part_file,
83
+ local_dir=output_dir,
84
+ local_dir_use_symlinks=False,
85
+ repo_type='dataset'
86
+ )
87
+ # Concatenate part files
88
+ if not tar_path.exists():
89
+ print("\nConcatenating part files...")
90
+ with open(tar_path, 'wb') as tar_file:
91
+ for part_file in sorted(part_files):
92
+ part_path = output_dir / part_file
93
+ if part_path.exists():
94
+ print(f"Adding {part_file}...")
95
+ with open(part_path, 'rb') as f:
96
+ tar_file.write(f.read())
97
+ else:
98
+ print(f"Warning: {part_file} not found, skipping...")
99
+ else:
100
+ print(f"Tar file already exists: {tar_path}")
101
+ # Extract tar file
102
+ if tar_path.exists():
103
+ print("\nExtracting images...")
104
+ images_dir.mkdir(exist_ok=True)
105
+ if any(images_dir.rglob("*.png")):
106
+ print("Images already extracted.")
107
+ else:
108
+ try:
109
+ subprocess.run([
110
+ "tar", "-xf", str(tar_path),
111
+ "-C", str(images_dir)
112
+ ], check=True)
113
+ print("Extraction completed!")
114
+ except subprocess.CalledProcessError as e:
115
+ print(f"Error extracting tar file: {e}")
116
+ return
117
+ except FileNotFoundError:
118
+ print("Error: 'tar' command not found. Please install tar or extract manually.")
119
+ return
120
+ png_files = list(images_dir.rglob("*.png"))
121
+ print(f"Extracted {len(png_files)} PNG images to {images_dir}")
122
+
123
+ # Clean up part and tar files after successful extraction
124
+ print("Cleaning up part and tar files...")
125
+ # Remove deid_png.part* files
126
+ for part_file in output_dir.glob("deid_png.part*"):
127
+ try:
128
+ part_file.unlink()
129
+ print(f"Deleted {part_file}")
130
+ except Exception as e:
131
+ print(f"Could not delete {part_file}: {e}")
132
+ # Remove deid_png.tar
133
+ if tar_path.exists():
134
+ try:
135
+ tar_path.unlink()
136
+ print(f"Deleted {tar_path}")
137
+ except Exception as e:
138
+ print(f"Could not delete {tar_path}: {e}")
139
+ except Exception as e:
140
+ print(f"Error: {e}")
141
+
142
+ @staticmethod
143
+ def download_test_vqa_data_json(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXVQA"):
144
+ """Download test_vqa_data.json from the ReXVQA HuggingFace repo if not already present."""
145
+ output_dir = Path(output_dir)
146
+ output_dir.mkdir(parents=True, exist_ok=True)
147
+ json_path = output_dir / "metadata" / "test_vqa_data.json"
148
+ if json_path.exists():
149
+ print(f"test_vqa_data.json already exists at {json_path}, skipping download.")
150
+ return
151
+ print(f"Downloading test_vqa_data.json to {json_path}...")
152
+ try:
153
+ hf_hub_download(
154
+ repo_id=repo_id,
155
+ filename="metadata/test_vqa_data.json",
156
+ local_dir=output_dir,
157
+ local_dir_use_symlinks=False,
158
+ repo_type='dataset'
159
+ )
160
+ print("Download complete.")
161
+ except Exception as e:
162
+ print(f"Error downloading test_vqa_data.json: {e}")
163
+ print("You may need to accept the license agreement on HuggingFace.")
164
+
165
  def _load_data(self) -> None:
166
  """Load ReXVQA data from local JSON file."""
167
  try:
168
+ # Check for images and test_vqa_data.json, download if missing
169
+ self.download_test_vqa_data_json()
170
+ self.download_rexgradient_images()
171
+
172
  # Construct path to the JSON file
173
+ json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
174
 
175
  # Check if file exists
176
  if not os.path.exists(json_file_path):
benchmarking/cli.py CHANGED
@@ -3,34 +3,40 @@
3
  import argparse
4
  import sys
5
 
6
- from .llm_providers import *
7
  from .benchmarks import *
8
  from .runner import BenchmarkRunner, BenchmarkRunConfig
9
 
10
 
11
- def create_llm_provider(model_name: str, provider_type: str, **kwargs) -> LLMProvider:
12
  """Create an LLM provider based on the model name and type.
13
 
14
  Args:
15
  model_name (str): Name of the model
16
- provider_type (str): Type of provider (openai, google, openrouter, anthropic, medrax)
 
17
  **kwargs: Additional configuration parameters
18
 
19
  Returns:
20
  LLMProvider: The configured LLM provider
21
  """
22
- provider_map = {
23
- "openai": OpenAIProvider,
24
- "google": GoogleProvider,
25
- "openrouter": OpenRouterProvider,
26
- "medrax": MedRAXProvider,
27
- }
28
-
29
- if provider_type not in provider_map:
30
- raise ValueError(f"Unknown provider type: {provider_type}. Available: {list(provider_map.keys())}")
31
-
32
- provider_class = provider_map[provider_type]
33
- return provider_class(model_name, **kwargs)
 
 
 
 
 
34
 
35
 
36
  def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
@@ -63,12 +69,14 @@ def run_benchmark_command(args) -> None:
63
  # Create LLM provider
64
  provider_kwargs = {}
65
 
66
- llm_provider = create_llm_provider(args.model, args.provider, **provider_kwargs)
67
 
68
  # Create benchmark
69
  benchmark_kwargs = {}
 
 
70
 
71
- benchmark = create_benchmark(args.benchmark, args.data_dir, **benchmark_kwargs)
72
 
73
  # Create runner config
74
  config = BenchmarkRunConfig(
@@ -111,16 +119,32 @@ def main():
111
  subparsers = parser.add_subparsers(dest="command", help="Available commands")
112
 
113
  # Run benchmark command
114
- run_parser = subparsers.add_parser("run", help="Run a benchmark")
115
- run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
116
- run_parser.add_argument("--provider", required=True, choices=["openai", "google", "openrouter", "medrax"], help="LLM provider")
117
- run_parser.add_argument("--benchmark", required=True, choices=["rexvqa", "chestagentbench"], help="Benchmark to run")
118
- run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
119
- run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
120
- run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
121
- run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
122
- run_parser.add_argument("--top-p", type=float, default=0.95, help="Top-p value")
123
- run_parser.add_argument("--max-tokens", type=int, default=1000, help="Maximum tokens per response")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  run_parser.set_defaults(func=run_benchmark_command)
126
 
 
3
  import argparse
4
  import sys
5
 
6
+ from .llm_providers.base import LLMProvider
7
  from .benchmarks import *
8
  from .runner import BenchmarkRunner, BenchmarkRunConfig
9
 
10
 
11
+ def create_llm_provider(model_name: str, provider_type: str, system_prompt: str, **kwargs) -> LLMProvider:
12
  """Create an LLM provider based on the model name and type.
13
 
14
  Args:
15
  model_name (str): Name of the model
16
+ provider_type (str): Type of provider (openai, google, openrouter, medrax)
17
+ system_prompt (str): System prompt identifier to load from file
18
  **kwargs: Additional configuration parameters
19
 
20
  Returns:
21
  LLMProvider: The configured LLM provider
22
  """
23
+ # Lazy imports to avoid slow startup
24
+ if provider_type == "openai":
25
+ from .llm_providers.openai_provider import OpenAIProvider
26
+ provider_class = OpenAIProvider
27
+ elif provider_type == "google":
28
+ from .llm_providers.google_provider import GoogleProvider
29
+ provider_class = GoogleProvider
30
+ elif provider_type == "openrouter":
31
+ from .llm_providers.openrouter_provider import OpenRouterProvider
32
+ provider_class = OpenRouterProvider
33
+ elif provider_type == "medrax":
34
+ from .llm_providers.medrax_provider import MedRAXProvider
35
+ provider_class = MedRAXProvider
36
+ else:
37
+ raise ValueError(f"Unknown provider type: {provider_type}. Available: openai, google, openrouter, medrax")
38
+
39
+ return provider_class(model_name, system_prompt, **kwargs)
40
 
41
 
42
  def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
 
69
  # Create LLM provider
70
  provider_kwargs = {}
71
 
72
+ llm_provider = create_llm_provider(model_name=args.model, provider_type=args.provider, system_prompt=args.system_prompt, **provider_kwargs)
73
 
74
  # Create benchmark
75
  benchmark_kwargs = {}
76
+ if args.random_seed is not None:
77
+ benchmark_kwargs["random_seed"] = args.random_seed
78
 
79
+ benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
80
 
81
  # Create runner config
82
  config = BenchmarkRunConfig(
 
119
  subparsers = parser.add_subparsers(dest="command", help="Available commands")
120
 
121
  # Run benchmark command
122
+ run_parser = subparsers.add_parser("run", help="Run a benchmark evaluation")
123
+ run_parser.add_argument("--model", required=True,
124
+ help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)")
125
+ run_parser.add_argument("--provider", required=True,
126
+ choices=["openai", "google", "openrouter", "medrax"],
127
+ help="LLM provider to use")
128
+ run_parser.add_argument("--system-prompt", required=True,
129
+ choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT"],
130
+ help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
131
+ run_parser.add_argument("--benchmark", required=True,
132
+ choices=["rexvqa", "chestagentbench"],
133
+ help="Benchmark dataset: rexvqa (radiology VQA) or chestagentbench (chest X-ray reasoning)")
134
+ run_parser.add_argument("--data-dir", required=True,
135
+ help="Directory containing benchmark data files")
136
+ run_parser.add_argument("--output-dir", default="benchmark_results",
137
+ help="Output directory for results (default: benchmark_results)")
138
+ run_parser.add_argument("--max-questions", type=int,
139
+ help="Maximum number of questions to process (default: all)")
140
+ run_parser.add_argument("--temperature", type=float, default=1,
141
+ help="Model temperature for response generation (default: 0.7)")
142
+ run_parser.add_argument("--top-p", type=float, default=0.95,
143
+ help="Top-p nucleus sampling parameter (default: 0.95)")
144
+ run_parser.add_argument("--max-tokens", type=int, default=5000,
145
+ help="Maximum tokens per model response (default: 5000)")
146
+ run_parser.add_argument("--random-seed", type=int, default=42,
147
+ help="Random seed for shuffling benchmark data (enables reproducible runs, default: None)")
148
 
149
  run_parser.set_defaults(func=run_benchmark_command)
150
 
benchmarking/data/rexvqa/download_rexgradient_images.py DELETED
@@ -1,172 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Utility script to download and extract ReXGradient-160K images.
4
-
5
- This script helps users download the actual PNG images from the ReXGradient-160K dataset,
6
- which are stored as part files on HuggingFace and need to be concatenated and extracted.
7
-
8
- Usage:
9
- python download_rexgradient_images.py --output_dir /path/to/images
10
- """
11
-
12
- import argparse
13
- import subprocess
14
- from pathlib import Path
15
- from huggingface_hub import hf_hub_download, list_repo_files
16
- import requests
17
- from tqdm import tqdm
18
-
19
-
20
- def download_file(url, output_path, chunk_size=8192):
21
- """Download a file with progress bar."""
22
- response = requests.get(url, stream=True)
23
- total_size = int(response.headers.get('content-length', 0))
24
-
25
- with open(output_path, 'wb') as f:
26
- with tqdm(total=total_size, unit='B', unit_scale=True, desc=output_path.name) as pbar:
27
- for chunk in response.iter_content(chunk_size=chunk_size):
28
- if chunk:
29
- f.write(chunk)
30
- pbar.update(len(chunk))
31
-
32
-
33
- def main():
34
- parser = argparse.ArgumentParser(description="Download ReXGradient-160K images")
35
- parser.add_argument(
36
- "--output_dir",
37
- type=str,
38
- required=True,
39
- help="Directory to save extracted images"
40
- )
41
- parser.add_argument(
42
- "--repo_id",
43
- type=str,
44
- default="rajpurkarlab/ReXGradient-160K",
45
- help="HuggingFace repository ID"
46
- )
47
- parser.add_argument(
48
- "--skip_download",
49
- action="store_true",
50
- help="Skip downloading and only extract if files exist"
51
- )
52
-
53
- args = parser.parse_args()
54
-
55
- output_dir = Path(args.output_dir)
56
- output_dir.mkdir(parents=True, exist_ok=True)
57
-
58
- print(f"Output directory: {output_dir}")
59
-
60
- # Check if we need to accept the license first
61
- print("Note: You may need to accept the dataset license on HuggingFace first:")
62
- print(f"Visit: https://huggingface.co/datasets/{args.repo_id}")
63
- print("Click 'Access repository' and accept the license agreement.")
64
- print()
65
-
66
- try:
67
- # List files in the repository
68
- print("Listing files in repository...")
69
- files = list_repo_files(args.repo_id, repo_type='dataset')
70
- part_files = [f for f in files if f.startswith("deid_png.part")]
71
-
72
- if not part_files:
73
- print("No part files found. The images might be in a different format.")
74
- print("Available files:")
75
- for f in files:
76
- print(f" - {f}")
77
- return
78
-
79
- print(f"Found {len(part_files)} part files:")
80
- for f in part_files:
81
- print(f" - {f}")
82
-
83
- # Download part files
84
- if not args.skip_download:
85
- print("\nDownloading part files...")
86
- for part_file in part_files:
87
- output_path = output_dir / part_file
88
- if output_path.exists():
89
- print(f"Skipping {part_file} (already exists)")
90
- continue
91
-
92
- print(f"Downloading {part_file}...")
93
- try:
94
- hf_hub_download(
95
- repo_id=args.repo_id,
96
- filename=part_file,
97
- local_dir=output_dir,
98
- local_dir_use_symlinks=False,
99
- repo_type='dataset'
100
- )
101
- except Exception as e:
102
- print(f"Error downloading {part_file}: {e}")
103
- print("You may need to accept the license agreement on HuggingFace.")
104
- return
105
-
106
- # Concatenate part files
107
- tar_path = output_dir / "deid_png.tar"
108
- if not tar_path.exists():
109
- print("\nConcatenating part files...")
110
- with open(tar_path, 'wb') as tar_file:
111
- for part_file in sorted(part_files):
112
- part_path = output_dir / part_file
113
- if part_path.exists():
114
- print(f"Adding {part_file}...")
115
- with open(part_path, 'rb') as f:
116
- tar_file.write(f.read())
117
- else:
118
- print(f"Warning: {part_file} not found, skipping...")
119
- else:
120
- print(f"Tar file already exists: {tar_path}")
121
-
122
- # Extract tar file
123
- if tar_path.exists():
124
- print("\nExtracting images...")
125
- images_dir = output_dir / "images"
126
- images_dir.mkdir(exist_ok=True)
127
-
128
- # Check if already extracted
129
- if any(images_dir.glob("*.png")):
130
- print("Images already extracted.")
131
- else:
132
- try:
133
- subprocess.run([
134
- "tar", "-xf", str(tar_path),
135
- "-C", str(images_dir)
136
- ], check=True)
137
- print("Extraction completed!")
138
- except subprocess.CalledProcessError as e:
139
- print(f"Error extracting tar file: {e}")
140
- return
141
- except FileNotFoundError:
142
- print("Error: 'tar' command not found. Please install tar or extract manually.")
143
- return
144
-
145
- # Count extracted images
146
- png_files = list(images_dir.glob("*.png"))
147
- print(f"Extracted {len(png_files)} PNG images to {images_dir}")
148
-
149
- # Show some example filenames
150
- if png_files:
151
- print("\nExample image filenames:")
152
- for f in png_files[:5]:
153
- print(f" - {f.name}")
154
- if len(png_files) > 5:
155
- print(f" ... and {len(png_files) - 5} more")
156
-
157
- print(f"\nSetup complete! Use this directory as images_dir in ReXVQABenchmark:")
158
- print(f"images_dir='{images_dir}'")
159
-
160
- except Exception as e:
161
- print(f"Error: {e}")
162
- print("\nManual setup instructions:")
163
- print("1. Visit https://huggingface.co/datasets/rajpurkarlab/ReXGradient-160K")
164
- print("2. Accept the license agreement")
165
- print("3. Download the deid_png.part* files")
166
- print("4. Concatenate: cat deid_png.part* > deid_png.tar")
167
- print("5. Extract: tar -xf deid_png.tar")
168
- print("6. Use the extracted directory as images_dir")
169
-
170
-
171
- if __name__ == "__main__":
172
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarking/llm_providers/base.py CHANGED
@@ -25,7 +25,7 @@ class LLMResponse:
25
  content: str
26
  usage: Optional[Dict[str, Any]] = None
27
  duration: Optional[float] = None
28
- raw_response: Optional[Any] = None
29
 
30
 
31
  class LLMProvider(ABC):
@@ -35,22 +35,24 @@ class LLMProvider(ABC):
35
  text + image input -> text output across different models and APIs.
36
  """
37
 
38
- def __init__(self, model_name: str, **kwargs):
39
  """Initialize the LLM provider.
40
 
41
  Args:
42
  model_name (str): Name of the model to use
 
43
  **kwargs: Additional configuration parameters
44
  """
45
  self.model_name = model_name
46
  self.config = kwargs
 
47
 
48
- # Always load system prompt from file
49
  try:
50
  prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
51
- self.system_prompt = prompts.get("CHESTAGENTBENCH_PROMPT", None)
52
  if self.system_prompt is None:
53
- print(f"Warning: System prompt not found in medrax/docs/system_prompts.txt.")
54
  except Exception as e:
55
  print(f"Error loading system prompt: {e}")
56
  self.system_prompt = None
@@ -102,8 +104,12 @@ class LLMProvider(ABC):
102
  Returns:
103
  str: Base64 encoded image string
104
  """
105
- with open(image_path, "rb") as image_file:
106
- return base64.b64encode(image_file.read()).decode('utf-8')
 
 
 
 
107
 
108
  def _validate_image_paths(self, image_paths: List[str]) -> List[str]:
109
  """Validate that image paths exist and are readable.
 
25
  content: str
26
  usage: Optional[Dict[str, Any]] = None
27
  duration: Optional[float] = None
28
+ chunk_history: Optional[Any] = None
29
 
30
 
31
  class LLMProvider(ABC):
 
35
  text + image input -> text output across different models and APIs.
36
  """
37
 
38
+ def __init__(self, model_name: str, system_prompt: str, **kwargs):
39
  """Initialize the LLM provider.
40
 
41
  Args:
42
  model_name (str): Name of the model to use
43
+ system_prompt (str): System prompt identifier to load from file
44
  **kwargs: Additional configuration parameters
45
  """
46
  self.model_name = model_name
47
  self.config = kwargs
48
+ self.prompt_name = system_prompt # Store the original prompt identifier
49
 
50
+ # Load system prompt content from file
51
  try:
52
  prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
53
+ self.system_prompt = prompts.get(system_prompt, None)
54
  if self.system_prompt is None:
55
+ print(f"Warning: System prompt '{system_prompt}' not found in medrax/docs/system_prompts.txt.")
56
  except Exception as e:
57
  print(f"Error loading system prompt: {e}")
58
  self.system_prompt = None
 
104
  Returns:
105
  str: Base64 encoded image string
106
  """
107
+ try:
108
+ with open(image_path, "rb") as image_file:
109
+ return base64.b64encode(image_file.read()).decode('utf-8')
110
+ except Exception as e:
111
+ print(f"ERROR: _encode_image failed for {image_path} (type: {type(image_path)}): {e}")
112
+ raise
113
 
114
  def _validate_image_paths(self, image_paths: List[str]) -> List[str]:
115
  """Validate that image paths exist and are readable.
benchmarking/llm_providers/google_provider.py CHANGED
@@ -92,13 +92,11 @@ class GoogleProvider(LLMProvider):
92
  return LLMResponse(
93
  content=content,
94
  usage=usage,
95
- duration=duration,
96
- raw_response=response
97
  )
98
 
99
  except Exception as e:
100
  return LLMResponse(
101
  content=f"Error: {str(e)}",
102
- duration=time.time() - start_time,
103
- raw_response=None
104
  )
 
92
  return LLMResponse(
93
  content=content,
94
  usage=usage,
95
+ duration=duration
 
96
  )
97
 
98
  except Exception as e:
99
  return LLMResponse(
100
  content=f"Error: {str(e)}",
101
+ duration=time.time() - start_time
 
102
  )
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -1,10 +1,10 @@
1
  """MedRAX LLM provider implementation."""
2
 
3
  import time
4
- import shutil
5
- from pathlib import Path
6
 
7
  from .base import LLMProvider, LLMRequest, LLMResponse
 
8
 
9
  from medrax.rag.rag import RAGConfig
10
  from main import initialize_agent
@@ -13,18 +13,19 @@ from main import initialize_agent
13
  class MedRAXProvider(LLMProvider):
14
  """MedRAX LLM provider that uses the full MedRAX agent system."""
15
 
16
- def __init__(self, model_name: str, **kwargs):
17
  """Initialize MedRAX provider.
18
 
19
  Args:
20
  model_name (str): Base LLM model name (e.g., "gpt-4.1-2025-04-14")
 
21
  **kwargs: Additional configuration parameters
22
  """
23
  self.model_name = model_name
24
  self.agent = None
25
  self.tools_dict = None
26
-
27
- super().__init__(model_name, **kwargs)
28
 
29
  def _setup(self) -> None:
30
  """Set up MedRAX agent system."""
@@ -32,19 +33,14 @@ class MedRAXProvider(LLMProvider):
32
  print("Starting server...")
33
 
34
  selected_tools = [
35
- # "ImageVisualizerTool", # For displaying images in the UI
36
- # "DicomProcessorTool", # For processing DICOM medical image files
37
- # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
38
- # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
39
- # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
40
- # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
41
- # "XRayVQATool", # For visual question answering on X-rays
42
- # "LlavaMedTool", # For multimodal medical image understanding
43
- # "XRayPhraseGroundingTool", # For locating described features in X-rays
44
- # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
45
- "WebBrowserTool", # For web browsing and search capabilities
46
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
47
- # "PythonSandboxTool", # Add the Python sandbox tool
 
 
 
 
 
48
  ]
49
 
50
  rag_config = RAGConfig(
@@ -55,7 +51,7 @@ class MedRAXProvider(LLMProvider):
55
  pinecone_index_name="medrax2", # Name for the Pinecone index
56
  chunk_size=1500,
57
  chunk_overlap=300,
58
- retriever_k=7,
59
  local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
60
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
61
  dataset_split="train", # Which split of the datasets to use
@@ -69,13 +65,13 @@ class MedRAXProvider(LLMProvider):
69
  tools_to_use=selected_tools,
70
  model_dir="/model-weights",
71
  temp_dir="temp", # Change this to the path of the temporary directory
72
- device="cpu",
73
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
74
- temperature=0.7,
75
  top_p=0.95,
76
  model_kwargs=model_kwargs,
77
  rag_config=rag_config,
78
- debug=True,
79
  )
80
 
81
  self.agent = agent
@@ -101,8 +97,7 @@ class MedRAXProvider(LLMProvider):
101
  if self.agent is None:
102
  return LLMResponse(
103
  content="Error: MedRAX agent not initialized",
104
- duration=time.time() - start_time,
105
- raw_response=None
106
  )
107
 
108
  try:
@@ -110,78 +105,118 @@ class MedRAXProvider(LLMProvider):
110
  messages = []
111
  thread_id = str(int(time.time() * 1000)) # Unique thread ID
112
 
113
- # Copy images to session temp directory and provide paths
114
- image_paths = []
115
  if request.images:
116
  valid_images = self._validate_image_paths(request.images)
117
  print(f"Processing {len(valid_images)} images")
118
  for i, image_path in enumerate(valid_images):
119
- print(f"Original image path: {image_path}")
120
- # Copy image to session temp directory
121
- dest_path = Path("temp") / f"image_{i}_{Path(image_path).name}"
122
- print(f"Destination path: {dest_path}")
123
- shutil.copy2(image_path, dest_path)
124
- image_paths.append(str(dest_path))
125
-
126
- # Verify file exists after copy
127
- if not dest_path.exists():
128
- print(f"ERROR: File not found after copy: {dest_path}")
129
- else:
130
- print(f"File successfully copied: {dest_path}")
131
-
132
  # Add image path message for tools
133
- messages.append({
134
- "role": "user",
135
- "content": f"image_path: {dest_path}"
136
- })
137
 
138
  # Add image content for multimodal LLM
139
- with open(image_path, "rb") as img_file:
140
- img_base64 = self._encode_image(image_path)
141
-
142
- messages.append({
143
- "role": "user",
144
- "content": [{
145
  "type": "image_url",
146
  "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
147
- }]
148
- })
 
 
149
 
150
  # Add text message
151
- messages.append({
152
- "role": "user",
153
- "content": [{
154
  "type": "text",
155
  "text": request.text
156
- }]
157
- })
 
 
 
 
 
 
158
 
159
- # Run the agent
160
- response_content = ""
161
  for chunk in self.agent.workflow.stream(
162
  {"messages": messages},
163
  {"configurable": {"thread_id": thread_id}},
164
  stream_mode="updates"
165
  ):
166
- if isinstance(chunk, dict):
167
- for node_name, node_output in chunk.items():
168
- if "messages" in node_output:
169
- for msg in node_output["messages"]:
170
- if hasattr(msg, 'content') and msg.content:
171
- response_content += str(msg.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  duration = time.time() - start_time
174
 
175
  return LLMResponse(
176
- content=response_content.strip(),
177
  usage={"agent_tools": list(self.tools_dict.keys())},
178
  duration=duration,
179
- raw_response={"thread_id": thread_id, "image_paths": image_paths}
180
  )
181
 
182
  except Exception as e:
 
183
  return LLMResponse(
184
  content=f"Error: {str(e)}",
185
- duration=time.time() - start_time,
186
- raw_response=None
187
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """MedRAX LLM provider implementation."""
2
 
3
  import time
4
+ import re
 
5
 
6
  from .base import LLMProvider, LLMRequest, LLMResponse
7
+ from langchain_core.messages import AIMessage, HumanMessage
8
 
9
  from medrax.rag.rag import RAGConfig
10
  from main import initialize_agent
 
13
  class MedRAXProvider(LLMProvider):
14
  """MedRAX LLM provider that uses the full MedRAX agent system."""
15
 
16
+ def __init__(self, model_name: str, system_prompt: str, **kwargs):
17
  """Initialize MedRAX provider.
18
 
19
  Args:
20
  model_name (str): Base LLM model name (e.g., "gpt-4.1-2025-04-14")
21
+ system_prompt (str): System prompt to use
22
  **kwargs: Additional configuration parameters
23
  """
24
  self.model_name = model_name
25
  self.agent = None
26
  self.tools_dict = None
27
+
28
+ super().__init__(model_name, system_prompt, **kwargs)
29
 
30
  def _setup(self) -> None:
31
  """Set up MedRAX agent system."""
 
33
  print("Starting server...")
34
 
35
  selected_tools = [
36
+ "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
 
 
 
 
 
 
 
 
 
 
37
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
38
+ "WebBrowserTool", # For web browsing and search capabilities
39
+ "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
40
+ "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
41
+ "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
42
+ "XRayVQATool", # For visual question answering on X-rays
43
+ "XRayPhraseGroundingTool", # For locating described features in X-rays
44
  ]
45
 
46
  rag_config = RAGConfig(
 
51
  pinecone_index_name="medrax2", # Name for the Pinecone index
52
  chunk_size=1500,
53
  chunk_overlap=300,
54
+ retriever_k=3,
55
  local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
56
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
57
  dataset_split="train", # Which split of the datasets to use
 
65
  tools_to_use=selected_tools,
66
  model_dir="/model-weights",
67
  temp_dir="temp", # Change this to the path of the temporary directory
68
+ device="cuda:0",
69
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
70
+ temperature=0.3,
71
  top_p=0.95,
72
  model_kwargs=model_kwargs,
73
  rag_config=rag_config,
74
+ system_prompt=self.prompt_name,
75
  )
76
 
77
  self.agent = agent
 
97
  if self.agent is None:
98
  return LLMResponse(
99
  content="Error: MedRAX agent not initialized",
100
+ duration=time.time() - start_time
 
101
  )
102
 
103
  try:
 
105
  messages = []
106
  thread_id = str(int(time.time() * 1000)) # Unique thread ID
107
 
 
 
108
  if request.images:
109
  valid_images = self._validate_image_paths(request.images)
110
  print(f"Processing {len(valid_images)} images")
111
  for i, image_path in enumerate(valid_images):
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # Add image path message for tools
113
+ messages.append(HumanMessage(content=f"image_path: {image_path}"))
 
 
 
114
 
115
  # Add image content for multimodal LLM
116
+ try:
117
+ with open(image_path, "rb") as img_file:
118
+ img_base64 = self._encode_image(image_path)
119
+
120
+ messages.append(HumanMessage(content=[{
 
121
  "type": "image_url",
122
  "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
123
+ }]))
124
+ except Exception as e:
125
+ print(f"ERROR: Image encoding failed for {image_path}: {e}")
126
+ raise
127
 
128
  # Add text message
129
+ if request.images:
130
+ # If there are images, add text as part of multimodal content
131
+ messages.append(HumanMessage(content=[{
132
  "type": "text",
133
  "text": request.text
134
+ }]))
135
+ else:
136
+ # If no images, add text as simple string
137
+ messages.append(HumanMessage(content=request.text))
138
+
139
+ # Run the agent with proper message type handling
140
+ final_response = ""
141
+ chunk_history = []
142
 
 
 
143
  for chunk in self.agent.workflow.stream(
144
  {"messages": messages},
145
  {"configurable": {"thread_id": thread_id}},
146
  stream_mode="updates"
147
  ):
148
+ if not isinstance(chunk, dict):
149
+ continue
150
+
151
+ for node_name, node_output in chunk.items():
152
+ # Log chunk and get serializable version
153
+ serializable_chunk = self._log_chunk(node_output, node_name)
154
+ chunk_history.append(serializable_chunk)
155
+
156
+ if "messages" not in node_output:
157
+ continue
158
+
159
+ for msg in node_output["messages"]:
160
+ if isinstance(msg, AIMessage) and msg.content:
161
+ # Handle case where content is a list
162
+ content = msg.content
163
+ if isinstance(content, list):
164
+ content = " ".join(content)
165
+ # Clean up the content (remove temp paths, etc.)
166
+ final_response = re.sub(r"temp/[^\s]*", "", content).strip()
167
+
168
+ # Determine the final response
169
+ if final_response:
170
+ response_content = final_response
171
+ else:
172
+ # Fallback if no LLM response was received
173
+ response_content = "No response generated"
174
 
175
  duration = time.time() - start_time
176
 
177
  return LLMResponse(
178
+ content=response_content,
179
  usage={"agent_tools": list(self.tools_dict.keys())},
180
  duration=duration,
181
+ chunk_history=chunk_history
182
  )
183
 
184
  except Exception as e:
185
+ print(f"ERROR: MedRAX agent failed: {e}")
186
  return LLMResponse(
187
  content=f"Error: {str(e)}",
188
+ duration=time.time() - start_time
 
189
  )
190
+
191
+ def _log_chunk(self, chunk: dict, node_name: str) -> dict:
192
+ """Log and process a chunk from the agent workflow.
193
+
194
+ Args:
195
+ chunk (dict): The chunk data from the agent workflow
196
+ node_name (str): Name of the node that produced the chunk
197
+
198
+ Returns:
199
+ dict: Serializable version of the chunk for debugging
200
+ """
201
+ # Log every chunk for debugging
202
+ print(f"Chunk from node '{node_name}': {type(chunk)}")
203
+
204
+ # Store serializable version of chunk for debugging
205
+ serializable_chunk = {
206
+ "node_name": node_name,
207
+ "node_type": type(chunk).__name__,
208
+ }
209
+
210
+ # Log messages in this chunk
211
+ if "messages" in chunk and isinstance(chunk, dict):
212
+ chunk_messages = []
213
+ for msg in chunk["messages"]:
214
+ msg_info = {
215
+ "type": type(msg).__name__,
216
+ "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
217
+ }
218
+ chunk_messages.append(msg_info)
219
+ print(f"Message in chunk: {msg_info}")
220
+ serializable_chunk["messages"] = chunk_messages
221
+
222
+ return serializable_chunk
benchmarking/llm_providers/openai_provider.py CHANGED
@@ -101,13 +101,11 @@ class OpenAIProvider(LLMProvider):
101
  return LLMResponse(
102
  content=content,
103
  usage=usage,
104
- duration=duration,
105
- raw_response=response
106
  )
107
 
108
  except Exception as e:
109
  return LLMResponse(
110
  content=f"Error: {str(e)}",
111
- duration=time.time() - start_time,
112
- raw_response=None
113
  )
 
101
  return LLMResponse(
102
  content=content,
103
  usage=usage,
104
+ duration=duration
 
105
  )
106
 
107
  except Exception as e:
108
  return LLMResponse(
109
  content=f"Error: {str(e)}",
110
+ duration=time.time() - start_time
 
111
  )
benchmarking/llm_providers/openrouter_provider.py CHANGED
@@ -78,12 +78,10 @@ class OpenRouterProvider(LLMProvider):
78
  return LLMResponse(
79
  content=content,
80
  usage=usage,
81
- duration=duration,
82
- raw_response=response
83
  )
84
  except Exception as e:
85
  return LLMResponse(
86
  content=f"Error: {str(e)}",
87
- duration=time.time() - start_time,
88
- raw_response=None
89
  )
 
78
  return LLMResponse(
79
  content=content,
80
  usage=usage,
81
+ duration=duration
 
82
  )
83
  except Exception as e:
84
  return LLMResponse(
85
  content=f"Error: {str(e)}",
86
+ duration=time.time() - start_time
 
87
  )
benchmarking/runner.py CHANGED
@@ -24,6 +24,7 @@ class BenchmarkResult:
24
  duration: float
25
  usage: Optional[Dict[str, Any]] = None
26
  error: Optional[str] = None
 
27
  metadata: Optional[Dict[str, Any]] = None
28
 
29
 
@@ -138,9 +139,11 @@ class BenchmarkRunner:
138
  # Add to results
139
  self.results.append(result)
140
 
 
 
 
141
  # Log progress
142
  if processed % 10 == 0:
143
- self._save_intermediate_results()
144
  accuracy = (correct / processed) * 100
145
  avg_duration = total_duration / processed
146
 
@@ -163,6 +166,9 @@ class BenchmarkRunner:
163
  error=str(e)
164
  )
165
  self.results.append(error_result)
 
 
 
166
  continue
167
 
168
  # Save final results
@@ -220,6 +226,7 @@ class BenchmarkRunner:
220
  is_correct=is_correct,
221
  duration=duration,
222
  usage=response.usage,
 
223
  metadata={
224
  "data_point_metadata": data_point.metadata,
225
  "case_id": data_point.case_id,
@@ -238,6 +245,7 @@ class BenchmarkRunner:
238
  is_correct=False,
239
  duration=duration,
240
  error=str(e),
 
241
  metadata={
242
  "data_point_metadata": data_point.metadata,
243
  "case_id": data_point.case_id,
@@ -254,9 +262,9 @@ class BenchmarkRunner:
254
  Returns:
255
  str: The extracted answer
256
  """
257
- # First, look for the '<|A|>' format
258
- final_answer_pattern = r'\s*<\|([A-F])\|>'
259
- match = re.search(final_answer_pattern, response_text)
260
  if match:
261
  return match.group(1).upper()
262
 
@@ -286,11 +294,55 @@ class BenchmarkRunner:
286
 
287
  return model_letter == correct_letter
288
 
289
- def _save_intermediate_results(self) -> None:
290
- """Save intermediate results to disk."""
291
- results_file = self.output_dir / f"{self.run_id}_intermediate.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
- # Convert results to serializable format
294
  results_data = []
295
  for result in self.results:
296
  results_data.append({
@@ -307,19 +359,6 @@ class BenchmarkRunner:
307
 
308
  with open(results_file, 'w') as f:
309
  json.dump(results_data, f, indent=2)
310
-
311
- def _save_final_results(self, benchmark: Benchmark) -> Dict[str, Any]:
312
- """Save final results and return summary.
313
-
314
- Args:
315
- benchmark (Benchmark): The benchmark that was run
316
-
317
- Returns:
318
- Dict[str, Any]: Summary of results
319
- """
320
- # Save detailed results
321
- results_file = self.output_dir / f"{self.run_id}_results.json"
322
- self._save_intermediate_results()
323
 
324
  # Calculate summary statistics
325
  total_questions = len(self.results)
 
24
  duration: float
25
  usage: Optional[Dict[str, Any]] = None
26
  error: Optional[str] = None
27
+ chunk_history: Optional[Dict[str, Any]] = None
28
  metadata: Optional[Dict[str, Any]] = None
29
 
30
 
 
139
  # Add to results
140
  self.results.append(result)
141
 
142
+ # Save individual result immediately
143
+ self._save_individual_result(result)
144
+
145
  # Log progress
146
  if processed % 10 == 0:
 
147
  accuracy = (correct / processed) * 100
148
  avg_duration = total_duration / processed
149
 
 
166
  error=str(e)
167
  )
168
  self.results.append(error_result)
169
+
170
+ # Save individual error result immediately
171
+ self._save_individual_result(error_result)
172
  continue
173
 
174
  # Save final results
 
226
  is_correct=is_correct,
227
  duration=duration,
228
  usage=response.usage,
229
+ chunk_history=response.chunk_history,
230
  metadata={
231
  "data_point_metadata": data_point.metadata,
232
  "case_id": data_point.case_id,
 
245
  is_correct=False,
246
  duration=duration,
247
  error=str(e),
248
+ chunk_history=None,
249
  metadata={
250
  "data_point_metadata": data_point.metadata,
251
  "case_id": data_point.case_id,
 
262
  Returns:
263
  str: The extracted answer
264
  """
265
+ # Look for the '\boxed{A}' format
266
+ boxed_pattern = r'\\boxed\{([A-Fa-f])\}'
267
+ match = re.search(boxed_pattern, response_text)
268
  if match:
269
  return match.group(1).upper()
270
 
 
294
 
295
  return model_letter == correct_letter
296
 
297
+ def _save_individual_result(self, result: BenchmarkResult) -> None:
298
+ """Save a single result to its own JSON file.
299
+
300
+ Args:
301
+ result (BenchmarkResult): The result to save
302
+ """
303
+ # Sanitize data_point_id for filename (remove invalid characters)
304
+ safe_id = re.sub(r'[^\w\-_.]', '_', result.data_point_id)
305
+
306
+ # Create filename with benchmark name and data point ID
307
+ filename = f"{self.config.benchmark_name}_{safe_id}.json"
308
+ result_file = self.output_dir / "individual_results" / filename
309
+
310
+ # Create individual_results directory if it doesn't exist
311
+ result_file.parent.mkdir(exist_ok=True)
312
+
313
+ # Convert result to serializable format
314
+ result_data = {
315
+ "timestamp": datetime.now().isoformat(),
316
+ "run_id": self.run_id,
317
+ "data_point_id": result.data_point_id,
318
+ "question": result.question,
319
+ "model_answer": result.model_answer,
320
+ "correct_answer": result.correct_answer,
321
+ "is_correct": result.is_correct,
322
+ "duration": result.duration,
323
+ "usage": result.usage,
324
+ "error": result.error,
325
+ "chunk_history": result.chunk_history,
326
+ "metadata": result.metadata
327
+ }
328
+
329
+ # Save to file
330
+ with open(result_file, 'w') as f:
331
+ json.dump(result_data, f, indent=2)
332
+
333
+ def _save_final_results(self, benchmark: Benchmark) -> Dict[str, Any]:
334
+ """Save final results and return summary.
335
+
336
+ Args:
337
+ benchmark (Benchmark): The benchmark that was run
338
+
339
+ Returns:
340
+ Dict[str, Any]: Summary of results
341
+ """
342
+ # Save detailed results
343
+ results_file = self.output_dir / f"{self.run_id}_results.json"
344
 
345
+ # Convert results to serializable format for final file
346
  results_data = []
347
  for result in self.results:
348
  results_data.append({
 
359
 
360
  with open(results_file, 'w') as f:
361
  json.dump(results_data, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
  # Calculate summary statistics
364
  total_questions = len(self.results)
interface.py CHANGED
@@ -192,7 +192,11 @@ class ChatInterface:
192
  tool_args = pending_call["args"]
193
 
194
  try:
195
- tool_output_json = json.loads(msg.content)
 
 
 
 
196
  tool_output_str = json.dumps(tool_output_json, indent=2)
197
  except (json.JSONDecodeError, TypeError):
198
  tool_output_str = str(msg.content)
@@ -217,10 +221,11 @@ class ChatInterface:
217
 
218
  if tool_name == "image_visualizer":
219
  try:
220
- result = json.loads(msg.content)
221
- # Handle case where tool returns array [output, metadata]
222
- if isinstance(result, list) and len(result) > 0:
223
- result = result[0] # Take the first element (output)
 
224
  if isinstance(result, dict) and "image_path" in result:
225
  self.display_file_path = result["image_path"]
226
  chat_history.append(
 
192
  tool_args = pending_call["args"]
193
 
194
  try:
195
+ # Handle case where tool returns tuple (output, metadata)
196
+ content = msg.content
197
+ content_tuple = ast.literal_eval(content)
198
+ content = json.dumps(content_tuple[0])
199
+ tool_output_json = json.loads(content)
200
  tool_output_str = json.dumps(tool_output_json, indent=2)
201
  except (json.JSONDecodeError, TypeError):
202
  tool_output_str = str(msg.content)
 
221
 
222
  if tool_name == "image_visualizer":
223
  try:
224
+ # Handle case where tool returns tuple (output, metadata)
225
+ content = msg.content
226
+ content_tuple = ast.literal_eval(content)
227
+ result = content_tuple[0]
228
+
229
  if isinstance(result, dict) and "image_path" in result:
230
  self.display_file_path = result["image_path"]
231
  chat_history.append(
main.py CHANGED
@@ -41,7 +41,7 @@ def initialize_agent(
41
  top_p: float = 0.95,
42
  rag_config: Optional[RAGConfig] = None,
43
  model_kwargs: Dict[str, Any] = {},
44
- debug: bool = False,
45
  ):
46
  """Initialize the MedRAX agent with specified tools and configuration.
47
 
@@ -56,6 +56,7 @@ def initialize_agent(
56
  top_p (float, optional): Top P for the model. Defaults to 0.95.
57
  rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
58
  model_kwargs (dict, optional): Additional keyword arguments for model.
 
59
  debug (bool, optional): Whether to enable debug mode. Defaults to False.
60
 
61
  Returns:
@@ -63,7 +64,7 @@ def initialize_agent(
63
  """
64
  # Load system prompts from file
65
  prompts = load_prompts_from_file(prompt_file)
66
- prompt = prompts["MEDICAL_ASSISTANT"]
67
 
68
  # Define the URL of the MedGemma FastAPI service.
69
  MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://127.0.0.1:8002")
@@ -126,7 +127,6 @@ def initialize_agent(
126
  log_dir="logs",
127
  system_prompt=prompt,
128
  checkpointer=checkpointer,
129
- debug=debug,
130
  )
131
  print("Agent initialized")
132
 
@@ -145,19 +145,20 @@ if __name__ == "__main__":
145
  selected_tools = [
146
  "ImageVisualizerTool", # For displaying images in the UI
147
  # "DicomProcessorTool", # For processing DICOM medical image files
148
- # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
149
- # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
150
- # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
151
- # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
152
- # "XRayVQATool", # For visual question answering on X-rays
153
  # "LlavaMedTool", # For multimodal medical image understanding
154
- # "XRayPhraseGroundingTool", # For locating described features in X-rays
155
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
156
  # "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
157
  # "WebBrowserTool", # For web browsing and search capabilities
158
  # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
159
  # "PythonSandboxTool", # Add the Python sandbox tool
160
  "MedGemmaVQATool" # Google MedGemma VQA tool
 
161
  ]
162
 
163
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
@@ -174,7 +175,7 @@ if __name__ == "__main__":
174
  pinecone_index_name="medrax2", # Name for the Pinecone index
175
  chunk_size=1500,
176
  chunk_overlap=300,
177
- retriever_k=7,
178
  local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
179
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
180
  dataset_split="train", # Which split of the datasets to use
@@ -186,15 +187,15 @@ if __name__ == "__main__":
186
  agent, tools_dict = initialize_agent(
187
  prompt_file="medrax/docs/system_prompts.txt",
188
  tools_to_use=selected_tools,
189
- model_dir="model-weights",
190
  temp_dir="temp", # Change this to the path of the temporary directory
191
- device="cuda",
192
  model="gpt-4.1-2025-04-14", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
193
  temperature=0.7,
194
  top_p=0.95,
195
  model_kwargs=model_kwargs,
196
  rag_config=rag_config,
197
- debug=True,
198
  )
199
 
200
  # Create and launch the web interface
 
41
  top_p: float = 0.95,
42
  rag_config: Optional[RAGConfig] = None,
43
  model_kwargs: Dict[str, Any] = {},
44
+ system_prompt: str = "MEDICAL_ASSISTANT",
45
  ):
46
  """Initialize the MedRAX agent with specified tools and configuration.
47
 
 
56
  top_p (float, optional): Top P for the model. Defaults to 0.95.
57
  rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
58
  model_kwargs (dict, optional): Additional keyword arguments for model.
59
+ system_prompt (str, optional): System prompt to use. Defaults to "MEDICAL_ASSISTANT".
60
  debug (bool, optional): Whether to enable debug mode. Defaults to False.
61
 
62
  Returns:
 
64
  """
65
  # Load system prompts from file
66
  prompts = load_prompts_from_file(prompt_file)
67
+ prompt = prompts[system_prompt]
68
 
69
  # Define the URL of the MedGemma FastAPI service.
70
  MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://127.0.0.1:8002")
 
127
  log_dir="logs",
128
  system_prompt=prompt,
129
  checkpointer=checkpointer,
 
130
  )
131
  print("Agent initialized")
132
 
 
145
  selected_tools = [
146
  "ImageVisualizerTool", # For displaying images in the UI
147
  # "DicomProcessorTool", # For processing DICOM medical image files
148
+ "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
149
+ "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
150
+ "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
151
+ "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
152
+ "XRayVQATool", # For visual question answering on X-rays
153
  # "LlavaMedTool", # For multimodal medical image understanding
154
+ "XRayPhraseGroundingTool", # For locating described features in X-rays
155
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
156
  # "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
157
  # "WebBrowserTool", # For web browsing and search capabilities
158
  # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
159
  # "PythonSandboxTool", # Add the Python sandbox tool
160
  "MedGemmaVQATool" # Google MedGemma VQA tool
161
+ "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
162
  ]
163
 
164
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
 
175
  pinecone_index_name="medrax2", # Name for the Pinecone index
176
  chunk_size=1500,
177
  chunk_overlap=300,
178
+ retriever_k=3,
179
  local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
180
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
181
  dataset_split="train", # Which split of the datasets to use
 
187
  agent, tools_dict = initialize_agent(
188
  prompt_file="medrax/docs/system_prompts.txt",
189
  tools_to_use=selected_tools,
190
+ model_dir="/model-weights",
191
  temp_dir="temp", # Change this to the path of the temporary directory
192
+ device="cuda:0",
193
  model="gpt-4.1-2025-04-14", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
194
  temperature=0.7,
195
  top_p=0.95,
196
  model_kwargs=model_kwargs,
197
  rag_config=rag_config,
198
+ system_prompt="MEDICAL_ASSISTANT",
199
  )
200
 
201
  # Create and launch the web interface
medrax/agent/__init__.py CHANGED
@@ -1 +1 @@
1
- from .agent import State, Agent
 
1
+ from .agent import AgentState, Agent
medrax/agent/agent.py CHANGED
@@ -5,9 +5,8 @@ from dotenv import load_dotenv
5
  from datetime import datetime
6
  from typing import List, Dict, Any, TypedDict, Annotated, Optional
7
 
8
- from langgraph.prebuilt import create_react_agent
9
- from langchain_core.messages import AnyMessage
10
- from langgraph.prebuilt.chat_agent_executor import AgentState
11
  from langchain_core.language_models import BaseLanguageModel
12
  from langchain_core.tools import BaseTool
13
 
@@ -33,19 +32,17 @@ class ToolCallLog(TypedDict):
33
  content: str
34
 
35
 
36
- class State(AgentState):
37
  """
38
- A AgentState representing the state of an agent.
39
 
40
  Attributes:
41
- session_bytes (bytes): The pickled state of the sandbox session. This is
42
- required for stateful tools and should not be modified directly.
43
- session_metadata (dict): Metadata associated with the sandbox session.
44
  """
45
 
46
- # Required for the stateful PyodideSandboxTool
47
- session_bytes: bytes = b""
48
- session_metadata: dict = {}
49
 
50
 
51
  class Agent:
@@ -55,7 +52,7 @@ class Agent:
55
 
56
  Attributes:
57
  model (BaseLanguageModel): The language model used for processing.
58
- tools (List[BaseTool]): A list of available tools.
59
  checkpointer (Any): Manages and persists the agent's state.
60
  system_prompt (str): The system instructions for the agent.
61
  workflow (StateGraph): The compiled workflow for the agent's processing.
@@ -71,7 +68,6 @@ class Agent:
71
  system_prompt: str = "",
72
  log_tools: bool = True,
73
  log_dir: Optional[str] = "logs",
74
- debug: bool = False,
75
  ):
76
  """
77
  Initialize the Agent.
@@ -83,7 +79,6 @@ class Agent:
83
  system_prompt (str, optional): System instructions. Defaults to "".
84
  log_tools (bool, optional): Whether to log tool calls. Defaults to True.
85
  log_dir (str, optional): Directory to save logs. Defaults to 'logs'.
86
- debug (bool, optional): Whether to enable debug mode. Defaults to False.
87
  """
88
  self.system_prompt = system_prompt
89
  self.log_tools = log_tools
@@ -92,12 +87,107 @@ class Agent:
92
  self.log_path = Path(log_dir or "logs")
93
  self.log_path.mkdir(exist_ok=True)
94
 
95
- self.workflow = create_react_agent(
96
- model=model,
97
- tools=tools,
98
- checkpointer=checkpointer,
99
- state_schema=State,
100
- prompt=system_prompt if system_prompt else None,
101
- debug=debug,
102
  )
 
 
 
 
103
  self.tools = {t.name: t for t in tools}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from datetime import datetime
6
  from typing import List, Dict, Any, TypedDict, Annotated, Optional
7
 
8
+ from langgraph.graph import StateGraph, END
9
+ from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage
 
10
  from langchain_core.language_models import BaseLanguageModel
11
  from langchain_core.tools import BaseTool
12
 
 
32
  content: str
33
 
34
 
35
+ class AgentState(TypedDict):
36
  """
37
+ A TypedDict representing the state of an agent.
38
 
39
  Attributes:
40
+ messages (Annotated[List[AnyMessage], operator.add]): A list of messages
41
+ representing the conversation history. The operator.add annotation
42
+ indicates that new messages should be appended to this list.
43
  """
44
 
45
+ messages: Annotated[List[AnyMessage], operator.add]
 
 
46
 
47
 
48
  class Agent:
 
52
 
53
  Attributes:
54
  model (BaseLanguageModel): The language model used for processing.
55
+ tools (Dict[str, BaseTool]): A dictionary of available tools.
56
  checkpointer (Any): Manages and persists the agent's state.
57
  system_prompt (str): The system instructions for the agent.
58
  workflow (StateGraph): The compiled workflow for the agent's processing.
 
68
  system_prompt: str = "",
69
  log_tools: bool = True,
70
  log_dir: Optional[str] = "logs",
 
71
  ):
72
  """
73
  Initialize the Agent.
 
79
  system_prompt (str, optional): System instructions. Defaults to "".
80
  log_tools (bool, optional): Whether to log tool calls. Defaults to True.
81
  log_dir (str, optional): Directory to save logs. Defaults to 'logs'.
 
82
  """
83
  self.system_prompt = system_prompt
84
  self.log_tools = log_tools
 
87
  self.log_path = Path(log_dir or "logs")
88
  self.log_path.mkdir(exist_ok=True)
89
 
90
+ # Define the agent workflow
91
+ workflow = StateGraph(AgentState)
92
+ workflow.add_node("process", self.process_request)
93
+ workflow.add_node("execute", self.execute_tools)
94
+ workflow.add_conditional_edges(
95
+ "process", self.has_tool_calls, {True: "execute", False: END}
 
96
  )
97
+ workflow.add_edge("execute", "process")
98
+ workflow.set_entry_point("process")
99
+
100
+ self.workflow = workflow.compile(checkpointer=checkpointer)
101
  self.tools = {t.name: t for t in tools}
102
+ self.model = model.bind_tools(tools)
103
+
104
+ def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]:
105
+ """
106
+ Process the request using the language model.
107
+
108
+ Args:
109
+ state (AgentState): The current state of the agent.
110
+
111
+ Returns:
112
+ Dict[str, List[AnyMessage]]: A dictionary containing the model's response.
113
+ """
114
+ messages = state["messages"]
115
+ if self.system_prompt:
116
+ messages = [SystemMessage(content=self.system_prompt)] + messages
117
+ response = self.model.invoke(messages)
118
+ return {"messages": [response]}
119
+
120
+ def has_tool_calls(self, state: AgentState) -> bool:
121
+ """
122
+ Check if the response contains any tool calls.
123
+
124
+ Args:
125
+ state (AgentState): The current state of the agent.
126
+
127
+ Returns:
128
+ bool: True if tool calls exist, False otherwise.
129
+ """
130
+ response = state["messages"][-1]
131
+ return len(response.tool_calls) > 0
132
+
133
+ def execute_tools(self, state: AgentState) -> Dict[str, List[ToolMessage]]:
134
+ """
135
+ Execute tool calls from the model's response.
136
+
137
+ Args:
138
+ state (AgentState): The current state of the agent.
139
+
140
+ Returns:
141
+ Dict[str, List[ToolMessage]]: A dictionary containing tool execution results.
142
+ """
143
+ tool_calls = state["messages"][-1].tool_calls
144
+ results = []
145
+
146
+ for call in tool_calls:
147
+ print(f"Executing tool: {call}")
148
+ if call["name"] not in self.tools:
149
+ print("\n....invalid tool....")
150
+ result = "invalid tool, please retry"
151
+ else:
152
+ result = self.tools[call["name"]].invoke(call["args"])
153
+
154
+ results.append(
155
+ ToolMessage(
156
+ tool_call_id=call["id"],
157
+ name=call["name"],
158
+ args=call["args"],
159
+ content=str(result),
160
+ )
161
+ )
162
+
163
+ self._save_tool_calls(results)
164
+ print("Returning to model processing!")
165
+
166
+ return {"messages": results}
167
+
168
+ def _save_tool_calls(self, tool_calls: List[ToolMessage]) -> None:
169
+ """
170
+ Save tool calls to a JSON file with timestamp-based naming.
171
+
172
+ Args:
173
+ tool_calls (List[ToolMessage]): List of tool calls to save.
174
+ """
175
+ if not self.log_tools:
176
+ return
177
+
178
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
179
+ filename = self.log_path / f"tool_calls_{timestamp}.json"
180
+
181
+ logs: List[ToolCallLog] = []
182
+ for call in tool_calls:
183
+ log_entry = {
184
+ "tool_call_id": call.tool_call_id,
185
+ "name": call.name,
186
+ "args": call.args,
187
+ "content": call.content,
188
+ "timestamp": datetime.now().isoformat(),
189
+ }
190
+ logs.append(log_entry)
191
+
192
+ with open(filename, "w") as f:
193
+ json.dump(logs, f, indent=4)
medrax/docs/system_prompts.txt CHANGED
@@ -22,5 +22,5 @@ Solve using your own vision and reasoning and use tools (if available) to comple
22
  You can make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
23
  Think critically about and criticize the tool outputs.
24
  If you need to look up some information before asking a follow up question, you are allowed to do that.
25
- When encountering a multiple-choice question, your final response should end with "Final answer: <|A|>" from list of possible choices A, B, C, D, E, F.
26
  It is extremely important that you strictly answer in the format mentioned above.
 
22
  You can make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
23
  Think critically about and criticize the tool outputs.
24
  If you need to look up some information before asking a follow up question, you are allowed to do that.
25
+ When encountering a multiple-choice question, your final response should end with "Final answer: \boxed{A}" from list of possible choices A, B, C, D, E, F.
26
  It is extremely important that you strictly answer in the format mentioned above.
medrax/tools/__init__.py CHANGED
@@ -5,10 +5,10 @@ from .report_generation import *
5
  from .segmentation import *
6
  from .vqa import *
7
  from .grounding import *
8
- from .generation import *
9
  from .dicom import *
10
  from .utils import *
11
  from .rag import *
12
- from .web_browser import *
13
  from .python_tool import *
14
  from .medsam2 import *
 
5
  from .segmentation import *
6
  from .vqa import *
7
  from .grounding import *
8
+ from .xray_generation import *
9
  from .dicom import *
10
  from .utils import *
11
  from .rag import *
12
+ from .browsing import *
13
  from .python_tool import *
14
  from .medsam2 import *
medrax/tools/browsing/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web browsing tools for MedRAX2 medical agents."""
2
+
3
+ from .duckduckgo import DuckDuckGoSearchTool, WebSearchInput
4
+ from .web_browser import WebBrowserTool, WebBrowserSchema, SearchQuerySchema, VisitUrlSchema
5
+
6
+ __all__ = [
7
+ "DuckDuckGoSearchTool",
8
+ "WebSearchInput",
9
+ "WebBrowserTool",
10
+ "WebBrowserSchema",
11
+ "SearchQuerySchema",
12
+ "VisitUrlSchema"
13
+ ]
medrax/tools/browsing/duckduckgo.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Web search tool for MedRAX2 medical agents.
3
+
4
+ Provides DuckDuckGo search capabilities for medical agents to retrieve
5
+ real-time information from the web with proper error handling
6
+ and result formatting. Designed specifically for medical research,
7
+ fact-checking, and accessing current medical information.
8
+ """
9
+
10
+ import asyncio
11
+ import logging
12
+ import time
13
+ from datetime import datetime
14
+ from typing import Dict, Any, Tuple
15
+
16
+ from langchain_core.callbacks import (
17
+ AsyncCallbackManagerForToolRun,
18
+ CallbackManagerForToolRun,
19
+ )
20
+ from langchain_core.tools import BaseTool
21
+ from pydantic import BaseModel, Field
22
+
23
+ try:
24
+ from duckduckgo_search import DDGS
25
+ except ImportError:
26
+ DDGS = None
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class WebSearchInput(BaseModel):
32
+ """Input schema for web search tool."""
33
+
34
+ query: str = Field(
35
+ ...,
36
+ description="The search query to look up on the web. Be specific and include relevant medical keywords for better results.",
37
+ min_length=1,
38
+ max_length=500,
39
+ )
40
+ max_results: int = Field(
41
+ default=5,
42
+ description="Maximum number of search results to return (1-10)",
43
+ ge=1,
44
+ le=10,
45
+ )
46
+ region: str = Field(
47
+ default="us-en",
48
+ description="Region for search results (e.g., 'us-en', 'uk-en', 'ca-en')",
49
+ )
50
+
51
+
52
+ class DuckDuckGoSearchTool(BaseTool):
53
+ """
54
+ Tool that performs web searches using DuckDuckGo search engine for medical research.
55
+
56
+ This tool provides access to real-time web information through DuckDuckGo's
57
+ search API, specifically designed for medical agents that need to retrieve current
58
+ medical information, verify facts, or find resources on medical topics.
59
+
60
+ Features:
61
+ - Real-time web search capability for medical information
62
+ - Configurable number of results (1-10)
63
+ - Regional search support for localized medical results
64
+ - Robust error handling for network issues
65
+ - Structured result formatting for easy parsing
66
+ - Privacy-focused (DuckDuckGo doesn't track users)
67
+ - Medical-focused search optimization
68
+
69
+ Use Cases:
70
+ - Medical fact checking and verification
71
+ - Finding current medical news and updates
72
+ - Researching specific medical topics or questions
73
+ - Gathering multiple perspectives on medical issues
74
+ - Locating official medical resources and documentation
75
+ - Accessing current clinical guidelines and research
76
+
77
+ Rate Limiting:
78
+ DuckDuckGo has rate limits. Avoid making too many rapid requests
79
+ to prevent temporary blocking.
80
+ """
81
+
82
+ name: str = "duckduckgo_search"
83
+ description: str = (
84
+ "Search the web using DuckDuckGo to find current medical information, research, and resources. "
85
+ "Input should be a clear search query with relevant medical keywords. The tool returns a list of relevant web results "
86
+ "with titles, URLs, and brief snippets. Useful for medical fact-checking, finding current medical events, "
87
+ "researching medical topics, and gathering information from reliable medical sources. "
88
+ "Results are privacy-focused and don't track user searches. Optimized for medical research and clinical information."
89
+ )
90
+ args_schema: type[BaseModel] = WebSearchInput
91
+ return_direct: bool = False
92
+
93
+ def __init__(self, **kwargs):
94
+ """Initialize the DuckDuckGo search tool."""
95
+ super().__init__(**kwargs)
96
+
97
+ if DDGS is None:
98
+ logger.error(
99
+ "duckduckgo-search package not installed. Install with: pip install duckduckgo-search"
100
+ )
101
+ raise ImportError(
102
+ "duckduckgo-search package is required for web search functionality"
103
+ )
104
+
105
+ logger.info("DuckDuckGo search tool initialized successfully")
106
+
107
+ def _perform_search_sync(
108
+ self, query: str, max_results: int = 5, region: str = "us-en"
109
+ ) -> Dict[str, Any]:
110
+ """
111
+ Perform the actual web search using DuckDuckGo synchronously.
112
+
113
+ Args:
114
+ query (str): The search query.
115
+ max_results (int): Maximum number of results to return.
116
+ region (str): Region for localized results.
117
+
118
+ Returns:
119
+ Dict[str, Any]: Structured search results.
120
+ """
121
+ logger.info(
122
+ f"Performing web search: '{query}' (max_results={max_results}, region={region})"
123
+ )
124
+
125
+ try:
126
+ # Initialize DDGS with error handling
127
+ with DDGS() as ddgs:
128
+ # Perform the search
129
+ search_results = list(
130
+ ddgs.text(
131
+ keywords=query,
132
+ region=region,
133
+ safesearch="moderate",
134
+ timelimit=None,
135
+ max_results=max_results,
136
+ )
137
+ )
138
+
139
+ # Format results for the agent
140
+ formatted_results = []
141
+ for i, result in enumerate(search_results, 1):
142
+ formatted_result = {
143
+ "rank": i,
144
+ "title": result.get("title", "No title"),
145
+ "url": result.get("href", "No URL"),
146
+ "snippet": result.get("body", "No description available"),
147
+ "source": "DuckDuckGo",
148
+ }
149
+ formatted_results.append(formatted_result)
150
+
151
+ # Create summary for the agent
152
+ if formatted_results:
153
+ summary = (
154
+ f"Found {len(formatted_results)} results for '{query}'. Top results include: "
155
+ + ", ".join([f"{r['title']}" for r in formatted_results[:3]])
156
+ )
157
+ else:
158
+ summary = f"No results found for '{query}'"
159
+
160
+ # Log successful completion
161
+ logger.info(
162
+ f"Web search completed successfully: {len(formatted_results)} results"
163
+ )
164
+
165
+ return {
166
+ "query": query,
167
+ "results_count": len(formatted_results),
168
+ "results": formatted_results,
169
+ "summary": summary,
170
+ "search_engine": "DuckDuckGo",
171
+ "timestamp": datetime.now().isoformat(),
172
+ }
173
+
174
+ except Exception as e:
175
+ error_msg = f"Web search failed for query '{query}': {str(e)}"
176
+ logger.error(f"{error_msg}")
177
+
178
+ return {
179
+ "query": query,
180
+ "results_count": 0,
181
+ "results": [],
182
+ "error": error_msg,
183
+ "search_engine": "DuckDuckGo",
184
+ "timestamp": datetime.now().isoformat(),
185
+ }
186
+
187
+ def _run(
188
+ self,
189
+ query: str,
190
+ max_results: int = 5,
191
+ region: str = "us-en",
192
+ run_manager: CallbackManagerForToolRun | None = None,
193
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
194
+ """
195
+ Execute the web search synchronously.
196
+
197
+ Args:
198
+ query (str): Search query
199
+ max_results (int): Maximum number of results
200
+ region (str): Search region
201
+ run_manager: Callback manager (unused)
202
+
203
+ Returns:
204
+ Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing:
205
+ - output: Dictionary with search results
206
+ - metadata: Dictionary with execution metadata
207
+ """
208
+ # Create metadata structure
209
+ metadata = {
210
+ "query": query,
211
+ "max_results": max_results,
212
+ "region": region,
213
+ "timestamp": time.time(),
214
+ "tool": "duckduckgo_search",
215
+ "operation": "search",
216
+ }
217
+
218
+ try:
219
+ result = self._perform_search_sync(query, max_results, region)
220
+
221
+ # Check if search was successful
222
+ if "error" in result:
223
+ metadata["analysis_status"] = "failed"
224
+ metadata["error_details"] = result["error"]
225
+ else:
226
+ metadata["analysis_status"] = "completed"
227
+ metadata["results_count"] = result.get("results_count", 0)
228
+
229
+ return result, metadata
230
+
231
+ except Exception as e:
232
+ error_result = {
233
+ "query": query,
234
+ "results_count": 0,
235
+ "results": [],
236
+ "error": str(e),
237
+ "search_engine": "DuckDuckGo",
238
+ "timestamp": datetime.now().isoformat(),
239
+ }
240
+ metadata["analysis_status"] = "failed"
241
+ metadata["error_details"] = str(e)
242
+
243
+ return error_result, metadata
244
+
245
+ async def _arun(
246
+ self,
247
+ query: str,
248
+ max_results: int = 5,
249
+ region: str = "us-en",
250
+ run_manager: AsyncCallbackManagerForToolRun | None = None,
251
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
252
+ """
253
+ Execute the web search asynchronously.
254
+
255
+ Args:
256
+ query (str): Search query
257
+ max_results (int): Maximum number of results
258
+ region (str): Search region
259
+ run_manager: Callback manager (unused)
260
+
261
+ Returns:
262
+ Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing:
263
+ - output: Dictionary with search results
264
+ - metadata: Dictionary with execution metadata
265
+ """
266
+ # Try to get LangGraph stream writer for progress updates
267
+ writer = None
268
+ try:
269
+ from langgraph.config import get_stream_writer
270
+
271
+ writer = get_stream_writer()
272
+ except Exception:
273
+ # Stream writer not available (outside LangGraph context)
274
+ pass
275
+
276
+ if writer:
277
+ writer(
278
+ {
279
+ "tool_name": "DuckDuckGoSearchTool",
280
+ "status": "started",
281
+ "query": query,
282
+ "max_results": max_results,
283
+ "step": "Initiating web search",
284
+ }
285
+ )
286
+
287
+ try:
288
+ if writer:
289
+ writer(
290
+ {
291
+ "tool_name": "DuckDuckGoSearchTool",
292
+ "status": "searching",
293
+ "step": "Fetching results from DuckDuckGo API",
294
+ }
295
+ )
296
+
297
+ # Use asyncio to run sync search in executor
298
+ loop = asyncio.get_event_loop()
299
+ result, metadata = await loop.run_in_executor(
300
+ None, self._run, query, max_results, region
301
+ )
302
+
303
+ if writer:
304
+ # Parse result to get count for progress update
305
+ results_count = result.get("results_count", 0)
306
+ writer(
307
+ {
308
+ "tool_name": "DuckDuckGoSearchTool",
309
+ "status": "completed",
310
+ "step": f"Search completed with {results_count} results",
311
+ "results_count": results_count,
312
+ }
313
+ )
314
+
315
+ return result, metadata
316
+
317
+ except Exception as e:
318
+ if writer:
319
+ writer(
320
+ {
321
+ "tool_name": "DuckDuckGoSearchTool",
322
+ "status": "error",
323
+ "step": f"Search failed: {str(e)}",
324
+ "error": str(e),
325
+ }
326
+ )
327
+
328
+ error_result = {
329
+ "query": query,
330
+ "results_count": 0,
331
+ "results": [],
332
+ "error": str(e),
333
+ "search_engine": "DuckDuckGo",
334
+ "timestamp": datetime.now().isoformat(),
335
+ }
336
+
337
+ metadata = {
338
+ "query": query,
339
+ "max_results": max_results,
340
+ "region": region,
341
+ "timestamp": time.time(),
342
+ "tool": "duckduckgo_search",
343
+ "operation": "search",
344
+ "analysis_status": "failed",
345
+ "error_details": str(e),
346
+ }
347
+
348
+ return error_result, metadata
349
+
350
+ def get_search_summary(
351
+ self, query: str, max_results: int = 3
352
+ ) -> dict[str, str | list[str]]:
353
+ """
354
+ Get a quick summary of search results for a given query.
355
+
356
+ Args:
357
+ query (str): The search query.
358
+ max_results (int): Maximum number of results to summarize.
359
+
360
+ Returns:
361
+ Dict[str, Union[str, List[str]]]: Summary of search results.
362
+ """
363
+ try:
364
+ result, _ = self._run(query, max_results)
365
+
366
+ if "error" in result:
367
+ return {
368
+ "query": query,
369
+ "status": "error",
370
+ "error": result["error"],
371
+ "results": [],
372
+ }
373
+
374
+ # Extract key information
375
+ results = result.get("results", [])
376
+ titles = [r["title"] for r in results]
377
+ urls = [r["url"] for r in results]
378
+ snippets = [
379
+ (
380
+ r["snippet"][:100] + "..."
381
+ if len(r["snippet"]) > 100
382
+ else r["snippet"]
383
+ )
384
+ for r in results
385
+ ]
386
+
387
+ return {
388
+ "query": query,
389
+ "status": "success",
390
+ "total_results": result.get("results_count", 0),
391
+ "titles": titles,
392
+ "urls": urls,
393
+ "snippets": snippets,
394
+ }
395
+
396
+ except Exception as e:
397
+ logger.error(f"Error getting search summary: {e}")
398
+ return {
399
+ "query": query,
400
+ "status": "error",
401
+ "error": str(e),
402
+ "results": [],
403
+ }
medrax/tools/{web_browser.py → browsing/web_browser.py} RENAMED
File without changes
medrax/tools/classification/arcplus.py CHANGED
@@ -345,7 +345,8 @@ class ArcPlusClassifierTool(BaseTool):
345
  predictions = predictions[: len(self.disease_list)]
346
 
347
  # Create output dictionary mapping disease names to probabilities
348
- output = dict(zip(self.disease_list, predictions.astype(float)))
 
349
 
350
  metadata = {
351
  "image_path": image_path,
 
345
  predictions = predictions[: len(self.disease_list)]
346
 
347
  # Create output dictionary mapping disease names to probabilities
348
+ # Convert numpy floats to native Python floats for proper serialization
349
+ output = dict(zip(self.disease_list, [float(pred) for pred in predictions]))
350
 
351
  metadata = {
352
  "image_path": image_path,
medrax/tools/llava_med.py DELETED
@@ -1,193 +0,0 @@
1
- from typing import Any, Dict, Optional, Tuple, Type
2
- from pydantic import BaseModel, Field
3
-
4
- import torch
5
-
6
- from langchain_core.callbacks import (
7
- AsyncCallbackManagerForToolRun,
8
- CallbackManagerForToolRun,
9
- )
10
- from langchain_core.tools import BaseTool
11
-
12
- from PIL import Image
13
-
14
-
15
- from medrax.llava.conversation import conv_templates
16
- from medrax.llava.model.builder import load_pretrained_model
17
- from medrax.llava.mm_utils import tokenizer_image_token, process_images
18
- from medrax.llava.constants import (
19
- IMAGE_TOKEN_INDEX,
20
- DEFAULT_IMAGE_TOKEN,
21
- DEFAULT_IM_START_TOKEN,
22
- DEFAULT_IM_END_TOKEN,
23
- )
24
-
25
-
26
- class LlavaMedInput(BaseModel):
27
- """Input for the LLaVA-Med Visual QA tool. Only supports JPG or PNG images."""
28
-
29
- question: str = Field(..., description="The question to ask about the medical image")
30
- image_path: Optional[str] = Field(
31
- None,
32
- description="Path to the medical image file (optional), only supports JPG or PNG images",
33
- )
34
-
35
-
36
- class LlavaMedTool(BaseTool):
37
- """Tool that performs medical visual question answering using LLaVA-Med.
38
-
39
- This tool uses a large language model fine-tuned on medical images to answer
40
- questions about medical images. It can handle both image-based questions and
41
- general medical questions without images.
42
- """
43
-
44
- name: str = "llava_med_qa"
45
- description: str = (
46
- "A tool that answers questions about biomedical images and general medical questions using LLaVA-Med. "
47
- "While it can process chest X-rays, it may not be as reliable for detailed chest X-ray analysis. "
48
- "Input should be a question and optionally a path to a medical image file."
49
- )
50
- args_schema: Type[BaseModel] = LlavaMedInput
51
- tokenizer: Any = None
52
- model: Any = None
53
- image_processor: Any = None
54
- context_len: int = 200000
55
-
56
- def __init__(
57
- self,
58
- model_path: str = "microsoft/llava-med-v1.5-mistral-7b",
59
- cache_dir: str = "/model-weights",
60
- low_cpu_mem_usage: bool = True,
61
- torch_dtype: torch.dtype = torch.bfloat16,
62
- device: str = "cuda",
63
- load_in_4bit: bool = False,
64
- load_in_8bit: bool = False,
65
- **kwargs,
66
- ):
67
- super().__init__()
68
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
69
- model_path=model_path,
70
- model_base=None,
71
- model_name=model_path,
72
- load_in_4bit=load_in_4bit,
73
- load_in_8bit=load_in_8bit,
74
- cache_dir=cache_dir,
75
- low_cpu_mem_usage=low_cpu_mem_usage,
76
- torch_dtype=torch_dtype,
77
- device=device,
78
- **kwargs,
79
- )
80
- self.model.eval()
81
-
82
- def _process_input(
83
- self, question: str, image_path: Optional[str] = None
84
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
85
- if self.model.config.mm_use_im_start_end:
86
- question = (
87
- DEFAULT_IM_START_TOKEN
88
- + DEFAULT_IMAGE_TOKEN
89
- + DEFAULT_IM_END_TOKEN
90
- + "\n"
91
- + question
92
- )
93
- else:
94
- question = DEFAULT_IMAGE_TOKEN + "\n" + question
95
-
96
- conv = conv_templates["vicuna_v1"].copy()
97
- conv.append_message(conv.roles[0], question)
98
- conv.append_message(conv.roles[1], None)
99
- prompt = conv.get_prompt()
100
-
101
- input_ids = (
102
- tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
103
- .unsqueeze(0)
104
- .cuda()
105
- )
106
-
107
- image_tensor = None
108
- if image_path:
109
- image = Image.open(image_path)
110
- image_tensor = process_images([image], self.image_processor, self.model.config)[0]
111
- image_tensor = image_tensor.unsqueeze(0).half().cuda()
112
-
113
- return input_ids, image_tensor
114
-
115
- def _run(
116
- self,
117
- question: str,
118
- image_path: Optional[str] = None,
119
- run_manager: Optional[CallbackManagerForToolRun] = None,
120
- ) -> Tuple[Dict[str, Any], Dict]:
121
- """Answer a medical question, optionally based on an input image.
122
-
123
- Args:
124
- question (str): The medical question to answer.
125
- image_path (Optional[str]): The path to the medical image file (if applicable).
126
- run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
127
-
128
- Returns:
129
- Tuple[Dict[str, Any], Dict]: A tuple containing the output dictionary and metadata dictionary.
130
-
131
- Raises:
132
- Exception: If there's an error processing the input or generating the answer.
133
- """
134
- try:
135
- input_ids, image_tensor = self._process_input(question, image_path)
136
- input_ids = input_ids.to(device=self.model.device)
137
- image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
138
-
139
- with torch.inference_mode():
140
- output_ids = self.model.generate(
141
- input_ids,
142
- images=image_tensor,
143
- do_sample=False,
144
- temperature=0.2,
145
- max_new_tokens=500,
146
- use_cache=True,
147
- )
148
-
149
- answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
150
-
151
- output = {
152
- "answer": answer,
153
- }
154
-
155
- metadata = {
156
- "question": question,
157
- "image_path": image_path,
158
- "analysis_status": "completed",
159
- }
160
- return output, metadata
161
- except Exception as e:
162
- output = {"error": f"Error generating answer: {str(e)}"}
163
- metadata = {
164
- "question": question,
165
- "image_path": image_path,
166
- "analysis_status": "failed",
167
- }
168
- return output, metadata
169
-
170
- async def _arun(
171
- self,
172
- question: str,
173
- image_path: Optional[str] = None,
174
- run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
175
- ) -> Tuple[Dict[str, Any], Dict]:
176
- """Asynchronously answer a medical question, optionally based on an input image.
177
-
178
- This method currently calls the synchronous version, as the model inference
179
- is not inherently asynchronous. For true asynchronous behavior, consider
180
- using a separate thread or process.
181
-
182
- Args:
183
- question (str): The medical question to answer.
184
- image_path (Optional[str]): The path to the medical image file (if applicable).
185
- run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
186
-
187
- Returns:
188
- Tuple[Dict[str, Any], Dict]: A tuple containing the output dictionary and metadata dictionary.
189
-
190
- Raises:
191
- Exception: If there's an error processing the input or generating the answer.
192
- """
193
- return self._run(question, image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
medrax/tools/rag.py CHANGED
@@ -48,14 +48,14 @@ class RAGTool(BaseTool):
48
  self.rag = CohereRAG(config)
49
  self.chain = self.rag.initialize_rag(with_memory=True)
50
 
51
- def _run(self, query: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
52
  """Execute the RAG tool with the given query.
53
 
54
  Args:
55
  query (str): Medical question to answer
56
 
57
  Returns:
58
- Tuple[Dict[str, Any], Dict[str, Any]]: Output dictionary and metadata dictionary
59
  """
60
  try:
61
  result = self.chain.invoke({"query": query})
@@ -87,14 +87,14 @@ class RAGTool(BaseTool):
87
  }
88
  return output, metadata
89
 
90
- async def _arun(self, query: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
91
  """Async version of _run.
92
 
93
  Args:
94
  query (str): Medical question to answer
95
 
96
  Returns:
97
- Tuple[Dict[str, Any], Dict[str, Any]]: Output dictionary and metadata dictionary
98
 
99
  Raises:
100
  NotImplementedError: Async not implemented yet
 
48
  self.rag = CohereRAG(config)
49
  self.chain = self.rag.initialize_rag(with_memory=True)
50
 
51
+ def _run(self, query: str) -> Tuple[Dict[str, Any], Dict]:
52
  """Execute the RAG tool with the given query.
53
 
54
  Args:
55
  query (str): Medical question to answer
56
 
57
  Returns:
58
+ Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary
59
  """
60
  try:
61
  result = self.chain.invoke({"query": query})
 
87
  }
88
  return output, metadata
89
 
90
+ async def _arun(self, query: str) -> Tuple[Dict[str, Any], Dict]:
91
  """Async version of _run.
92
 
93
  Args:
94
  query (str): Medical question to answer
95
 
96
  Returns:
97
+ Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary
98
 
99
  Raises:
100
  NotImplementedError: Async not implemented yet
medrax/tools/segmentation/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Medical image segmentation tools for MedRAX2."""
2
+
3
+ from .segmentation import ChestXRaySegmentationTool, ChestXRaySegmentationInput, OrganMetrics
4
+ from .medsam2 import MedSAM2Tool, MedSAM2Input
5
+
6
+ __all__ = [
7
+ "ChestXRaySegmentationTool",
8
+ "ChestXRaySegmentationInput",
9
+ "OrganMetrics",
10
+ "MedSAM2Tool",
11
+ "MedSAM2Input"
12
+ ]
medrax/tools/{medsam2.py → segmentation/medsam2.py} RENAMED
@@ -15,7 +15,7 @@ from langchain_core.callbacks import (
15
  from langchain_core.tools import BaseTool
16
 
17
  # Add MedSAM2 to Python path for proper module resolution
18
- medsam2_path = str(Path(__file__).parent.parent.parent / "MedSAM2")
19
  if medsam2_path not in sys.path:
20
  sys.path.append(medsam2_path)
21
 
@@ -93,7 +93,7 @@ class MedSAM2Tool(BaseTool):
93
  if GlobalHydra.instance().is_initialized():
94
  GlobalHydra.instance().clear()
95
 
96
- config_dir = Path(__file__).parent.parent.parent / "MedSAM2" / "sam2" / "configs"
97
  initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
98
 
99
  hf_hub_download(
 
15
  from langchain_core.tools import BaseTool
16
 
17
  # Add MedSAM2 to Python path for proper module resolution
18
+ medsam2_path = str(Path(__file__).parent.parent.parent.parent / "MedSAM2")
19
  if medsam2_path not in sys.path:
20
  sys.path.append(medsam2_path)
21
 
 
93
  if GlobalHydra.instance().is_initialized():
94
  GlobalHydra.instance().clear()
95
 
96
+ config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
97
  initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
98
 
99
  hf_hub_download(
medrax/tools/{segmentation.py → segmentation/segmentation.py} RENAMED
File without changes
medrax/tools/vqa/llava_med.py CHANGED
@@ -117,7 +117,7 @@ class LlavaMedTool(BaseTool):
117
  question: str,
118
  image_path: Optional[str] = None,
119
  run_manager: Optional[CallbackManagerForToolRun] = None,
120
- ) -> Tuple[str, Dict]:
121
  """Answer a medical question, optionally based on an input image.
122
 
123
  Args:
@@ -126,7 +126,7 @@ class LlavaMedTool(BaseTool):
126
  run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
127
 
128
  Returns:
129
- Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata.
130
 
131
  Raises:
132
  Exception: If there's an error processing the input or generating the answer.
@@ -146,7 +146,12 @@ class LlavaMedTool(BaseTool):
146
  use_cache=True,
147
  )
148
 
149
- output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
 
 
 
 
 
150
  metadata = {
151
  "question": question,
152
  "image_path": image_path,
@@ -154,18 +159,20 @@ class LlavaMedTool(BaseTool):
154
  }
155
  return output, metadata
156
  except Exception as e:
157
- return f"Error generating answer: {str(e)}", {
 
158
  "question": question,
159
  "image_path": image_path,
160
  "analysis_status": "failed",
161
  }
 
162
 
163
  async def _arun(
164
  self,
165
  question: str,
166
  image_path: Optional[str] = None,
167
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
168
- ) -> Tuple[str, Dict]:
169
  """Asynchronously answer a medical question, optionally based on an input image.
170
 
171
  This method currently calls the synchronous version, as the model inference
@@ -178,9 +185,9 @@ class LlavaMedTool(BaseTool):
178
  run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
179
 
180
  Returns:
181
- Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata.
182
 
183
  Raises:
184
  Exception: If there's an error processing the input or generating the answer.
185
  """
186
- return self._run(question, image_path)
 
117
  question: str,
118
  image_path: Optional[str] = None,
119
  run_manager: Optional[CallbackManagerForToolRun] = None,
120
+ ) -> Tuple[Dict[str, Any], Dict]:
121
  """Answer a medical question, optionally based on an input image.
122
 
123
  Args:
 
126
  run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
127
 
128
  Returns:
129
+ Tuple[Dict[str, Any], Dict]: A tuple containing the output dictionary and metadata dictionary.
130
 
131
  Raises:
132
  Exception: If there's an error processing the input or generating the answer.
 
146
  use_cache=True,
147
  )
148
 
149
+ answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
150
+
151
+ output = {
152
+ "answer": answer,
153
+ }
154
+
155
  metadata = {
156
  "question": question,
157
  "image_path": image_path,
 
159
  }
160
  return output, metadata
161
  except Exception as e:
162
+ output = {"error": f"Error generating answer: {str(e)}"}
163
+ metadata = {
164
  "question": question,
165
  "image_path": image_path,
166
  "analysis_status": "failed",
167
  }
168
+ return output, metadata
169
 
170
  async def _arun(
171
  self,
172
  question: str,
173
  image_path: Optional[str] = None,
174
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
175
+ ) -> Tuple[Dict[str, Any], Dict]:
176
  """Asynchronously answer a medical question, optionally based on an input image.
177
 
178
  This method currently calls the synchronous version, as the model inference
 
185
  run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
186
 
187
  Returns:
188
+ Tuple[Dict[str, Any], Dict]: A tuple containing the output dictionary and metadata dictionary.
189
 
190
  Raises:
191
  Exception: If there's an error processing the input or generating the answer.
192
  """
193
+ return self._run(question, image_path)
medrax/tools/vqa/xray_vqa.py CHANGED
@@ -183,4 +183,4 @@ class CheXagentXRayVQATool(BaseTool):
183
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
184
  ) -> Tuple[Dict[str, Any], Dict]:
185
  """Async version of _run."""
186
- return self._run(image_paths, prompt, max_new_tokens)
 
183
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
184
  ) -> Tuple[Dict[str, Any], Dict]:
185
  """Async version of _run."""
186
+ return self._run(image_paths, prompt, max_new_tokens)
medrax/tools/{generation.py → xray_generation.py} RENAMED
File without changes
medrax/tools/xray_vqa.py DELETED
@@ -1,186 +0,0 @@
1
- from typing import Dict, List, Optional, Tuple, Type, Any
2
- from pathlib import Path
3
- from pydantic import BaseModel, Field
4
-
5
- import torch
6
- import transformers
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
- from langchain_core.callbacks import (
9
- AsyncCallbackManagerForToolRun,
10
- CallbackManagerForToolRun,
11
- )
12
- from langchain_core.tools import BaseTool
13
-
14
-
15
- class XRayVQAToolInput(BaseModel):
16
- """Input schema for the CheXagent Tool."""
17
-
18
- image_paths: List[str] = Field(
19
- ..., description="List of paths to chest X-ray images to analyze"
20
- )
21
- prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
22
- max_new_tokens: int = Field(
23
- 512, description="Maximum number of tokens to generate in the response"
24
- )
25
-
26
-
27
- class XRayVQATool(BaseTool):
28
- """Tool that leverages CheXagent for comprehensive chest X-ray analysis."""
29
-
30
- name: str = "chest_xray_expert"
31
- description: str = (
32
- "A versatile tool for analyzing chest X-rays. "
33
- "Can perform multiple tasks including: visual question answering, report generation, "
34
- "abnormality detection, comparative analysis, anatomical description, "
35
- "and clinical interpretation. Input should be paths to X-ray images "
36
- "and a natural language prompt describing the analysis needed."
37
- )
38
- args_schema: Type[BaseModel] = XRayVQAToolInput
39
- return_direct: bool = True
40
- cache_dir: Optional[str] = None
41
- device: Optional[str] = None
42
- dtype: torch.dtype = torch.bfloat16
43
- tokenizer: Optional[AutoTokenizer] = None
44
- model: Optional[AutoModelForCausalLM] = None
45
-
46
- def __init__(
47
- self,
48
- model_name: str = "StanfordAIMI/CheXagent-2-3b",
49
- device: Optional[str] = "cuda",
50
- dtype: torch.dtype = torch.bfloat16,
51
- cache_dir: Optional[str] = None,
52
- **kwargs: Any,
53
- ) -> None:
54
- """Initialize the XRayVQATool.
55
-
56
- Args:
57
- model_name: Name of the CheXagent model to use
58
- device: Device to run model on (cuda/cpu)
59
- dtype: Data type for model weights
60
- cache_dir: Directory to cache downloaded models
61
- **kwargs: Additional arguments
62
- """
63
- super().__init__(**kwargs)
64
-
65
- # Dangerous code, but works for now
66
- import transformers
67
-
68
- original_transformers_version = transformers.__version__
69
- transformers.__version__ = "4.40.0"
70
-
71
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
72
- self.dtype = dtype
73
- self.cache_dir = cache_dir
74
-
75
- # Load tokenizer and model
76
- self.tokenizer = AutoTokenizer.from_pretrained(
77
- model_name,
78
- trust_remote_code=True,
79
- cache_dir=cache_dir,
80
- )
81
- self.model = AutoModelForCausalLM.from_pretrained(
82
- model_name,
83
- device_map=self.device,
84
- trust_remote_code=True,
85
- cache_dir=cache_dir,
86
- )
87
- self.model = self.model.to(dtype=self.dtype)
88
- self.model.eval()
89
-
90
- transformers.__version__ = original_transformers_version
91
-
92
- def _generate_response(self, image_paths: List[str], prompt: str, max_new_tokens: int) -> str:
93
- """Generate response using CheXagent model.
94
-
95
- Args:
96
- image_paths: List of paths to chest X-ray images
97
- prompt: Question or instruction about the images
98
- max_new_tokens: Maximum number of tokens to generate
99
- Returns:
100
- str: Model's response
101
- """
102
- query = self.tokenizer.from_list_format(
103
- [*[{"image": path} for path in image_paths], {"text": prompt}]
104
- )
105
- conv = [
106
- {"from": "system", "value": "You are a helpful assistant."},
107
- {"from": "human", "value": query},
108
- ]
109
- input_ids = self.tokenizer.apply_chat_template(
110
- conv, add_generation_prompt=True, return_tensors="pt"
111
- ).to(device=self.device)
112
-
113
- # Run inference
114
- with torch.inference_mode():
115
- output = self.model.generate(
116
- input_ids,
117
- do_sample=False,
118
- num_beams=1,
119
- temperature=1.0,
120
- top_p=1.0,
121
- use_cache=True,
122
- max_new_tokens=max_new_tokens,
123
- )[0]
124
- response = self.tokenizer.decode(output[input_ids.size(1) : -1])
125
-
126
- return response
127
-
128
- def _run(
129
- self,
130
- image_paths: List[str],
131
- prompt: str,
132
- max_new_tokens: int = 512,
133
- run_manager: Optional[CallbackManagerForToolRun] = None,
134
- ) -> Tuple[Dict[str, Any], Dict]:
135
- """Execute the chest X-ray analysis.
136
-
137
- Args:
138
- image_paths: List of paths to chest X-ray images
139
- prompt: Question or instruction about the images
140
- max_new_tokens: Maximum number of tokens to generate
141
- run_manager: Optional callback manager
142
-
143
- Returns:
144
- Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary
145
- """
146
- try:
147
- # Verify image paths
148
- for path in image_paths:
149
- if not Path(path).is_file():
150
- raise FileNotFoundError(f"Image file not found: {path}")
151
-
152
- response = self._generate_response(image_paths, prompt, max_new_tokens)
153
-
154
- output = {
155
- "response": response,
156
- }
157
-
158
- metadata = {
159
- "image_paths": image_paths,
160
- "prompt": prompt,
161
- "max_new_tokens": max_new_tokens,
162
- "analysis_status": "completed",
163
- }
164
-
165
- return output, metadata
166
-
167
- except Exception as e:
168
- output = {"error": str(e)}
169
- metadata = {
170
- "image_paths": image_paths,
171
- "prompt": prompt,
172
- "max_new_tokens": max_new_tokens,
173
- "analysis_status": "failed",
174
- "error_details": str(e),
175
- }
176
- return output, metadata
177
-
178
- async def _arun(
179
- self,
180
- image_paths: List[str],
181
- prompt: str,
182
- max_new_tokens: int = 512,
183
- run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
184
- ) -> Tuple[Dict[str, Any], Dict]:
185
- """Async version of _run."""
186
- return self._run(image_paths, prompt, max_new_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -50,7 +50,7 @@ dependencies = [
50
  "fastapi>=0.68.0",
51
  "einops>=0.3.0",
52
  "einops-exts>=0.0.4",
53
- "timm>=0.5.0",
54
  "tiktoken>=0.3.0",
55
  "openai>=0.27.0",
56
  "backoff>=1.10.0",
@@ -75,6 +75,7 @@ dependencies = [
75
  "seaborn>=0.12.0",
76
  "huggingface_hub>=0.17.0",
77
  "iopath>=0.1.10",
 
78
  ]
79
 
80
  [project.optional-dependencies]
 
50
  "fastapi>=0.68.0",
51
  "einops>=0.3.0",
52
  "einops-exts>=0.0.4",
53
+ "timm==0.5.4",
54
  "tiktoken>=0.3.0",
55
  "openai>=0.27.0",
56
  "backoff>=1.10.0",
 
75
  "seaborn>=0.12.0",
76
  "huggingface_hub>=0.17.0",
77
  "iopath>=0.1.10",
78
+ "duckduckgo-search>=4.0.0",
79
  ]
80
 
81
  [project.optional-dependencies]