victorli commited on
Commit
e97f266
·
1 Parent(s): 8d52c4a

cleared merge issues

Browse files
benchmarking/benchmarks/rexvqa_benchmark.py CHANGED
@@ -34,20 +34,20 @@ class ReXVQABenchmark(Benchmark):
34
  data_dir (str): Directory to store/cache downloaded data
35
  **kwargs: Additional configuration parameters
36
  split (str): Dataset split to use (default: 'test')
37
- cache_dir (str): Directory for caching HuggingFace datasets
38
  trust_remote_code (bool): Whether to trust remote code (default: False)
39
  max_questions (int): Maximum number of questions to load (default: None, load all)
40
  images_dir (str): Directory containing extracted PNG images (default: None)
41
  """
42
  self.split = kwargs.get("split", "test")
43
- self.cache_dir = kwargs.get("cache_dir", None)
44
  self.trust_remote_code = kwargs.get("trust_remote_code", False)
45
  self.max_questions = kwargs.get("max_questions", None)
46
- self.images_dir = "benchmarking/data/rexvqa/images/deid_png"
47
  self.image_dataset = None
48
  self.image_mapping = {} # Maps study_id to image data
49
 
50
  super().__init__(data_dir, **kwargs)
 
 
 
51
 
52
  @staticmethod
53
  def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K"):
@@ -166,8 +166,8 @@ class ReXVQABenchmark(Benchmark):
166
  """Load ReXVQA data from local JSON file."""
167
  try:
168
  # Check for images and test_vqa_data.json, download if missing
169
- self.download_test_vqa_data_json()
170
- self.download_rexgradient_images()
171
 
172
  # Construct path to the JSON file
173
  json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
@@ -197,7 +197,7 @@ class ReXVQABenchmark(Benchmark):
197
  self.image_dataset = load_dataset(
198
  "rajpurkarlab/ReXGradient-160K",
199
  split="test",
200
- cache_dir=self.cache_dir,
201
  trust_remote_code=self.trust_remote_code
202
  )
203
  print(f"Loaded {len(self.image_dataset)} image metadata entries from ReXGradient-160K")
 
34
  data_dir (str): Directory to store/cache downloaded data
35
  **kwargs: Additional configuration parameters
36
  split (str): Dataset split to use (default: 'test')
 
37
  trust_remote_code (bool): Whether to trust remote code (default: False)
38
  max_questions (int): Maximum number of questions to load (default: None, load all)
39
  images_dir (str): Directory containing extracted PNG images (default: None)
40
  """
41
  self.split = kwargs.get("split", "test")
 
42
  self.trust_remote_code = kwargs.get("trust_remote_code", False)
43
  self.max_questions = kwargs.get("max_questions", None)
 
44
  self.image_dataset = None
45
  self.image_mapping = {} # Maps study_id to image data
46
 
47
  super().__init__(data_dir, **kwargs)
48
+
49
+ # Set images_dir after parent initialization
50
+ self.images_dir = f"{self.data_dir}/images/deid_png"
51
 
52
  @staticmethod
53
  def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K"):
 
166
  """Load ReXVQA data from local JSON file."""
167
  try:
168
  # Check for images and test_vqa_data.json, download if missing
169
+ self.download_test_vqa_data_json(self.data_dir)
170
+ self.download_rexgradient_images(self.data_dir)
171
 
172
  # Construct path to the JSON file
173
  json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
 
197
  self.image_dataset = load_dataset(
198
  "rajpurkarlab/ReXGradient-160K",
199
  split="test",
200
+ cache_dir=self.data_dir,
201
  trust_remote_code=self.trust_remote_code
202
  )
203
  print(f"Loaded {len(self.image_dataset)} image metadata entries from ReXGradient-160K")
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -33,15 +33,15 @@ class MedRAXProvider(LLMProvider):
33
  print("Starting server...")
34
 
35
  selected_tools = [
36
- "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
37
- "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
38
- "WebBrowserTool", # For web browsing and search capabilities
39
- "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
40
- "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
41
- "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
42
- "XRayVQATool", # For visual question answering on X-rays
43
- "XRayPhraseGroundingTool", # For locating described features in X-rays
44
  "MedGemmaVQATool"
 
 
 
 
45
  ]
46
 
47
  rag_config = RAGConfig(
@@ -64,11 +64,11 @@ class MedRAXProvider(LLMProvider):
64
  agent, tools_dict = initialize_agent(
65
  prompt_file="medrax/docs/system_prompts.txt",
66
  tools_to_use=selected_tools,
67
- model_dir="/model-weights",
68
  temp_dir="temp", # Change this to the path of the temporary directory
69
  device="cuda:0",
70
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
71
- temperature=0.3,
72
  top_p=0.95,
73
  model_kwargs=model_kwargs,
74
  rag_config=rag_config,
 
33
  print("Starting server...")
34
 
35
  selected_tools = [
36
+ # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
37
+ # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
38
+ # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
39
+ # "XRayVQATool", # For visual question answering on X-rays
 
 
 
 
40
  "MedGemmaVQATool"
41
+ # "XRayPhraseGroundingTool", # For locating described features in X-rays
42
+ # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
43
+ # "WebBrowserTool", # For web browsing and search capabilities
44
+ # "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
45
  ]
46
 
47
  rag_config = RAGConfig(
 
64
  agent, tools_dict = initialize_agent(
65
  prompt_file="medrax/docs/system_prompts.txt",
66
  tools_to_use=selected_tools,
67
+ model_dir="/scratch/ssd004/scratch/victorli/model-weights",
68
  temp_dir="temp", # Change this to the path of the temporary directory
69
  device="cuda:0",
70
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
71
+ temperature=1.0,
72
  top_p=0.95,
73
  model_kwargs=model_kwargs,
74
  rag_config=rag_config,
main.py CHANGED
@@ -33,7 +33,7 @@ _ = load_dotenv()
33
  def initialize_agent(
34
  prompt_file: str,
35
  tools_to_use: Optional[List[str]] = None,
36
- model_dir: str = "/model-weights",
37
  temp_dir: str = "temp",
38
  device: str = "cpu",
39
  model: str = "gpt-4.1-2025-04-14",
@@ -88,6 +88,7 @@ def initialize_agent(
88
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
89
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
90
  "WebBrowserTool": lambda: WebBrowserTool(),
 
91
  "MedSAM2Tool": lambda: MedSAM2Tool(
92
  device=device, cache_dir=model_dir, temp_dir=temp_dir
93
  ),
 
33
  def initialize_agent(
34
  prompt_file: str,
35
  tools_to_use: Optional[List[str]] = None,
36
+ model_dir: str = "/scratch/ssd004/scratch/victorli/model-weights",
37
  temp_dir: str = "temp",
38
  device: str = "cpu",
39
  model: str = "gpt-4.1-2025-04-14",
 
88
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
89
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
90
  "WebBrowserTool": lambda: WebBrowserTool(),
91
+ "DuckDuckGoSearchTool": lambda: DuckDuckGoSearchTool(),
92
  "MedSAM2Tool": lambda: MedSAM2Tool(
93
  device=device, cache_dir=model_dir, temp_dir=temp_dir
94
  ),
pyproject.toml CHANGED
@@ -57,7 +57,6 @@ dependencies = [
57
  "torch>=2.2.0",
58
  "torchvision>=0.10.0",
59
  "scikit-image>=0.18.0",
60
- "gradio>=5.0.0",
61
  "opencv-python>=4.8.0",
62
  "matplotlib>=3.8.0",
63
  "diffusers>=0.20.0",
@@ -65,13 +64,11 @@ dependencies = [
65
  "pylibjpeg>=1.0.0",
66
  "jupyter>=1.0.0",
67
  "albumentations>=1.0.0",
68
- "pyarrow>=10.0.0",
69
  "chromadb>=0.0.10",
70
  "pinecone-client>=3.2.2",
71
  "langchain-pinecone>=0.0.1",
72
  "langchain-google-genai>=0.1.0",
73
  "ray>=2.9.0",
74
- "langchain-sandbox>=0.0.6",
75
  "seaborn>=0.12.0",
76
  "huggingface_hub>=0.17.0",
77
  "iopath>=0.1.10",
 
57
  "torch>=2.2.0",
58
  "torchvision>=0.10.0",
59
  "scikit-image>=0.18.0",
 
60
  "opencv-python>=4.8.0",
61
  "matplotlib>=3.8.0",
62
  "diffusers>=0.20.0",
 
64
  "pylibjpeg>=1.0.0",
65
  "jupyter>=1.0.0",
66
  "albumentations>=1.0.0",
 
67
  "chromadb>=0.0.10",
68
  "pinecone-client>=3.2.2",
69
  "langchain-pinecone>=0.0.1",
70
  "langchain-google-genai>=0.1.0",
71
  "ray>=2.9.0",
 
72
  "seaborn>=0.12.0",
73
  "huggingface_hub>=0.17.0",
74
  "iopath>=0.1.10",