Spaces:
Sleeping
Sleeping
updates
Browse files- benchmarking/benchmarks/chestagentbench_benchmark.py +1 -2
- benchmarking/cli.py +4 -3
- benchmarking/llm_providers/__init__.py +2 -0
- benchmarking/llm_providers/base.py +3 -2
- benchmarking/llm_providers/google_provider.py +3 -4
- benchmarking/llm_providers/medrax_provider.py +2 -19
- benchmarking/llm_providers/openai_provider.py +1 -4
- benchmarking/runner.py +9 -11
- medrax/docs/system_prompts.txt +10 -5
- medrax/models/model_factory.py +2 -2
benchmarking/benchmarks/chestagentbench_benchmark.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
-
import os
|
| 2 |
import json
|
| 3 |
from pathlib import Path
|
| 4 |
-
from typing import Dict,
|
| 5 |
from .base import Benchmark, BenchmarkDataPoint
|
| 6 |
|
| 7 |
class ChestAgentBenchBenchmark(Benchmark):
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
from pathlib import Path
|
| 3 |
+
from typing import Dict, Optional, Any
|
| 4 |
from .base import Benchmark, BenchmarkDataPoint
|
| 5 |
|
| 6 |
class ChestAgentBenchBenchmark(Benchmark):
|
benchmarking/cli.py
CHANGED
|
@@ -13,7 +13,7 @@ def create_llm_provider(model_name: str, provider_type: str, **kwargs) -> LLMPro
|
|
| 13 |
|
| 14 |
Args:
|
| 15 |
model_name (str): Name of the model
|
| 16 |
-
provider_type (str): Type of provider (openai, google,
|
| 17 |
**kwargs: Additional configuration parameters
|
| 18 |
|
| 19 |
Returns:
|
|
@@ -22,6 +22,7 @@ def create_llm_provider(model_name: str, provider_type: str, **kwargs) -> LLMPro
|
|
| 22 |
provider_map = {
|
| 23 |
"openai": OpenAIProvider,
|
| 24 |
"google": GoogleProvider,
|
|
|
|
| 25 |
"medrax": MedRAXProvider,
|
| 26 |
}
|
| 27 |
|
|
@@ -111,13 +112,13 @@ def main():
|
|
| 111 |
# Run benchmark command
|
| 112 |
run_parser = subparsers.add_parser("run", help="Run a benchmark")
|
| 113 |
run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
|
| 114 |
-
run_parser.add_argument("--provider", required=True, choices=["openai", "google", "medrax"], help="LLM provider")
|
| 115 |
run_parser.add_argument("--benchmark", required=True, choices=["rexvqa", "chestagentbench"], help="Benchmark to run")
|
| 116 |
run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
|
| 117 |
run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
|
| 118 |
run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
|
| 119 |
run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
|
| 120 |
-
run_parser.add_argument("--max-tokens", type=int, default=
|
| 121 |
|
| 122 |
run_parser.set_defaults(func=run_benchmark_command)
|
| 123 |
|
|
|
|
| 13 |
|
| 14 |
Args:
|
| 15 |
model_name (str): Name of the model
|
| 16 |
+
provider_type (str): Type of provider (openai, google, xai, medrax)
|
| 17 |
**kwargs: Additional configuration parameters
|
| 18 |
|
| 19 |
Returns:
|
|
|
|
| 22 |
provider_map = {
|
| 23 |
"openai": OpenAIProvider,
|
| 24 |
"google": GoogleProvider,
|
| 25 |
+
"xai": XAIProvider,
|
| 26 |
"medrax": MedRAXProvider,
|
| 27 |
}
|
| 28 |
|
|
|
|
| 112 |
# Run benchmark command
|
| 113 |
run_parser = subparsers.add_parser("run", help="Run a benchmark")
|
| 114 |
run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
|
| 115 |
+
run_parser.add_argument("--provider", required=True, choices=["openai", "google", "xai", "medrax"], help="LLM provider")
|
| 116 |
run_parser.add_argument("--benchmark", required=True, choices=["rexvqa", "chestagentbench"], help="Benchmark to run")
|
| 117 |
run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
|
| 118 |
run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
|
| 119 |
run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
|
| 120 |
run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
|
| 121 |
+
run_parser.add_argument("--max-tokens", type=int, default=5000, help="Maximum tokens per response")
|
| 122 |
|
| 123 |
run_parser.set_defaults(func=run_benchmark_command)
|
| 124 |
|
benchmarking/llm_providers/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ from .base import LLMProvider, LLMRequest, LLMResponse
|
|
| 4 |
from .openai_provider import OpenAIProvider
|
| 5 |
from .google_provider import GoogleProvider
|
| 6 |
from .medrax_provider import MedRAXProvider
|
|
|
|
| 7 |
|
| 8 |
__all__ = [
|
| 9 |
"LLMProvider",
|
|
@@ -12,4 +13,5 @@ __all__ = [
|
|
| 12 |
"OpenAIProvider",
|
| 13 |
"GoogleProvider",
|
| 14 |
"MedRAXProvider",
|
|
|
|
| 15 |
]
|
|
|
|
| 4 |
from .openai_provider import OpenAIProvider
|
| 5 |
from .google_provider import GoogleProvider
|
| 6 |
from .medrax_provider import MedRAXProvider
|
| 7 |
+
from .xai_provider import XAIProvider
|
| 8 |
|
| 9 |
__all__ = [
|
| 10 |
"LLMProvider",
|
|
|
|
| 13 |
"OpenAIProvider",
|
| 14 |
"GoogleProvider",
|
| 15 |
"MedRAXProvider",
|
| 16 |
+
"XAIProvider",
|
| 17 |
]
|
benchmarking/llm_providers/base.py
CHANGED
|
@@ -14,7 +14,8 @@ class LLMRequest:
|
|
| 14 |
text: str
|
| 15 |
images: Optional[List[str]] = None # List of image paths
|
| 16 |
temperature: float = 0.7
|
| 17 |
-
|
|
|
|
| 18 |
additional_params: Optional[Dict[str, Any]] = None
|
| 19 |
|
| 20 |
|
|
@@ -47,7 +48,7 @@ class LLMProvider(ABC):
|
|
| 47 |
# Always load system prompt from file
|
| 48 |
try:
|
| 49 |
prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
|
| 50 |
-
self.system_prompt = prompts.get("
|
| 51 |
if self.system_prompt is None:
|
| 52 |
print(f"Warning: System prompt type 'MEDICAL_ASSISTANT' not found in medrax/docs/system_prompts.txt.")
|
| 53 |
except Exception as e:
|
|
|
|
| 14 |
text: str
|
| 15 |
images: Optional[List[str]] = None # List of image paths
|
| 16 |
temperature: float = 0.7
|
| 17 |
+
top_p: float = 0.95
|
| 18 |
+
max_tokens: int = 5000
|
| 19 |
additional_params: Optional[Dict[str, Any]] = None
|
| 20 |
|
| 21 |
|
|
|
|
| 48 |
# Always load system prompt from file
|
| 49 |
try:
|
| 50 |
prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
|
| 51 |
+
self.system_prompt = prompts.get("CHESTAGENTBENCH_PROMPT", None)
|
| 52 |
if self.system_prompt is None:
|
| 53 |
print(f"Warning: System prompt type 'MEDICAL_ASSISTANT' not found in medrax/docs/system_prompts.txt.")
|
| 54 |
except Exception as e:
|
benchmarking/llm_providers/google_provider.py
CHANGED
|
@@ -71,12 +71,11 @@ class GoogleProvider(LLMProvider):
|
|
| 71 |
# Update client parameters for this request
|
| 72 |
self.client.temperature = request.temperature
|
| 73 |
self.client.max_output_tokens = request.max_tokens
|
| 74 |
-
|
| 75 |
-
if request.additional_params and "top_p" in request.additional_params:
|
| 76 |
-
self.client.top_p = request.additional_params["top_p"]
|
| 77 |
|
| 78 |
response = self.client.invoke(messages)
|
| 79 |
-
|
|
|
|
| 80 |
duration = time.time() - start_time
|
| 81 |
|
| 82 |
# Extract response content
|
|
|
|
| 71 |
# Update client parameters for this request
|
| 72 |
self.client.temperature = request.temperature
|
| 73 |
self.client.max_output_tokens = request.max_tokens
|
| 74 |
+
self.client.top_p = request.top_p
|
|
|
|
|
|
|
| 75 |
|
| 76 |
response = self.client.invoke(messages)
|
| 77 |
+
print(response)
|
| 78 |
+
|
| 79 |
duration = time.time() - start_time
|
| 80 |
|
| 81 |
# Extract response content
|
benchmarking/llm_providers/medrax_provider.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
"""MedRAX LLM provider implementation."""
|
| 2 |
|
| 3 |
import time
|
| 4 |
-
import tempfile
|
| 5 |
import shutil
|
| 6 |
from pathlib import Path
|
| 7 |
|
|
@@ -65,14 +64,11 @@ class MedRAXProvider(LLMProvider):
|
|
| 65 |
# Prepare any additional model-specific kwargs
|
| 66 |
model_kwargs = {}
|
| 67 |
|
| 68 |
-
# Create temporary directory for this session
|
| 69 |
-
self.session_temp_dir = Path(tempfile.mkdtemp(prefix="medrax_bench_"))
|
| 70 |
-
|
| 71 |
agent, tools_dict = initialize_agent(
|
| 72 |
prompt_file="medrax/docs/system_prompts.txt",
|
| 73 |
tools_to_use=selected_tools,
|
| 74 |
model_dir="/model-weights",
|
| 75 |
-
temp_dir=
|
| 76 |
device="cpu",
|
| 77 |
model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
|
| 78 |
temperature=0.7,
|
|
@@ -122,7 +118,7 @@ class MedRAXProvider(LLMProvider):
|
|
| 122 |
for i, image_path in enumerate(valid_images):
|
| 123 |
print(f"Original image path: {image_path}")
|
| 124 |
# Copy image to session temp directory
|
| 125 |
-
dest_path =
|
| 126 |
print(f"Destination path: {dest_path}")
|
| 127 |
shutil.copy2(image_path, dest_path)
|
| 128 |
image_paths.append(str(dest_path))
|
|
@@ -189,16 +185,3 @@ class MedRAXProvider(LLMProvider):
|
|
| 189 |
duration=time.time() - start_time,
|
| 190 |
raw_response=None
|
| 191 |
)
|
| 192 |
-
|
| 193 |
-
def _cleanup_temp_files(self) -> None:
|
| 194 |
-
"""Clean up temporary files."""
|
| 195 |
-
try:
|
| 196 |
-
if hasattr(self, 'session_temp_dir') and self.session_temp_dir.exists():
|
| 197 |
-
shutil.rmtree(self.session_temp_dir)
|
| 198 |
-
print(f"Cleaned up temporary directory: {self.session_temp_dir}")
|
| 199 |
-
except Exception as e:
|
| 200 |
-
print(f"Warning: Failed to cleanup temp files: {e}")
|
| 201 |
-
|
| 202 |
-
def cleanup(self) -> None:
|
| 203 |
-
"""Clean up resources when done with the provider."""
|
| 204 |
-
self._cleanup_temp_files()
|
|
|
|
| 1 |
"""MedRAX LLM provider implementation."""
|
| 2 |
|
| 3 |
import time
|
|
|
|
| 4 |
import shutil
|
| 5 |
from pathlib import Path
|
| 6 |
|
|
|
|
| 64 |
# Prepare any additional model-specific kwargs
|
| 65 |
model_kwargs = {}
|
| 66 |
|
|
|
|
|
|
|
|
|
|
| 67 |
agent, tools_dict = initialize_agent(
|
| 68 |
prompt_file="medrax/docs/system_prompts.txt",
|
| 69 |
tools_to_use=selected_tools,
|
| 70 |
model_dir="/model-weights",
|
| 71 |
+
temp_dir="temp", # Change this to the path of the temporary directory
|
| 72 |
device="cpu",
|
| 73 |
model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
|
| 74 |
temperature=0.7,
|
|
|
|
| 118 |
for i, image_path in enumerate(valid_images):
|
| 119 |
print(f"Original image path: {image_path}")
|
| 120 |
# Copy image to session temp directory
|
| 121 |
+
dest_path = Path("temp") / f"image_{i}_{Path(image_path).name}"
|
| 122 |
print(f"Destination path: {dest_path}")
|
| 123 |
shutil.copy2(image_path, dest_path)
|
| 124 |
image_paths.append(str(dest_path))
|
|
|
|
| 185 |
duration=time.time() - start_time,
|
| 186 |
raw_response=None
|
| 187 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarking/llm_providers/openai_provider.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import time
|
| 5 |
-
from typing import Dict, Any
|
| 6 |
from tenacity import retry, wait_exponential, stop_after_attempt
|
| 7 |
from langchain_openai import ChatOpenAI
|
| 8 |
from langchain_core.messages import HumanMessage, SystemMessage
|
|
@@ -81,9 +80,7 @@ class OpenAIProvider(LLMProvider):
|
|
| 81 |
# Update client parameters for this request
|
| 82 |
self.client.temperature = request.temperature
|
| 83 |
self.client.max_tokens = request.max_tokens
|
| 84 |
-
|
| 85 |
-
if request.additional_params and "top_p" in request.additional_params:
|
| 86 |
-
self.client.model_kwargs = {"top_p": request.additional_params["top_p"]}
|
| 87 |
|
| 88 |
response = self.client.invoke(messages)
|
| 89 |
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import time
|
|
|
|
| 5 |
from tenacity import retry, wait_exponential, stop_after_attempt
|
| 6 |
from langchain_openai import ChatOpenAI
|
| 7 |
from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
|
| 80 |
# Update client parameters for this request
|
| 81 |
self.client.temperature = request.temperature
|
| 82 |
self.client.max_tokens = request.max_tokens
|
| 83 |
+
self.client.top_p = request.top_p
|
|
|
|
|
|
|
| 84 |
|
| 85 |
response = self.client.invoke(messages)
|
| 86 |
|
benchmarking/runner.py
CHANGED
|
@@ -36,7 +36,8 @@ class BenchmarkRunConfig:
|
|
| 36 |
output_dir: str
|
| 37 |
max_questions: Optional[int] = None
|
| 38 |
temperature: float = 0.7
|
| 39 |
-
|
|
|
|
| 40 |
additional_params: Optional[Dict[str, Any]] = None
|
| 41 |
|
| 42 |
|
|
@@ -167,14 +168,6 @@ class BenchmarkRunner:
|
|
| 167 |
# Save final results
|
| 168 |
summary = self._save_final_results(benchmark)
|
| 169 |
|
| 170 |
-
# Clean up provider resources
|
| 171 |
-
if hasattr(llm_provider, 'cleanup'):
|
| 172 |
-
try:
|
| 173 |
-
llm_provider.cleanup()
|
| 174 |
-
self.logger.info("Provider cleanup completed")
|
| 175 |
-
except Exception as e:
|
| 176 |
-
self.logger.warning(f"Provider cleanup failed: {e}")
|
| 177 |
-
|
| 178 |
self.logger.info(f"Benchmark run completed: {self.run_id}")
|
| 179 |
self.logger.info(f"Final accuracy: {summary['results']['accuracy']:.2f}%")
|
| 180 |
self.logger.info(f"Total duration: {summary['results']['total_duration']:.2f}s")
|
|
@@ -203,6 +196,7 @@ class BenchmarkRunner:
|
|
| 203 |
text=data_point.text,
|
| 204 |
images=data_point.images,
|
| 205 |
temperature=self.config.temperature,
|
|
|
|
| 206 |
max_tokens=self.config.max_tokens,
|
| 207 |
additional_params=self.config.additional_params
|
| 208 |
)
|
|
@@ -260,10 +254,14 @@ class BenchmarkRunner:
|
|
| 260 |
Returns:
|
| 261 |
str: The extracted answer
|
| 262 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
# This is a simple implementation - may need customization per benchmark
|
| 264 |
# For multiple choice, look for single letters A, B, C, D, E, F
|
| 265 |
-
|
| 266 |
-
# Look for patterns like "A", "B)", "(C)", "Answer: D", etc.
|
| 267 |
patterns = [
|
| 268 |
r'\b([A-F])\b', # Single letter
|
| 269 |
r'\b([A-F])\)', # Letter with closing parenthesis
|
|
|
|
| 36 |
output_dir: str
|
| 37 |
max_questions: Optional[int] = None
|
| 38 |
temperature: float = 0.7
|
| 39 |
+
top_p: float = 0.95
|
| 40 |
+
max_tokens: int = 5000
|
| 41 |
additional_params: Optional[Dict[str, Any]] = None
|
| 42 |
|
| 43 |
|
|
|
|
| 168 |
# Save final results
|
| 169 |
summary = self._save_final_results(benchmark)
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
self.logger.info(f"Benchmark run completed: {self.run_id}")
|
| 172 |
self.logger.info(f"Final accuracy: {summary['results']['accuracy']:.2f}%")
|
| 173 |
self.logger.info(f"Total duration: {summary['results']['total_duration']:.2f}s")
|
|
|
|
| 196 |
text=data_point.text,
|
| 197 |
images=data_point.images,
|
| 198 |
temperature=self.config.temperature,
|
| 199 |
+
top_p=self.config.top_p,
|
| 200 |
max_tokens=self.config.max_tokens,
|
| 201 |
additional_params=self.config.additional_params
|
| 202 |
)
|
|
|
|
| 254 |
Returns:
|
| 255 |
str: The extracted answer
|
| 256 |
"""
|
| 257 |
+
# First, look for the 'Final answer: <|A|>' format
|
| 258 |
+
final_answer_pattern = r'Final answer:\s*<\|([A-F])\|>'
|
| 259 |
+
match = re.search(final_answer_pattern, response_text)
|
| 260 |
+
if match:
|
| 261 |
+
return match.group(1).upper()
|
| 262 |
+
|
| 263 |
# This is a simple implementation - may need customization per benchmark
|
| 264 |
# For multiple choice, look for single letters A, B, C, D, E, F
|
|
|
|
|
|
|
| 265 |
patterns = [
|
| 266 |
r'\b([A-F])\b', # Single letter
|
| 267 |
r'\b([A-F])\)', # Letter with closing parenthesis
|
medrax/docs/system_prompts.txt
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
[MEDICAL_ASSISTANT]
|
| 2 |
You are an expert medical AI assistant who can answer any medical questions and analyze medical images similar to a doctor.
|
| 3 |
Solve using your own vision and reasoning and use tools to complement your reasoning.
|
| 4 |
-
|
| 5 |
-
|
| 6 |
If you need to look up some information before asking a follow up question, you are allowed to do that.
|
| 7 |
-
When encountering a multiple-choice question, give the final answer in closed parentheses without further elaborations; give a definitive answer even if you're not sure.
|
| 8 |
|
| 9 |
CITATION REQUIREMENTS:
|
| 10 |
- When referencing information from RAG and/or web search tools, ALWAYS include numbered citations [1], [2], [3], etc.
|
|
@@ -17,5 +16,11 @@ Examples:
|
|
| 17 |
- "The medical literature indicates [2] that this condition typically presents with..."
|
| 18 |
- "Based on clinical guidelines [3], the recommended treatment approach is..."
|
| 19 |
|
| 20 |
-
[
|
| 21 |
-
You are
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
[MEDICAL_ASSISTANT]
|
| 2 |
You are an expert medical AI assistant who can answer any medical questions and analyze medical images similar to a doctor.
|
| 3 |
Solve using your own vision and reasoning and use tools to complement your reasoning.
|
| 4 |
+
You can make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
|
| 5 |
+
Think critically about and criticize the tool outputs.
|
| 6 |
If you need to look up some information before asking a follow up question, you are allowed to do that.
|
|
|
|
| 7 |
|
| 8 |
CITATION REQUIREMENTS:
|
| 9 |
- When referencing information from RAG and/or web search tools, ALWAYS include numbered citations [1], [2], [3], etc.
|
|
|
|
| 16 |
- "The medical literature indicates [2] that this condition typically presents with..."
|
| 17 |
- "Based on clinical guidelines [3], the recommended treatment approach is..."
|
| 18 |
|
| 19 |
+
[CHESTAGENTBENCH_PROMPT]
|
| 20 |
+
You are an expert medical AI assistant who can answer any medical questions and analyze medical images similar to a doctor.
|
| 21 |
+
Solve using your own vision and reasoning and use tools (if available) to complement your reasoning.
|
| 22 |
+
You can make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
|
| 23 |
+
Think critically about and criticize the tool outputs.
|
| 24 |
+
If you need to look up some information before asking a follow up question, you are allowed to do that.
|
| 25 |
+
When encountering a multiple-choice question, your final response should end with "Final answer: <|A|>" from list of possible choices A, B, C, D, E, F.
|
| 26 |
+
It is extremely important that you strictly answer in the format mentioned above.
|
medrax/models/model_factory.py
CHANGED
|
@@ -36,8 +36,8 @@ class ModelFactory:
|
|
| 36 |
"default_base_url": "https://openrouter.ai/api/v1",
|
| 37 |
},
|
| 38 |
"grok": {
|
| 39 |
-
|
| 40 |
-
|
| 41 |
}
|
| 42 |
# Add more providers with default configurations here
|
| 43 |
}
|
|
|
|
| 36 |
"default_base_url": "https://openrouter.ai/api/v1",
|
| 37 |
},
|
| 38 |
"grok": {
|
| 39 |
+
"class": ChatXAI,
|
| 40 |
+
"env_key": "XAI_API_KEY",
|
| 41 |
}
|
| 42 |
# Add more providers with default configurations here
|
| 43 |
}
|