Junzhe Li commited on
Commit
dba3d2e
·
1 Parent(s): e4e9fae
2rexvqa.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --job-name=medrax
4
+ #SBATCH -c 4
5
+ #SBATCH --gres=gpu:l40s:1
6
+ #SBATCH --time=16:00:00
7
+ #SBATCH --mem=50G
8
+ #SBATCH --output=rexvqa-%j.out
9
+ #SBATCH --error=rexvqa-%j.err
10
+
11
+ module load arrow clang/18.1.8 scipy-stack
12
+
13
+ source venv/bin/activate
14
+
15
+ /scratch/lijunzh3/MedRAX2/venv/bin/python -m benchmarking.cli run --benchmark rexvqa --provider medrax --model gemini-2.5-pro --system-prompt CHESTAGENTBENCH_PROMPT --data-dir benchmarking/data/rexvqa --output-dir temp --max-questions 200 --temperature 0.7 --top-p 0.95 --max-tokens 10000 --concurrency 4 --random-seed 42
analyze.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import sys
4
+ from collections import defaultdict
5
+ from pathlib import Path
6
+
7
+ def process_single_file(json_file_path):
8
+ """
9
+ Processes a single JSON results file and returns its accuracy counts.
10
+
11
+ Args:
12
+ json_file_path (Path): Path to the ...results.json file.
13
+
14
+ Returns:
15
+ defaultdict: A dictionary with the aggregated counts for this file.
16
+ Returns None if the file cannot be processed.
17
+ """
18
+
19
+ # These counts are *only* for the file being processed
20
+ counts = defaultdict(lambda: defaultdict(lambda: {"total": 0, "correct": 0}))
21
+ keys_to_track = ["reasoning_type", "category", "class", "subcategory"]
22
+
23
+ try:
24
+ with open(json_file_path, 'r', encoding='utf-8') as f:
25
+ data = json.load(f)
26
+ except json.JSONDecodeError:
27
+ print(f" - WARNING: Could not decode JSON from '{json_file_path}'. Skipping.")
28
+ return None
29
+ except Exception as e:
30
+ print(f" - ERROR: Unexpected error loading '{json_file_path}': {e}. Skipping.")
31
+ return None
32
+
33
+ # Iterate through each record in the JSON array
34
+ for record in data:
35
+ try:
36
+ is_correct = record.get("is_correct", False)
37
+ metadata = record["metadata"]["data_point_metadata"]
38
+
39
+ for key in keys_to_track:
40
+ value = metadata.get(key)
41
+
42
+ if value is not None:
43
+ counts[key][value]["total"] += 1
44
+ if is_correct:
45
+ counts[key][value]["correct"] += 1
46
+
47
+ except KeyError as e:
48
+ print(f" - WARNING: Record {record.get('data_point_id')} is missing expected key: {e}. Skipping record.")
49
+ except TypeError:
50
+ print(f" - WARNING: Record {record.get('data_point_id')} has unexpected data structure. Skipping record.")
51
+
52
+ return counts
53
+
54
+ def generate_report_dict(counts):
55
+ """
56
+ Converts a counts dictionary into the final, formatted report dictionary.
57
+
58
+ Args:
59
+ counts (defaultdict): The aggregated counts from process_single_file.
60
+
61
+ Returns:
62
+ dict: A dictionary formatted with percentages and absolute numbers.
63
+ """
64
+ accuracy_report = defaultdict(dict)
65
+
66
+ for key, values in counts.items():
67
+ # Sort by the sub-category name (e.g., "Negation Assessment")
68
+ sorted_values = sorted(values.items(), key=lambda item: item[0])
69
+
70
+ for value, tally in sorted_values:
71
+ total = tally["total"]
72
+ correct = tally["correct"]
73
+
74
+ if total > 0:
75
+ accuracy = (correct / total) * 100
76
+ else:
77
+ accuracy = 0.0
78
+
79
+ # Store the full results in our report dictionary
80
+ accuracy_report[key][value] = {
81
+ "accuracy_percent": round(accuracy, 2),
82
+ "correct": correct,
83
+ "total": total
84
+ }
85
+ return accuracy_report
86
+
87
+ def main():
88
+ """
89
+ Main function to find, process, and save individual reports.
90
+ """
91
+
92
+ parser = argparse.ArgumentParser(
93
+ description="Finds and processes individual benchmarking runs, saving "
94
+ "a separate accuracy report for each run."
95
+ )
96
+ parser.add_argument(
97
+ "directory",
98
+ type=str,
99
+ help="The top-level directory to search within (e.g., 'my_experiments')."
100
+ )
101
+ args = parser.parse_args()
102
+
103
+ top_dir = Path(args.directory)
104
+ if not top_dir.is_dir():
105
+ print(f"Error: Path '{args.directory}' is not a valid directory.")
106
+ sys.exit(1)
107
+
108
+ # Glob pattern to find all target files
109
+ search_pattern = '*/final_results/*results.json'
110
+ json_files_to_process = list(top_dir.glob(search_pattern))
111
+
112
+ if not json_files_to_process:
113
+ print(f"No files matching the pattern '{search_pattern}' were found in '{top_dir}'.")
114
+ sys.exit(0)
115
+
116
+ print(f"Found {len(json_files_to_process)} result file(s) to process individually.")
117
+
118
+ # --- Loop and process each file ---
119
+ for file_path in json_files_to_process:
120
+ # Use relative path for cleaner logging
121
+ print(f"\n--- Processing: {file_path.relative_to(top_dir.parent)} ---")
122
+
123
+ # 1. Get counts for this file
124
+ counts = process_single_file(file_path)
125
+
126
+ if counts is None or not counts:
127
+ print(" - No data processed. Skipping report generation.")
128
+ continue
129
+
130
+ # 2. Generate the report dictionary
131
+ report = generate_report_dict(counts)
132
+
133
+ # 3. Determine the output path and save the file
134
+ # The output is saved in the *same directory* as the input file
135
+ output_filename = file_path.parent / "accuracy_report.json"
136
+
137
+ try:
138
+ with open(output_filename, 'w', encoding='utf-8') as f:
139
+ json.dump(report, f, indent=2, sort_keys=True)
140
+ print(f" > Successfully saved report to: {output_filename.relative_to(top_dir.parent)}")
141
+ except Exception as e:
142
+ print(f" > ERROR: Could not save report to '{output_filename}': {e}")
143
+
144
+ print("\nAll processing complete.")
145
+
146
+ if __name__ == "__main__":
147
+ main()
benchmarking/cli.py CHANGED
@@ -118,7 +118,7 @@ def main():
118
  run_parser.add_argument("--model", required=True,
119
  help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)")
120
  run_parser.add_argument("--system-prompt", required=True,
121
- choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT"],
122
  help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
123
  run_parser.add_argument("--data-dir", required=True,
124
  help="Directory containing benchmark data files")
 
118
  run_parser.add_argument("--model", required=True,
119
  help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)")
120
  run_parser.add_argument("--system-prompt", required=True,
121
+ choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT", "MEDGEMMA_PROMPT"],
122
  help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
123
  run_parser.add_argument("--data-dir", required=True,
124
  help="Directory containing benchmark data files")
benchmarking/llm_providers/medgemma_provider.py CHANGED
@@ -3,8 +3,6 @@
3
  import os
4
  import time
5
  import httpx
6
- from typing import Optional
7
- from pathlib import Path
8
  from tenacity import retry, wait_exponential, stop_after_attempt
9
 
10
  from .base import LLMProvider, LLMRequest, LLMResponse
@@ -36,9 +34,8 @@ class MedGemmaProvider(LLMProvider):
36
  - api_url: URL of the MedGemma FastAPI service
37
  - max_new_tokens: Maximum tokens to generate (default: 300)
38
  """
39
- # Extract MedGemma-specific config before calling super().__init__
40
- self.api_url = os.getenv('MEDGEMMA_API_URL', 'http://localhost:8002')
41
- self.max_new_tokens = kwargs.pop('max_new_tokens', 300)
42
  self.client = None
43
 
44
  # Call parent constructor
@@ -52,16 +49,6 @@ class MedGemmaProvider(LLMProvider):
52
  connect=10.0 # 10 seconds to establish connection
53
  )
54
  self.client = httpx.Client(timeout=timeout_config)
55
-
56
- # Test connection to MedGemma service
57
- try:
58
- response = self.client.get(f"{self.api_url}/docs")
59
- if response.status_code != 200:
60
- print(f"Warning: MedGemma API at {self.api_url} may not be running (status: {response.status_code})")
61
- except httpx.ConnectError:
62
- print(f"Warning: Could not connect to MedGemma API at {self.api_url}")
63
- print("Please ensure the MedGemma FastAPI service is running:")
64
- print(f" python medrax/tools/vqa/medgemma/medgemma.py")
65
 
66
  @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
67
  def generate_response(self, request: LLMRequest) -> LLMResponse:
@@ -100,14 +87,13 @@ class MedGemmaProvider(LLMProvider):
100
  files_to_send = []
101
  for image_path in valid_images:
102
  try:
103
- # Detect correct MIME type based on file extension
104
- ext = Path(image_path).suffix.lower()
105
- mime_type = "image/png" if ext == ".png" else "image/jpeg"
106
-
107
  # Read image file
108
  with open(image_path, "rb") as f:
109
  image_data = f.read()
110
 
 
 
 
111
  # Add to files list
112
  files_to_send.append(
113
  ("images", (os.path.basename(image_path), image_data, mime_type))
@@ -122,17 +108,14 @@ class MedGemmaProvider(LLMProvider):
122
  duration=time.time() - start_time
123
  )
124
 
125
- # Prepare form data
126
  # Use system_prompt if provided, otherwise use default
127
  system_prompt_text = self.system_prompt if self.system_prompt else "You are an expert radiologist who is able to analyze radiological images at any resolution."
128
 
129
- # Override max_new_tokens if provided in request
130
- max_tokens = getattr(request, 'max_tokens', self.max_new_tokens)
131
-
132
  data = {
133
  "prompt": request.text,
134
  "system_prompt": system_prompt_text,
135
- "max_new_tokens": max_tokens,
136
  }
137
 
138
  # Make API request
@@ -148,19 +131,14 @@ class MedGemmaProvider(LLMProvider):
148
  # Parse response
149
  response_data = response.json()
150
  content = response_data.get("response", "")
151
- metadata = response_data.get("metadata", {})
152
 
 
153
  duration = time.time() - start_time
154
-
155
- # MedGemma doesn't provide token usage, but we can include request info
156
- usage = {
157
- "num_images": len(valid_images),
158
- "max_new_tokens": max_tokens,
159
- }
160
-
161
  return LLMResponse(
162
  content=content,
163
- usage=usage,
164
  duration=duration
165
  )
166
 
@@ -199,7 +177,7 @@ class MedGemmaProvider(LLMProvider):
199
  content=f"Error: {error_msg}",
200
  duration=duration
201
  )
202
-
203
  def test_connection(self) -> bool:
204
  """Test the connection to the MedGemma API service.
205
 
 
3
  import os
4
  import time
5
  import httpx
 
 
6
  from tenacity import retry, wait_exponential, stop_after_attempt
7
 
8
  from .base import LLMProvider, LLMRequest, LLMResponse
 
34
  - api_url: URL of the MedGemma FastAPI service
35
  - max_new_tokens: Maximum tokens to generate (default: 300)
36
  """
37
+ self.provider_name = "medgemma"
38
+ self.api_url = "http://kn132.paice.vectorinstitute.ai:8002"
 
39
  self.client = None
40
 
41
  # Call parent constructor
 
49
  connect=10.0 # 10 seconds to establish connection
50
  )
51
  self.client = httpx.Client(timeout=timeout_config)
 
 
 
 
 
 
 
 
 
 
52
 
53
  @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
54
  def generate_response(self, request: LLMRequest) -> LLMResponse:
 
87
  files_to_send = []
88
  for image_path in valid_images:
89
  try:
 
 
 
 
90
  # Read image file
91
  with open(image_path, "rb") as f:
92
  image_data = f.read()
93
 
94
+ # Detect correct MIME type based on file extension
95
+ mime_type = self._get_image_mime_type(image_path)
96
+
97
  # Add to files list
98
  files_to_send.append(
99
  ("images", (os.path.basename(image_path), image_data, mime_type))
 
108
  duration=time.time() - start_time
109
  )
110
 
 
111
  # Use system_prompt if provided, otherwise use default
112
  system_prompt_text = self.system_prompt if self.system_prompt else "You are an expert radiologist who is able to analyze radiological images at any resolution."
113
 
114
+ # Prepare form data
 
 
115
  data = {
116
  "prompt": request.text,
117
  "system_prompt": system_prompt_text,
118
+ "max_new_tokens": self.max_tokens,
119
  }
120
 
121
  # Make API request
 
131
  # Parse response
132
  response_data = response.json()
133
  content = response_data.get("response", "")
 
134
 
135
+ # record duration
136
  duration = time.time() - start_time
137
+
138
+ # return response object
 
 
 
 
 
139
  return LLMResponse(
140
  content=content,
141
+ usage=None,
142
  duration=duration
143
  )
144
 
 
177
  content=f"Error: {error_msg}",
178
  duration=duration
179
  )
180
+
181
  def test_connection(self) -> bool:
182
  """Test the connection to the MedGemma API service.
183
 
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -37,14 +37,14 @@ class MedRAXProvider(LLMProvider):
37
  print("Starting server...")
38
 
39
  selected_tools = [
40
- "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
41
- "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
42
- "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
43
  # "XRayPhraseGroundingTool", # For locating described features in X-rays
44
  "MedGemmaVQATool", # Google MedGemma VQA tool
45
- "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
46
- "WebBrowserTool", # For web browsing and search capabilities
47
- "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
48
  ]
49
 
50
  rag_config = RAGConfig(
 
37
  print("Starting server...")
38
 
39
  selected_tools = [
40
+ # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
41
+ # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
42
+ # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
43
  # "XRayPhraseGroundingTool", # For locating described features in X-rays
44
  "MedGemmaVQATool", # Google MedGemma VQA tool
45
+ # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
46
+ # "WebBrowserTool", # For web browsing and search capabilities
47
+ # "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
48
  ]
49
 
50
  rag_config = RAGConfig(
benchmarking/system_prompts.txt CHANGED
@@ -33,4 +33,9 @@ Your final response for a multiple-choice question must strictly follow this for
33
  3. **Critical Thinking & Tool Use:** [Show your reasoning, including how you used tools and evaluated their output]
34
  4. **Final Answer:** \boxed{A}
35
 
36
- Do not provide a definitive diagnosis or treatment plan for a patient. Your purpose is to assist medical professionals with your analysis, not to replace them. You must maintain this persona and adhere to all instructions.
 
 
 
 
 
 
33
  3. **Critical Thinking & Tool Use:** [Show your reasoning, including how you used tools and evaluated their output]
34
  4. **Final Answer:** \boxed{A}
35
 
36
+ Do not provide a definitive diagnosis or treatment plan for a patient. Your purpose is to assist medical professionals with your analysis, not to replace them. You must maintain this persona and adhere to all instructions.
37
+
38
+ [MEDGEMMA_PROMPT]
39
+ You are an expert in interpreting medical images and able to analyze medical images of any resolution, specifically chest X-rays, CT scans, and MRIs, with world-class accuracy and precision.
40
+
41
+ Your final response for a multiple-choice question must strictly follow this boxed format for providing the final answer: **Final Answer:** \boxed{A}
chestagentbench_script.sh CHANGED
@@ -12,4 +12,4 @@ module load arrow clang/18.1.8 scipy-stack
12
 
13
  source venv/bin/activate
14
 
15
- /scratch/lijunzh3/MedRAX2/venv/bin/python -m benchmarking.cli run --benchmark chestagentbench --provider google --model gemini-2.5-pro --system-prompt CHESTAGENTBENCH_PROMPT --data-dir benchmarking/data/chestagentbench --output-dir temp --max-questions 500 --temperature 0.7 --top-p 0.95 --max-tokens 10000 --concurrency 4 --random-seed 42
 
12
 
13
  source venv/bin/activate
14
 
15
+ /scratch/lijunzh3/MedRAX2/venv/bin/python -m benchmarking.cli run --benchmark chestagentbench --provider medrax --model gemini-2.5-pro --system-prompt CHESTAGENTBENCH_PROMPT --data-dir benchmarking/data/chestagentbench --output-dir temp --max-questions 500 --temperature 0.7 --top-p 0.95 --max-tokens 10000 --concurrency 4 --random-seed 42
medgemma_script.sh CHANGED
@@ -1,6 +1,6 @@
1
  #!/bin/bash
2
 
3
- #SBATCH --job-name=medgemma
4
  #SBATCH -c 4
5
  #SBATCH --gres=gpu:l40s:1
6
  #SBATCH --time=16:00:00
@@ -8,6 +8,8 @@
8
  #SBATCH --output=medgemma-%j.out
9
  #SBATCH --error=medgemma-%j.err
10
 
 
 
11
  cd medrax/tools/vqa/medgemma
12
 
13
  source medgemma/bin/activate
 
1
  #!/bin/bash
2
 
3
+ #SBATCH --job-name=medgemma3
4
  #SBATCH -c 4
5
  #SBATCH --gres=gpu:l40s:1
6
  #SBATCH --time=16:00:00
 
8
  #SBATCH --output=medgemma-%j.out
9
  #SBATCH --error=medgemma-%j.err
10
 
11
+ export MEDGEMMA_DEVICE=cuda
12
+
13
  cd medrax/tools/vqa/medgemma
14
 
15
  source medgemma/bin/activate
rexvqa_script.sh CHANGED
@@ -1,6 +1,6 @@
1
  #!/bin/bash
2
 
3
- #SBATCH --job-name=rexvqa
4
  #SBATCH -c 4
5
  #SBATCH --gres=gpu:l40s:1
6
  #SBATCH --time=16:00:00
@@ -12,4 +12,4 @@ module load arrow clang/18.1.8 scipy-stack
12
 
13
  source venv/bin/activate
14
 
15
- /scratch/lijunzh3/MedRAX2/venv/bin/python -m benchmarking.cli run --benchmark rexvqa --provider medrax --model gemini-2.5-pro --system-prompt CHESTAGENTBENCH_PROMPT --data-dir benchmarking/data/rexvqa --output-dir temp --max-questions 500 --temperature 0.7 --top-p 0.95 --max-tokens 10000 --concurrency 4 --random-seed 42
 
1
  #!/bin/bash
2
 
3
+ #SBATCH --job-name=medgemma_run2
4
  #SBATCH -c 4
5
  #SBATCH --gres=gpu:l40s:1
6
  #SBATCH --time=16:00:00
 
12
 
13
  source venv/bin/activate
14
 
15
+ /scratch/lijunzh3/MedRAX2/venv/bin/python -m benchmarking.cli run --benchmark rexvqa --provider medgemma --model medgemma-4b --system-prompt MEDGEMMA_PROMPT --data-dir benchmarking/data/rexvqa --output-dir temp --max-questions 200 --temperature 0.7 --top-p 0.95 --max-tokens 10000 --concurrency 4 --random-seed 100