VictorLJZ commited on
Commit
e08f161
·
1 Parent(s): a8f2960
benchmarking/cli.py CHANGED
@@ -60,14 +60,6 @@ def run_benchmark_command(args) -> None:
60
 
61
  # Create LLM provider
62
  provider_kwargs = {}
63
- if args.provider == "medrax":
64
- provider_kwargs = {
65
- "tools_to_use": args.medrax_tools.split(",") if args.medrax_tools else None,
66
- "model_dir": args.model_dir,
67
- "temp_dir": args.temp_dir,
68
- "device": args.device,
69
- "rag_config": None, # You might want to add RAG config options
70
- }
71
 
72
  llm_provider = create_llm_provider(args.model, args.provider, **provider_kwargs)
73
 
@@ -82,12 +74,8 @@ def run_benchmark_command(args) -> None:
82
  benchmark_name=args.benchmark,
83
  output_dir=args.output_dir,
84
  max_questions=args.max_questions,
85
- start_index=args.start_index,
86
  temperature=args.temperature,
87
- max_tokens=args.max_tokens,
88
- system_prompt=args.system_prompt,
89
- save_frequency=args.save_frequency,
90
- log_level=args.log_level,
91
  )
92
 
93
  # Run benchmark
@@ -126,39 +114,11 @@ def main():
126
  run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
127
  run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
128
  run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
129
- run_parser.add_argument("--start-index", type=int, default=0, help="Starting index for questions")
130
  run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
131
  run_parser.add_argument("--max-tokens", type=int, default=1500, help="Maximum tokens per response")
132
- run_parser.add_argument("--system-prompt", help="System prompt for the model")
133
- run_parser.add_argument("--save-frequency", type=int, default=10, help="Save results every N questions")
134
- run_parser.add_argument("--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"])
135
-
136
- # MedRAX-specific arguments
137
- run_parser.add_argument("--medrax-tools", help="Comma-separated list of tools for MedRAX (e.g., WebBrowserTool,MedicalRAGTool)")
138
- run_parser.add_argument("--model-dir", default="/model-weights", help="Directory containing model weights for MedRAX")
139
- run_parser.add_argument("--temp-dir", default="temp", help="Temporary directory for MedRAX")
140
- run_parser.add_argument("--device", default="cuda", help="Device for MedRAX models")
141
-
142
-
143
 
144
  run_parser.set_defaults(func=run_benchmark_command)
145
 
146
- # Evaluate results command
147
- eval_parser = subparsers.add_parser("evaluate", help="Evaluate benchmark results")
148
- eval_parser.add_argument("results_files", nargs="+", help="Path(s) to results files")
149
- eval_parser.add_argument("--output-dir", default="evaluation_results", help="Output directory for evaluation")
150
- eval_parser.add_argument("--report-name", default="evaluation_report", help="Name for the evaluation report")
151
- eval_parser.add_argument("--statistical-test", action="store_true", help="Run statistical significance tests")
152
- eval_parser.set_defaults(func=evaluate_results_command)
153
-
154
- # List providers command
155
- list_providers_parser = subparsers.add_parser("list-providers", help="List available LLM providers")
156
- list_providers_parser.set_defaults(func=list_providers_command)
157
-
158
- # List benchmarks command
159
- list_benchmarks_parser = subparsers.add_parser("list-benchmarks", help="List available benchmarks")
160
- list_benchmarks_parser.set_defaults(func=list_benchmarks_command)
161
-
162
  args = parser.parse_args()
163
 
164
  if args.command is None:
 
60
 
61
  # Create LLM provider
62
  provider_kwargs = {}
 
 
 
 
 
 
 
 
63
 
64
  llm_provider = create_llm_provider(args.model, args.provider, **provider_kwargs)
65
 
 
74
  benchmark_name=args.benchmark,
75
  output_dir=args.output_dir,
76
  max_questions=args.max_questions,
 
77
  temperature=args.temperature,
78
+ max_tokens=args.max_tokens
 
 
 
79
  )
80
 
81
  # Run benchmark
 
114
  run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
115
  run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
116
  run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
 
117
  run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
118
  run_parser.add_argument("--max-tokens", type=int, default=1500, help="Maximum tokens per response")
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  run_parser.set_defaults(func=run_benchmark_command)
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  args = parser.parse_args()
123
 
124
  if args.command is None:
benchmarking/llm_providers/base.py CHANGED
@@ -1,10 +1,9 @@
1
  """Base class for LLM providers."""
2
 
3
  from abc import ABC, abstractmethod
4
- from typing import Dict, List, Optional, Any, Union
5
  from dataclasses import dataclass
6
  import base64
7
- import time
8
  from pathlib import Path
9
 
10
 
 
1
  """Base class for LLM providers."""
2
 
3
  from abc import ABC, abstractmethod
4
+ from typing import Dict, List, Optional, Any
5
  from dataclasses import dataclass
6
  import base64
 
7
  from pathlib import Path
8
 
9
 
benchmarking/llm_providers/google_provider.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  import os
4
  import time
5
- from typing import Dict, Any
6
  from tenacity import retry, wait_exponential, stop_after_attempt
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_core.messages import HumanMessage, SystemMessage
 
2
 
3
  import os
4
  import time
 
5
  from tenacity import retry, wait_exponential, stop_after_attempt
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_core.messages import HumanMessage, SystemMessage
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -1,23 +1,14 @@
1
  """MedRAX LLM provider implementation."""
2
 
3
- import os
4
  import time
5
  import tempfile
6
  import shutil
7
- from typing import Dict, Any, List, Optional
8
  from pathlib import Path
9
- import json
10
 
11
  from .base import LLMProvider, LLMRequest, LLMResponse
12
 
13
- # Import MedRAX components
14
- from medrax.agent import Agent
15
- from medrax.tools import *
16
- from medrax.utils import load_prompts_from_file
17
  from medrax.rag.rag import RAGConfig
18
- from medrax.models import ModelFactory
19
- from langgraph.checkpoint.memory import MemorySaver
20
- from langchain_core.messages import HumanMessage
21
 
22
 
23
  class MedRAXProvider(LLMProvider):
@@ -30,21 +21,7 @@ class MedRAXProvider(LLMProvider):
30
  model_name (str): Base LLM model name (e.g., "gpt-4.1-2025-04-14")
31
  **kwargs: Additional configuration parameters
32
  """
33
- # MedRAX-specific configuration
34
- self.tools_to_use = kwargs.get("tools_to_use", [
35
- "WebBrowserTool",
36
- "MedicalRAGTool",
37
- "PythonSandboxTool"
38
- ])
39
- self.model_dir = kwargs.get("model_dir", "/model-weights")
40
- self.temp_dir = kwargs.get("temp_dir", "temp")
41
- self.device = kwargs.get("device", "cuda")
42
- self.temperature = kwargs.get("temperature", 0.7)
43
- self.top_p = kwargs.get("top_p", 0.95)
44
- self.rag_config = kwargs.get("rag_config")
45
- self.prompt_file = kwargs.get("prompt_file", "medrax/docs/system_prompts.txt")
46
-
47
- # Initialize agent as None, will be created in _setup
48
  self.agent = None
49
  self.tools_dict = None
50
 
@@ -53,71 +30,60 @@ class MedRAXProvider(LLMProvider):
53
  def _setup(self) -> None:
54
  """Set up MedRAX agent system."""
55
  try:
56
- # Load system prompts
57
- prompts = load_prompts_from_file(self.prompt_file)
58
- prompt = prompts["MEDICAL_ASSISTANT"]
59
-
60
- # Initialize tools
61
- all_tools = {
62
- "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=self.device),
63
- "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=self.model_dir, device=self.device),
64
- "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=self.device),
65
- "LlavaMedTool": lambda: LlavaMedTool(cache_dir=self.model_dir, device=self.device, load_in_8bit=True),
66
- "XRayVQATool": lambda: XRayVQATool(cache_dir=self.model_dir, device=self.device),
67
- "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
68
- cache_dir=self.model_dir, device=self.device
69
- ),
70
- "XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
71
- cache_dir=self.model_dir, temp_dir=self.temp_dir, load_in_8bit=True, device=self.device
72
- ),
73
- "ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool(
74
- model_path=f"{self.model_dir}/roentgen", temp_dir=self.temp_dir, device=self.device
75
- ),
76
- "ImageVisualizerTool": lambda: ImageVisualizerTool(),
77
- "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=self.temp_dir),
78
- "MedicalRAGTool": lambda: RAGTool(config=self.rag_config) if self.rag_config else None,
79
- "WebBrowserTool": lambda: WebBrowserTool(),
80
- }
81
-
82
- # Add PythonSandboxTool if available
83
- try:
84
- all_tools["PythonSandboxTool"] = lambda: create_python_sandbox()
85
- except Exception as e:
86
- print(f"Warning: PythonSandboxTool not available: {e}")
87
-
88
- # Initialize selected tools
89
- self.tools_dict = {}
90
- for tool_name in self.tools_to_use:
91
- if tool_name in all_tools:
92
- try:
93
- tool_instance = all_tools[tool_name]()
94
- if tool_instance is not None:
95
- self.tools_dict[tool_name] = tool_instance
96
- except Exception as e:
97
- print(f"Warning: Failed to initialize {tool_name}: {e}")
98
-
99
- # Set up checkpointing
100
- checkpointer = MemorySaver()
101
-
102
- # Create the language model
103
- llm = ModelFactory.create_model(
104
- model_name=self.model_name,
105
- temperature=self.temperature,
106
- top_p=self.top_p
107
- )
108
-
109
- # Create the agent
110
- self.agent = Agent(
111
- llm,
112
- tools=list(self.tools_dict.values()),
113
- log_tools=False, # Disable logging for benchmarking
114
- system_prompt=prompt,
115
- checkpointer=checkpointer,
116
- debug=False,
117
  )
118
-
 
 
 
119
  # Create temporary directory for this session
120
  self.session_temp_dir = Path(tempfile.mkdtemp(prefix="medrax_bench_"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  print(f"MedRAX agent initialized with tools: {list(self.tools_dict.keys())}")
123
 
 
1
  """MedRAX LLM provider implementation."""
2
 
 
3
  import time
4
  import tempfile
5
  import shutil
 
6
  from pathlib import Path
 
7
 
8
  from .base import LLMProvider, LLMRequest, LLMResponse
9
 
 
 
 
 
10
  from medrax.rag.rag import RAGConfig
11
+ from main import initialize_agent
 
 
12
 
13
 
14
  class MedRAXProvider(LLMProvider):
 
21
  model_name (str): Base LLM model name (e.g., "gpt-4.1-2025-04-14")
22
  **kwargs: Additional configuration parameters
23
  """
24
+ self.model_name = model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  self.agent = None
26
  self.tools_dict = None
27
 
 
30
  def _setup(self) -> None:
31
  """Set up MedRAX agent system."""
32
  try:
33
+ print("Starting server...")
34
+
35
+ selected_tools = [
36
+ "ImageVisualizerTool", # For displaying images in the UI
37
+ # "DicomProcessorTool", # For processing DICOM medical image files
38
+ # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
39
+ # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
40
+ # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
41
+ # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
42
+ # "XRayVQATool", # For visual question answering on X-rays
43
+ # "LlavaMedTool", # For multimodal medical image understanding
44
+ # "XRayPhraseGroundingTool", # For locating described features in X-rays
45
+ # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
46
+ "WebBrowserTool", # For web browsing and search capabilities
47
+ "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
48
+ "PythonSandboxTool", # Add the Python sandbox tool
49
+ ]
50
+
51
+ rag_config = RAGConfig(
52
+ model="command-a-03-2025", # Chat model for generating responses
53
+ embedding_model="embed-v4.0", # Embedding model for the RAG system
54
+ rerank_model="rerank-v3.5", # Reranking model for the RAG system
55
+ temperature=0.3,
56
+ pinecone_index_name="medrax2", # Name for the Pinecone index
57
+ chunk_size=1500,
58
+ chunk_overlap=300,
59
+ retriever_k=7,
60
+ local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
61
+ huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
62
+ dataset_split="train", # Which split of the datasets to use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
+
65
+ # Prepare any additional model-specific kwargs
66
+ model_kwargs = {}
67
+
68
  # Create temporary directory for this session
69
  self.session_temp_dir = Path(tempfile.mkdtemp(prefix="medrax_bench_"))
70
+
71
+ agent, tools_dict = initialize_agent(
72
+ prompt_file="medrax/docs/system_prompts.txt",
73
+ tools_to_use=selected_tools,
74
+ model_dir="/model-weights",
75
+ temp_dir=self.session_temp_dir, # Change this to the path of the temporary directory
76
+ device="cuda",
77
+ model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
78
+ temperature=0.7,
79
+ top_p=0.95,
80
+ model_kwargs=model_kwargs,
81
+ rag_config=rag_config,
82
+ debug=True,
83
+ )
84
+
85
+ self.agent = agent
86
+ self.tools_dict = tools_dict
87
 
88
  print(f"MedRAX agent initialized with tools: {list(self.tools_dict.keys())}")
89
 
main.py CHANGED
@@ -9,14 +9,12 @@ The system uses OpenAI's language models for reasoning and can be configured
9
  with different model weights, tools, and parameters.
10
  """
11
 
12
- import os
13
  import warnings
14
- from typing import Dict, List, Optional, Tuple, Any
15
  from dotenv import load_dotenv
16
  from transformers import logging
17
 
18
  from langgraph.checkpoint.memory import MemorySaver
19
- from langchain_openai import ChatOpenAI
20
  from medrax.models import ModelFactory
21
 
22
  from interface import create_demo
@@ -138,7 +136,7 @@ if __name__ == "__main__":
138
  # Example: initialize with only specific tools
139
  # Here three tools are commented out, you can uncomment them to use them
140
  selected_tools = [
141
- # "ImageVisualizerTool", # For displaying images in the UI
142
  # "DicomProcessorTool", # For processing DICOM medical image files
143
  # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
144
  # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
@@ -172,14 +170,6 @@ if __name__ == "__main__":
172
  # Prepare any additional model-specific kwargs
173
  model_kwargs = {}
174
 
175
- # Set up API keys for the web browser tool
176
- # You'll need to set these environment variables:
177
- # - GOOGLE_SEARCH_API_KEY: Your Google Custom Search API key
178
- # - GOOGLE_SEARCH_ENGINE_ID: Your Google Custom Search Engine ID
179
- # - COHERE_API_KEY: Your Cohere API key
180
- # - OPENAI_API_KEY: Your OpenAI API key
181
- # - PINECONE_API_KEY: Your Pinecone API key
182
-
183
  agent, tools_dict = initialize_agent(
184
  prompt_file="medrax/docs/system_prompts.txt",
185
  tools_to_use=selected_tools,
 
9
  with different model weights, tools, and parameters.
10
  """
11
 
 
12
  import warnings
13
+ from typing import Dict, List, Optional, Any
14
  from dotenv import load_dotenv
15
  from transformers import logging
16
 
17
  from langgraph.checkpoint.memory import MemorySaver
 
18
  from medrax.models import ModelFactory
19
 
20
  from interface import create_demo
 
136
  # Example: initialize with only specific tools
137
  # Here three tools are commented out, you can uncomment them to use them
138
  selected_tools = [
139
+ "ImageVisualizerTool", # For displaying images in the UI
140
  # "DicomProcessorTool", # For processing DICOM medical image files
141
  # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
142
  # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
 
170
  # Prepare any additional model-specific kwargs
171
  model_kwargs = {}
172
 
 
 
 
 
 
 
 
 
173
  agent, tools_dict = initialize_agent(
174
  prompt_file="medrax/docs/system_prompts.txt",
175
  tools_to_use=selected_tools,