Spaces:
Sleeping
Sleeping
victorli
commited on
Commit
·
e97f266
1
Parent(s):
8d52c4a
cleared merge issues
Browse files- benchmarking/benchmarks/rexvqa_benchmark.py +6 -6
- benchmarking/llm_providers/medrax_provider.py +10 -10
- main.py +2 -1
- pyproject.toml +0 -3
benchmarking/benchmarks/rexvqa_benchmark.py
CHANGED
|
@@ -34,20 +34,20 @@ class ReXVQABenchmark(Benchmark):
|
|
| 34 |
data_dir (str): Directory to store/cache downloaded data
|
| 35 |
**kwargs: Additional configuration parameters
|
| 36 |
split (str): Dataset split to use (default: 'test')
|
| 37 |
-
cache_dir (str): Directory for caching HuggingFace datasets
|
| 38 |
trust_remote_code (bool): Whether to trust remote code (default: False)
|
| 39 |
max_questions (int): Maximum number of questions to load (default: None, load all)
|
| 40 |
images_dir (str): Directory containing extracted PNG images (default: None)
|
| 41 |
"""
|
| 42 |
self.split = kwargs.get("split", "test")
|
| 43 |
-
self.cache_dir = kwargs.get("cache_dir", None)
|
| 44 |
self.trust_remote_code = kwargs.get("trust_remote_code", False)
|
| 45 |
self.max_questions = kwargs.get("max_questions", None)
|
| 46 |
-
self.images_dir = "benchmarking/data/rexvqa/images/deid_png"
|
| 47 |
self.image_dataset = None
|
| 48 |
self.image_mapping = {} # Maps study_id to image data
|
| 49 |
|
| 50 |
super().__init__(data_dir, **kwargs)
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
@staticmethod
|
| 53 |
def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K"):
|
|
@@ -166,8 +166,8 @@ class ReXVQABenchmark(Benchmark):
|
|
| 166 |
"""Load ReXVQA data from local JSON file."""
|
| 167 |
try:
|
| 168 |
# Check for images and test_vqa_data.json, download if missing
|
| 169 |
-
self.download_test_vqa_data_json()
|
| 170 |
-
self.download_rexgradient_images()
|
| 171 |
|
| 172 |
# Construct path to the JSON file
|
| 173 |
json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
|
|
@@ -197,7 +197,7 @@ class ReXVQABenchmark(Benchmark):
|
|
| 197 |
self.image_dataset = load_dataset(
|
| 198 |
"rajpurkarlab/ReXGradient-160K",
|
| 199 |
split="test",
|
| 200 |
-
cache_dir=self.
|
| 201 |
trust_remote_code=self.trust_remote_code
|
| 202 |
)
|
| 203 |
print(f"Loaded {len(self.image_dataset)} image metadata entries from ReXGradient-160K")
|
|
|
|
| 34 |
data_dir (str): Directory to store/cache downloaded data
|
| 35 |
**kwargs: Additional configuration parameters
|
| 36 |
split (str): Dataset split to use (default: 'test')
|
|
|
|
| 37 |
trust_remote_code (bool): Whether to trust remote code (default: False)
|
| 38 |
max_questions (int): Maximum number of questions to load (default: None, load all)
|
| 39 |
images_dir (str): Directory containing extracted PNG images (default: None)
|
| 40 |
"""
|
| 41 |
self.split = kwargs.get("split", "test")
|
|
|
|
| 42 |
self.trust_remote_code = kwargs.get("trust_remote_code", False)
|
| 43 |
self.max_questions = kwargs.get("max_questions", None)
|
|
|
|
| 44 |
self.image_dataset = None
|
| 45 |
self.image_mapping = {} # Maps study_id to image data
|
| 46 |
|
| 47 |
super().__init__(data_dir, **kwargs)
|
| 48 |
+
|
| 49 |
+
# Set images_dir after parent initialization
|
| 50 |
+
self.images_dir = f"{self.data_dir}/images/deid_png"
|
| 51 |
|
| 52 |
@staticmethod
|
| 53 |
def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K"):
|
|
|
|
| 166 |
"""Load ReXVQA data from local JSON file."""
|
| 167 |
try:
|
| 168 |
# Check for images and test_vqa_data.json, download if missing
|
| 169 |
+
self.download_test_vqa_data_json(self.data_dir)
|
| 170 |
+
self.download_rexgradient_images(self.data_dir)
|
| 171 |
|
| 172 |
# Construct path to the JSON file
|
| 173 |
json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
|
|
|
|
| 197 |
self.image_dataset = load_dataset(
|
| 198 |
"rajpurkarlab/ReXGradient-160K",
|
| 199 |
split="test",
|
| 200 |
+
cache_dir=self.data_dir,
|
| 201 |
trust_remote_code=self.trust_remote_code
|
| 202 |
)
|
| 203 |
print(f"Loaded {len(self.image_dataset)} image metadata entries from ReXGradient-160K")
|
benchmarking/llm_providers/medrax_provider.py
CHANGED
|
@@ -33,15 +33,15 @@ class MedRAXProvider(LLMProvider):
|
|
| 33 |
print("Starting server...")
|
| 34 |
|
| 35 |
selected_tools = [
|
| 36 |
-
"
|
| 37 |
-
"
|
| 38 |
-
"
|
| 39 |
-
"
|
| 40 |
-
"ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
| 41 |
-
"DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
|
| 42 |
-
"XRayVQATool", # For visual question answering on X-rays
|
| 43 |
-
"XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 44 |
"MedGemmaVQATool"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
]
|
| 46 |
|
| 47 |
rag_config = RAGConfig(
|
|
@@ -64,11 +64,11 @@ class MedRAXProvider(LLMProvider):
|
|
| 64 |
agent, tools_dict = initialize_agent(
|
| 65 |
prompt_file="medrax/docs/system_prompts.txt",
|
| 66 |
tools_to_use=selected_tools,
|
| 67 |
-
model_dir="/model-weights",
|
| 68 |
temp_dir="temp", # Change this to the path of the temporary directory
|
| 69 |
device="cuda:0",
|
| 70 |
model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
|
| 71 |
-
temperature=0
|
| 72 |
top_p=0.95,
|
| 73 |
model_kwargs=model_kwargs,
|
| 74 |
rag_config=rag_config,
|
|
|
|
| 33 |
print("Starting server...")
|
| 34 |
|
| 35 |
selected_tools = [
|
| 36 |
+
# "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
|
| 37 |
+
# "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
|
| 38 |
+
# "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
|
| 39 |
+
# "XRayVQATool", # For visual question answering on X-rays
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
"MedGemmaVQATool"
|
| 41 |
+
# "XRayPhraseGroundingTool", # For locating described features in X-rays
|
| 42 |
+
# "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
|
| 43 |
+
# "WebBrowserTool", # For web browsing and search capabilities
|
| 44 |
+
# "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
|
| 45 |
]
|
| 46 |
|
| 47 |
rag_config = RAGConfig(
|
|
|
|
| 64 |
agent, tools_dict = initialize_agent(
|
| 65 |
prompt_file="medrax/docs/system_prompts.txt",
|
| 66 |
tools_to_use=selected_tools,
|
| 67 |
+
model_dir="/scratch/ssd004/scratch/victorli/model-weights",
|
| 68 |
temp_dir="temp", # Change this to the path of the temporary directory
|
| 69 |
device="cuda:0",
|
| 70 |
model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
|
| 71 |
+
temperature=1.0,
|
| 72 |
top_p=0.95,
|
| 73 |
model_kwargs=model_kwargs,
|
| 74 |
rag_config=rag_config,
|
main.py
CHANGED
|
@@ -33,7 +33,7 @@ _ = load_dotenv()
|
|
| 33 |
def initialize_agent(
|
| 34 |
prompt_file: str,
|
| 35 |
tools_to_use: Optional[List[str]] = None,
|
| 36 |
-
model_dir: str = "/model-weights",
|
| 37 |
temp_dir: str = "temp",
|
| 38 |
device: str = "cpu",
|
| 39 |
model: str = "gpt-4.1-2025-04-14",
|
|
@@ -88,6 +88,7 @@ def initialize_agent(
|
|
| 88 |
"DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
|
| 89 |
"MedicalRAGTool": lambda: RAGTool(config=rag_config),
|
| 90 |
"WebBrowserTool": lambda: WebBrowserTool(),
|
|
|
|
| 91 |
"MedSAM2Tool": lambda: MedSAM2Tool(
|
| 92 |
device=device, cache_dir=model_dir, temp_dir=temp_dir
|
| 93 |
),
|
|
|
|
| 33 |
def initialize_agent(
|
| 34 |
prompt_file: str,
|
| 35 |
tools_to_use: Optional[List[str]] = None,
|
| 36 |
+
model_dir: str = "/scratch/ssd004/scratch/victorli/model-weights",
|
| 37 |
temp_dir: str = "temp",
|
| 38 |
device: str = "cpu",
|
| 39 |
model: str = "gpt-4.1-2025-04-14",
|
|
|
|
| 88 |
"DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
|
| 89 |
"MedicalRAGTool": lambda: RAGTool(config=rag_config),
|
| 90 |
"WebBrowserTool": lambda: WebBrowserTool(),
|
| 91 |
+
"DuckDuckGoSearchTool": lambda: DuckDuckGoSearchTool(),
|
| 92 |
"MedSAM2Tool": lambda: MedSAM2Tool(
|
| 93 |
device=device, cache_dir=model_dir, temp_dir=temp_dir
|
| 94 |
),
|
pyproject.toml
CHANGED
|
@@ -57,7 +57,6 @@ dependencies = [
|
|
| 57 |
"torch>=2.2.0",
|
| 58 |
"torchvision>=0.10.0",
|
| 59 |
"scikit-image>=0.18.0",
|
| 60 |
-
"gradio>=5.0.0",
|
| 61 |
"opencv-python>=4.8.0",
|
| 62 |
"matplotlib>=3.8.0",
|
| 63 |
"diffusers>=0.20.0",
|
|
@@ -65,13 +64,11 @@ dependencies = [
|
|
| 65 |
"pylibjpeg>=1.0.0",
|
| 66 |
"jupyter>=1.0.0",
|
| 67 |
"albumentations>=1.0.0",
|
| 68 |
-
"pyarrow>=10.0.0",
|
| 69 |
"chromadb>=0.0.10",
|
| 70 |
"pinecone-client>=3.2.2",
|
| 71 |
"langchain-pinecone>=0.0.1",
|
| 72 |
"langchain-google-genai>=0.1.0",
|
| 73 |
"ray>=2.9.0",
|
| 74 |
-
"langchain-sandbox>=0.0.6",
|
| 75 |
"seaborn>=0.12.0",
|
| 76 |
"huggingface_hub>=0.17.0",
|
| 77 |
"iopath>=0.1.10",
|
|
|
|
| 57 |
"torch>=2.2.0",
|
| 58 |
"torchvision>=0.10.0",
|
| 59 |
"scikit-image>=0.18.0",
|
|
|
|
| 60 |
"opencv-python>=4.8.0",
|
| 61 |
"matplotlib>=3.8.0",
|
| 62 |
"diffusers>=0.20.0",
|
|
|
|
| 64 |
"pylibjpeg>=1.0.0",
|
| 65 |
"jupyter>=1.0.0",
|
| 66 |
"albumentations>=1.0.0",
|
|
|
|
| 67 |
"chromadb>=0.0.10",
|
| 68 |
"pinecone-client>=3.2.2",
|
| 69 |
"langchain-pinecone>=0.0.1",
|
| 70 |
"langchain-google-genai>=0.1.0",
|
| 71 |
"ray>=2.9.0",
|
|
|
|
| 72 |
"seaborn>=0.12.0",
|
| 73 |
"huggingface_hub>=0.17.0",
|
| 74 |
"iopath>=0.1.10",
|