Adibvafa commited on
Commit
5849e07
·
2 Parent(s): 6fd286c 352f092

Fix merge conflicts

Browse files
README.md CHANGED
@@ -15,7 +15,7 @@ Chest X-rays (CXRs) play an integral role in driving critical decisions in disea
15
  ## MedRAX
16
  MedRAX is built on a robust technical foundation:
17
  - **Core Architecture**: Built on LangChain and LangGraph frameworks
18
- - **Language Model**: Uses GPT-4o with vision capabilities as the backbone LLM
19
  - **Deployment**: Supports both local and cloud-based deployments
20
  - **Interface**: Production-ready interface built with Gradio
21
  - **Modular Design**: Tool-agnostic architecture allowing easy integration of new capabilities
@@ -27,6 +27,7 @@ MedRAX is built on a robust technical foundation:
27
  - **Report Generation**: Implements SwinV2 Transformer trained on CheXpert Plus for detailed medical reporting
28
  - **Disease Classification**: Leverages DenseNet-121 from TorchXRayVision for detecting 18 pathology classes
29
  - **X-ray Generation**: Utilizes RoentGen for synthetic CXR generation
 
30
  - **Utilities**: Includes DICOM processing, visualization tools, and custom plotting capabilities
31
  <br><br>
32
 
@@ -103,7 +104,8 @@ MedRAX supports selective tool initialization, allowing you to use only the tool
103
  ```python
104
  selected_tools = [
105
  "ImageVisualizerTool",
106
- "ChestXRayClassifierTool",
 
107
  "ChestXRaySegmentationTool",
108
  # Add or remove tools as needed
109
  ]
@@ -120,9 +122,17 @@ agent, tools_dict = initialize_agent(
120
 
121
  The following tools will automatically download their model weights when initialized:
122
 
123
- ### Classification Tool
124
  ```python
125
- ChestXRayClassifierTool(device=device)
 
 
 
 
 
 
 
 
126
  ```
127
 
128
  ### Segmentation Tool
@@ -180,6 +190,7 @@ No additional model weights required:
180
  ```python
181
  ImageVisualizerTool()
182
  DicomProcessorTool(temp_dir=temp_dir)
 
183
  ```
184
  <br>
185
 
@@ -239,12 +250,45 @@ The `MedicalRAGTool` uses a Pinecone vector database to store and retrieve medic
239
  - Some tools (LLaVA-Med, Grounding) are more resource-intensive
240
  <br>
241
 
242
- ### Local LLMs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  If you are running a local LLM using frameworks like [Ollama](https://ollama.com/) or [LM Studio](https://lmstudio.ai/), you need to configure your environment variables accordingly. For example:
244
  ```
245
  export OPENAI_BASE_URL="http://localhost:11434/v1"
246
  export OPENAI_API_KEY="ollama"
247
  ```
 
 
 
 
 
 
 
248
  <br>
249
 
250
  ## Star History
 
15
  ## MedRAX
16
  MedRAX is built on a robust technical foundation:
17
  - **Core Architecture**: Built on LangChain and LangGraph frameworks
18
+ - **Language Models**: Supports multiple LLM providers including OpenAI (GPT-4o) and Google (Gemini) models
19
  - **Deployment**: Supports both local and cloud-based deployments
20
  - **Interface**: Production-ready interface built with Gradio
21
  - **Modular Design**: Tool-agnostic architecture allowing easy integration of new capabilities
 
27
  - **Report Generation**: Implements SwinV2 Transformer trained on CheXpert Plus for detailed medical reporting
28
  - **Disease Classification**: Leverages DenseNet-121 from TorchXRayVision for detecting 18 pathology classes
29
  - **X-ray Generation**: Utilizes RoentGen for synthetic CXR generation
30
+ - **Web Browser**: Provides web search capabilities and URL content retrieval using Google Custom Search API
31
  - **Utilities**: Includes DICOM processing, visualization tools, and custom plotting capabilities
32
  <br><br>
33
 
 
104
  ```python
105
  selected_tools = [
106
  "ImageVisualizerTool",
107
+ "TorchXRayVisionClassifierTool", # Renamed from ChestXRayClassifierTool
108
+ "ArcPlusClassifierTool", # New ArcPlus classifier
109
  "ChestXRaySegmentationTool",
110
  # Add or remove tools as needed
111
  ]
 
122
 
123
  The following tools will automatically download their model weights when initialized:
124
 
125
+ ### Classification Tools
126
  ```python
127
+ # TorchXRayVision-based classifier (original)
128
+ TorchXRayVisionClassifierTool(device=device)
129
+
130
+ # ArcPlus SwinTransformer-based classifier (new)
131
+ ArcPlusClassifierTool(
132
+ model_path="/path/to/Ark6_swinLarge768_ep50.pth.tar", # Optional
133
+ num_classes=18, # Default
134
+ device=device
135
+ )
136
  ```
137
 
138
  ### Segmentation Tool
 
190
  ```python
191
  ImageVisualizerTool()
192
  DicomProcessorTool(temp_dir=temp_dir)
193
+ WebBrowserTool() # Requires Google Search API credentials
194
  ```
195
  <br>
196
 
 
250
  - Some tools (LLaVA-Med, Grounding) are more resource-intensive
251
  <br>
252
 
253
+ ### Language Model Options
254
+ MedRAX supports multiple language model providers:
255
+
256
+ #### OpenAI Models
257
+ Supported prefixes: `gpt-` and `chatgpt-`
258
+ ```
259
+ export OPENAI_API_KEY="your-openai-api-key"
260
+ export OPENAI_BASE_URL="https://api.openai.com/v1" # Optional for custom endpoints
261
+ ```
262
+
263
+ #### Google Gemini Models
264
+ Supported prefix: `gemini-`
265
+ ```
266
+ export GOOGLE_API_KEY="your-google-api-key"
267
+ ```
268
+
269
+ #### OpenRouter Models (Open Source & Proprietary)
270
+ Supported prefix: `openrouter-`
271
+
272
+ Access many open source and proprietary models via [OpenRouter](https://openrouter.ai/):
273
+ ```
274
+ export OPENROUTER_API_KEY="your-openrouter-api-key"
275
+ ```
276
+
277
+ **Note:** Tool compatibility may vary with open-source models. For best results with tools, we recommend using OpenAI or Google Gemini models.
278
+
279
+ #### Local LLMs
280
  If you are running a local LLM using frameworks like [Ollama](https://ollama.com/) or [LM Studio](https://lmstudio.ai/), you need to configure your environment variables accordingly. For example:
281
  ```
282
  export OPENAI_BASE_URL="http://localhost:11434/v1"
283
  export OPENAI_API_KEY="ollama"
284
  ```
285
+
286
+ #### WebBrowserTool Configuration
287
+ If you're using the WebBrowserTool, you'll need to set these environment variables:
288
+ ```
289
+ export GOOGLE_SEARCH_API_KEY="your-google-search-api-key"
290
+ export GOOGLE_SEARCH_ENGINE_ID="your-google-search-engine-id"
291
+ ```
292
  <br>
293
 
294
  ## Star History
main.py CHANGED
@@ -17,6 +17,7 @@ from transformers import logging
17
 
18
  from langgraph.checkpoint.memory import MemorySaver
19
  from langchain_openai import ChatOpenAI
 
20
 
21
  from interface import create_demo
22
  from medrax.agent import *
@@ -37,12 +38,12 @@ def initialize_agent(
37
  model_dir: str = "/model-weights",
38
  temp_dir: str = "temp",
39
  device: str = "cpu",
40
- model: str = "gpt-4o",
41
  temperature: float = 0.7,
42
  top_p: float = 0.95,
43
  rag_config: Optional[RAGConfig] = None,
44
- openai_kwargs: Dict[str, Any] = {},
45
- ) -> Tuple[Agent, Dict[str, BaseTool]]:
46
  """Initialize the MedRAX agent with specified tools and configuration.
47
 
48
  Args:
@@ -51,11 +52,11 @@ def initialize_agent(
51
  model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights".
52
  temp_dir (str, optional): Directory for temporary files. Defaults to "temp".
53
  device (str, optional): Device to run models on. Defaults to "cuda".
54
- model (str, optional): Model to use. Defaults to "chatgpt-4o-latest".
55
  temperature (float, optional): Temperature for the model. Defaults to 0.7.
56
  top_p (float, optional): Top P for the model. Defaults to 0.95.
57
  rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
58
- openai_kwargs (dict, optional): Additional keyword arguments for OpenAI API, such as API key and base URL.
59
 
60
  Returns:
61
  Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
@@ -64,9 +65,9 @@ def initialize_agent(
64
  prompts = load_prompts_from_file(prompt_file)
65
  prompt = prompts["MEDICAL_ASSISTANT"]
66
 
67
- # Define all available tools with their initialization functions
68
- all_tools: Dict[str, callable] = {
69
- "ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device),
70
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
71
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
72
  "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
@@ -82,6 +83,7 @@ def initialize_agent(
82
  "ImageVisualizerTool": lambda: ImageVisualizerTool(),
83
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
84
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
 
85
  }
86
 
87
  # Initialize only selected tools or all if none specified
@@ -94,12 +96,18 @@ def initialize_agent(
94
  # Set up checkpointing for conversation state
95
  checkpointer = MemorySaver()
96
 
97
- # Initialize the language model
98
- model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
 
 
 
 
 
 
 
99
 
100
- # Create the agent with the specified model, tools, and configuration
101
  agent = Agent(
102
- model,
103
  tools=list(tools_dict.values()),
104
  log_tools=True,
105
  log_dir="logs",
@@ -118,20 +126,21 @@ if __name__ == "__main__":
118
  """
119
  print("Starting server...")
120
 
121
- # Define which tools to use in the application
122
- # Each tool provides specific medical imaging functionality
123
- # You can uncomment the tools you dont want to use
124
- tools_to_use = [
125
- # "ImageVisualizerTool", # For displaying images in the UI
126
  # "DicomProcessorTool", # For processing DICOM medical image files
127
- # "ChestXRayClassifierTool", # For classifying chest X-ray images
 
128
  # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
129
  # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
130
  # "XRayVQATool", # For visual question answering on X-rays
131
  # "LlavaMedTool", # For multimodal medical image understanding
132
  # "XRayPhraseGroundingTool", # For locating described features in X-rays
133
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
134
- "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
135
  ]
136
 
137
  # Configure the Retrieval Augmented Generation (RAG) system
@@ -147,26 +156,28 @@ if __name__ == "__main__":
147
  use_medrag_textbooks=True, # Set to True if you want to use the MedRAG textbooks dataset
148
  )
149
 
150
- # Prepare OpenAI API configuration from environment variables
151
- openai_kwargs: Dict[str, str] = {}
152
- if api_key := os.getenv("OPENAI_API_KEY"):
153
- openai_kwargs["api_key"] = api_key
154
 
155
- if base_url := os.getenv("OPENAI_BASE_URL"):
156
- openai_kwargs["base_url"] = base_url
 
 
 
 
 
157
 
158
- # Initialize the agent with all configured components
159
  agent, tools_dict = initialize_agent(
160
- prompt_file="medrax/docs/system_prompts.txt", # File containing system instructions
161
- tools_to_use=tools_to_use,
162
- model_dir="/model-weights", # Change this to the path of the model weights
163
  temp_dir="temp", # Change this to the path of the temporary directory
164
- device="cpu", # Change this to the device you want to use
165
- model="gpt-4o", # Change this to the model you want to use, e.g. gpt-4o-mini
166
  temperature=0.7,
167
  top_p=0.95,
 
168
  rag_config=rag_config,
169
- openai_kwargs=openai_kwargs,
170
  )
171
 
172
  # Create and launch the web interface
 
17
 
18
  from langgraph.checkpoint.memory import MemorySaver
19
  from langchain_openai import ChatOpenAI
20
+ from medrax.models import ModelFactory
21
 
22
  from interface import create_demo
23
  from medrax.agent import *
 
38
  model_dir: str = "/model-weights",
39
  temp_dir: str = "temp",
40
  device: str = "cpu",
41
+ model: str = "gpt-4.1-2025-04-14",
42
  temperature: float = 0.7,
43
  top_p: float = 0.95,
44
  rag_config: Optional[RAGConfig] = None,
45
+ model_kwargs: Dict[str, Any] = {},
46
+ ):
47
  """Initialize the MedRAX agent with specified tools and configuration.
48
 
49
  Args:
 
52
  model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights".
53
  temp_dir (str, optional): Directory for temporary files. Defaults to "temp".
54
  device (str, optional): Device to run models on. Defaults to "cuda".
55
+ model (str, optional): Model to use. Defaults to "gpt-4o".
56
  temperature (float, optional): Temperature for the model. Defaults to 0.7.
57
  top_p (float, optional): Top P for the model. Defaults to 0.95.
58
  rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
59
+ model_kwargs (dict, optional): Additional keyword arguments for model.
60
 
61
  Returns:
62
  Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
 
65
  prompts = load_prompts_from_file(prompt_file)
66
  prompt = prompts["MEDICAL_ASSISTANT"]
67
 
68
+ all_tools = {
69
+ "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
70
+ "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
71
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
72
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
73
  "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
 
83
  "ImageVisualizerTool": lambda: ImageVisualizerTool(),
84
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
85
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
86
+ "WebBrowserTool": lambda: WebBrowserTool(),
87
  }
88
 
89
  # Initialize only selected tools or all if none specified
 
96
  # Set up checkpointing for conversation state
97
  checkpointer = MemorySaver()
98
 
99
+ # Create the language model using the factory
100
+ try:
101
+ llm = ModelFactory.create_model(
102
+ model_name=model, temperature=temperature, top_p=top_p, **model_kwargs
103
+ )
104
+ except ValueError as e:
105
+ print(f"Error creating language model: {e}")
106
+ print(f"Available model providers: {list(ModelFactory._model_providers.keys())}")
107
+ raise
108
 
 
109
  agent = Agent(
110
+ llm,
111
  tools=list(tools_dict.values()),
112
  log_tools=True,
113
  log_dir="logs",
 
126
  """
127
  print("Starting server...")
128
 
129
+ # Example: initialize with only specific tools
130
+ # Here three tools are commented out, you can uncomment them to use them
131
+ selected_tools = [
132
+ "ImageVisualizerTool", # For displaying images in the UI
133
+ "WebBrowserTool", # For web browsing and search capabilities
134
  # "DicomProcessorTool", # For processing DICOM medical image files
135
+ # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
136
+ # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
137
  # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
138
  # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
139
  # "XRayVQATool", # For visual question answering on X-rays
140
  # "LlavaMedTool", # For multimodal medical image understanding
141
  # "XRayPhraseGroundingTool", # For locating described features in X-rays
142
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
143
+ # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
144
  ]
145
 
146
  # Configure the Retrieval Augmented Generation (RAG) system
 
156
  use_medrag_textbooks=True, # Set to True if you want to use the MedRAG textbooks dataset
157
  )
158
 
159
+ # Prepare any additional model-specific kwargs
160
+ model_kwargs = {}
 
 
161
 
162
+ # Set up API keys for the web browser tool
163
+ # You'll need to set these environment variables:
164
+ # - GOOGLE_SEARCH_API_KEY: Your Google Custom Search API key
165
+ # - GOOGLE_SEARCH_ENGINE_ID: Your Google Custom Search Engine ID
166
+ # - COHERE_API_KEY: Your Cohere API key
167
+ # - OPENAI_API_KEY: Your OpenAI API key
168
+ # - PINECONE_API_KEY: Your Pinecone API key
169
 
 
170
  agent, tools_dict = initialize_agent(
171
+ prompt_file="medrax/docs/system_prompts.txt",
172
+ tools_to_use=selected_tools,
173
+ model_dir="/model-weights",
174
  temp_dir="temp", # Change this to the path of the temporary directory
175
+ device="cuda",
176
+ model="gpt-4.1-2025-04-14", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
177
  temperature=0.7,
178
  top_p=0.95,
179
+ model_kwargs=model_kwargs,
180
  rag_config=rag_config,
 
181
  )
182
 
183
  # Create and launch the web interface
medrax/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Model module for MedRAX."""
2
+
3
+ from .model_factory import ModelFactory
4
+
5
+ __all__ = ["ModelFactory"]
medrax/models/model_factory.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Factory for creating language model instances based on model name."""
2
+
3
+ import os
4
+ from typing import Dict, Any, Type
5
+
6
+ from langchain_core.language_models import BaseLanguageModel
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+
10
+
11
+ class ModelFactory:
12
+ """Factory for creating language model instances based on model name.
13
+
14
+ This class implements a registry of language model providers and provides
15
+ methods to create appropriate language model instances based on the model name.
16
+ """
17
+
18
+ # Registry of model providers
19
+ _model_providers = {
20
+ "gpt": {
21
+ "class": ChatOpenAI,
22
+ "env_key": "OPENAI_API_KEY",
23
+ "base_url_key": "OPENAI_BASE_URL"
24
+ },
25
+ "chatgpt": {
26
+ "class": ChatOpenAI,
27
+ "env_key": "OPENAI_API_KEY",
28
+ "base_url_key": "OPENAI_BASE_URL"
29
+ },
30
+ "gemini": {
31
+ "class": ChatGoogleGenerativeAI,
32
+ "env_key": "GOOGLE_API_KEY"
33
+ },
34
+ "openrouter": {
35
+ "class": ChatOpenAI, # OpenRouter uses OpenAI-compatible interface
36
+ "env_key": "OPENROUTER_API_KEY",
37
+ "base_url_key": "OPENROUTER_BASE_URL",
38
+ "default_base_url": "https://openrouter.ai/api/v1"
39
+ },
40
+ # Add more providers with default configurations here
41
+ }
42
+
43
+ @classmethod
44
+ def register_provider(cls, prefix: str, model_class: Type[BaseLanguageModel],
45
+ env_key: str, **kwargs) -> None:
46
+ """Register a new model provider.
47
+
48
+ Args:
49
+ prefix (str): The prefix used to identify this model provider (e.g., 'gpt', 'gemini')
50
+ model_class (Type[BaseLanguageModel]): The LangChain model class to use
51
+ env_key (str): The environment variable name for the API key
52
+ **kwargs: Additional provider-specific configuration
53
+ """
54
+ cls._model_providers[prefix] = {
55
+ "class": model_class,
56
+ "env_key": env_key,
57
+ **kwargs
58
+ }
59
+
60
+ @classmethod
61
+ def create_model(cls, model_name: str, temperature: float = 0.7,
62
+ top_p: float = 0.95, **kwargs) -> BaseLanguageModel:
63
+ """Create and return an instance of the appropriate language model.
64
+
65
+ Args:
66
+ model_name (str): Name of the model to create (e.g., 'gpt-4o', 'gemini-2.5-pro')
67
+ temperature (float, optional): Temperature parameter. Defaults to 0.7.
68
+ top_p (float, optional): Top-p sampling parameter. Defaults to 0.95.
69
+ **kwargs: Additional model-specific parameters
70
+
71
+ Returns:
72
+ BaseLanguageModel: An initialized language model instance
73
+
74
+ Raises:
75
+ ValueError: If no provider is found for the given model name
76
+ ValueError: If the required API key is missing
77
+ """
78
+ # Find the matching provider based on model name prefix
79
+ provider_prefix = next(
80
+ (prefix for prefix in cls._model_providers if model_name.startswith(prefix)),
81
+ None
82
+ )
83
+
84
+ if not provider_prefix:
85
+ raise ValueError(
86
+ f"No provider found for model: {model_name}. "
87
+ f"Registered providers are for: {list(cls._model_providers.keys())}"
88
+ )
89
+
90
+ provider = cls._model_providers[provider_prefix]
91
+ model_class = provider["class"]
92
+ env_key = provider["env_key"]
93
+
94
+ # Set up provider-specific kwargs
95
+ provider_kwargs = {}
96
+
97
+ # Handle API key
98
+ if env_key in os.environ:
99
+ provider_kwargs["api_key"] = os.environ[env_key]
100
+ else:
101
+ # Log warning but don't fail - the model class might handle missing API keys differently
102
+ print(f"Warning: Environment variable {env_key} not found. Authentication may fail.")
103
+
104
+ # Check for base_url if applicable
105
+ if "base_url_key" in provider:
106
+ if provider["base_url_key"] in os.environ:
107
+ provider_kwargs["base_url"] = os.environ[provider["base_url_key"]]
108
+ elif "default_base_url" in provider:
109
+ provider_kwargs["base_url"] = provider["default_base_url"]
110
+
111
+ # Merge with any additional provider-specific settings from the registry
112
+ for k, v in provider.items():
113
+ if k not in ["class", "env_key", "base_url_key", "default_base_url"]:
114
+ provider_kwargs[k] = v
115
+
116
+ # Strip the provider prefix from the model name
117
+ # For example, 'openrouter-anthropic/claude-sonnet-4' becomes 'anthropic/claude-sonnet-4'
118
+ actual_model_name = model_name
119
+ if model_name.startswith(f"{provider_prefix}-"):
120
+ actual_model_name = model_name[len(provider_prefix)+1:]
121
+
122
+ # Create and return the model instance
123
+ return model_class(
124
+ model=actual_model_name,
125
+ temperature=temperature,
126
+ top_p=top_p,
127
+ **provider_kwargs,
128
+ **kwargs
129
+ )
130
+
131
+ @classmethod
132
+ def list_providers(cls) -> Dict[str, Dict[str, Any]]:
133
+ """List all registered model providers.
134
+
135
+ Returns:
136
+ Dict[str, Dict[str, Any]]: Dictionary of registered providers and their configurations
137
+ """
138
+ # Return a copy to prevent accidental modification
139
+ return {k: {kk: vv for kk, vv in v.items() if kk != "class"}
140
+ for k, v in cls._model_providers.items()}
medrax/tools/__init__.py CHANGED
@@ -9,4 +9,5 @@ from .grounding import *
9
  from .generation import *
10
  from .dicom import *
11
  from .utils import *
12
- from .rag import *
 
 
9
  from .generation import *
10
  from .dicom import *
11
  from .utils import *
12
+ from .rag import *
13
+ from .web_browser import *
medrax/tools/classification/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Classification tools for chest X-ray analysis."""
2
+
3
+ from .torchxrayvision import TorchXRayVisionClassifierTool, TorchXRayVisionInput
4
+ from .arcplus import ArcPlusClassifierTool, ArcPlusInput
5
+
6
+ __all__ = [
7
+ "TorchXRayVisionClassifierTool",
8
+ "TorchXRayVisionInput",
9
+ "ArcPlusClassifierTool",
10
+ "ArcPlusInput"
11
+ ]
medrax/tools/classification/arcplus.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import ClassVar, Dict, List, Optional, Tuple, Type
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from pydantic import BaseModel, Field
10
+ from timm.models.swin_transformer import SwinTransformer
11
+
12
+ from langchain_core.callbacks import (
13
+ AsyncCallbackManagerForToolRun,
14
+ CallbackManagerForToolRun,
15
+ )
16
+ from langchain_core.tools import BaseTool
17
+
18
+
19
+ class OmniSwinTransformer(SwinTransformer):
20
+ """OmniSwinTransformer with multiple classification heads and optional projector."""
21
+
22
+ def __init__(self, num_classes_list, projector_features=None, use_mlp=False, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ assert num_classes_list is not None
25
+
26
+ self.projector = None
27
+ if projector_features:
28
+ encoder_features = self.num_features
29
+ self.num_features = projector_features
30
+ if use_mlp:
31
+ self.projector = nn.Sequential(
32
+ nn.Linear(encoder_features, self.num_features),
33
+ nn.ReLU(inplace=True),
34
+ nn.Linear(self.num_features, self.num_features),
35
+ )
36
+ else:
37
+ self.projector = nn.Linear(encoder_features, self.num_features)
38
+
39
+ self.omni_heads = []
40
+ for num_classes in num_classes_list:
41
+ self.omni_heads.append(
42
+ nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
43
+ )
44
+ self.omni_heads = nn.ModuleList(self.omni_heads)
45
+
46
+ def forward(self, x, head_n=None):
47
+ x = self.forward_features(x)
48
+ if self.projector:
49
+ x = self.projector(x)
50
+ if head_n is not None:
51
+ return x, self.omni_heads[head_n](x)
52
+ else:
53
+ return [head(x) for head in self.omni_heads]
54
+
55
+ def generate_embeddings(self, x, after_proj=True):
56
+ x = self.forward_features(x)
57
+ if after_proj and self.projector:
58
+ x = self.projector(x)
59
+ return x
60
+
61
+
62
+ class ArcPlusInput(BaseModel):
63
+ """Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images."""
64
+
65
+ image_path: str = Field(
66
+ ..., description="Path to the radiology image file, only supports JPG or PNG images"
67
+ )
68
+
69
+
70
+ class ArcPlusClassifierTool(BaseTool):
71
+ """Tool that classifies chest X-ray images using the ArcPlus OmniSwinTransformer model.
72
+
73
+ This tool uses a pre-trained OmniSwinTransformer model (ArcPlus) to analyze chest X-ray images
74
+ and predict the likelihood of various pathologies across multiple medical datasets. The model
75
+ employs a Swin Transformer architecture with multiple classification heads, each specialized
76
+ for different medical datasets and conditions.
77
+
78
+ The ArcPlus model is trained on 6 different medical datasets:
79
+ - MIMIC-CXR: 14 pathologies including common chest conditions
80
+ - CheXpert: 14 pathologies with standardized labeling
81
+ - NIH ChestX-ray14: 14 pathologies from large-scale dataset
82
+ - RSNA: 3 classes for pneumonia detection
83
+ - VinDr-CXR: 6 categories including tuberculosis and lung tumors
84
+ - Shenzhen: 1 class for tuberculosis detection
85
+
86
+ Key Features:
87
+ - Multi-head architecture with 6 specialized classification heads
88
+ - 768x768 input resolution for high-detail analysis
89
+ - Projector layer with 1376 features for enhanced representation
90
+ - Sigmoid activation for multi-label classification
91
+ - Covers 52+ distinct pathology categories across datasets
92
+
93
+ The model outputs probabilities (0 to 1) for each condition, with higher values
94
+ indicating higher likelihood of the pathology being present in the image.
95
+ """
96
+
97
+ name: str = "arcplus_classifier"
98
+ description: str = (
99
+ "Advanced chest X-ray classification tool using ArcPlus OmniSwinTransformer with multi-dataset training. "
100
+ "Analyzes chest X-ray images and provides probability predictions for 52+ pathologies across 6 medical datasets. "
101
+ "Input: Path to chest X-ray image file (JPG/PNG). "
102
+ "Output: Dictionary mapping pathology names to probabilities (0-1). "
103
+ "Features: Multi-head architecture, 768px resolution, projector layer, specialized for medical imaging. "
104
+ "Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged Cardiomediastinum, "
105
+ "Fracture, Lung Lesion, Lung Opacity, Pleural Effusion, Pneumonia, Pneumothorax, Mass, Nodule, "
106
+ "Emphysema, Fibrosis, PE, Lung Tumor, Tuberculosis, and many more across MIMIC, CheXpert, NIH, "
107
+ "RSNA, VinDr, and Shenzhen datasets. Higher probabilities indicate higher likelihood of condition presence."
108
+ )
109
+ args_schema: Type[BaseModel] = ArcPlusInput
110
+ model: OmniSwinTransformer = None
111
+ device: Optional[str] = "cuda"
112
+ normalize: transforms.Normalize = None
113
+ disease_list: List[str] = None
114
+ num_classes_list: List[int] = None
115
+
116
+ # Disease mappings from the analysis
117
+ mimic_diseases: ClassVar[List[str]] = [
118
+ "Atelectasis",
119
+ "Cardiomegaly",
120
+ "Consolidation",
121
+ "Edema",
122
+ "Enlarged Cardiomediastinum",
123
+ "Fracture",
124
+ "Lung Lesion",
125
+ "Lung Opacity",
126
+ "No Finding",
127
+ "Pleural Effusion",
128
+ "Pleural Other",
129
+ "Pneumonia",
130
+ "Pneumothorax",
131
+ "Support Devices",
132
+ ]
133
+ chexpert_diseases: ClassVar[List[str]] = [
134
+ "No Finding",
135
+ "Enlarged Cardiomediastinum",
136
+ "Cardiomegaly",
137
+ "Lung Opacity",
138
+ "Lung Lesion",
139
+ "Edema",
140
+ "Consolidation",
141
+ "Pneumonia",
142
+ "Atelectasis",
143
+ "Pneumothorax",
144
+ "Pleural Effusion",
145
+ "Pleural Other",
146
+ "Fracture",
147
+ "Support Devices",
148
+ ]
149
+ nih14_diseases: ClassVar[List[str]] = [
150
+ "Atelectasis",
151
+ "Cardiomegaly",
152
+ "Effusion",
153
+ "Infiltration",
154
+ "Mass",
155
+ "Nodule",
156
+ "Pneumonia",
157
+ "Pneumothorax",
158
+ "Consolidation",
159
+ "Edema",
160
+ "Emphysema",
161
+ "Fibrosis",
162
+ "Pleural_Thickening",
163
+ "Hernia",
164
+ ]
165
+ rsna_diseases: ClassVar[List[str]] = ["No Lung Opacity/Not Normal", "Normal", "Lung Opacity"]
166
+ vindr_diseases: ClassVar[List[str]] = [
167
+ "PE",
168
+ "Lung tumor",
169
+ "Pneumonia",
170
+ "Tuberculosis",
171
+ "Other diseases",
172
+ "No finding",
173
+ ]
174
+ shenzhen_diseases: ClassVar[List[str]] = ["TB"]
175
+
176
+ def __init__(self, cache_dir: str = None, device: Optional[str] = "cuda"):
177
+ """Initialize the ArcPlus Classifier Tool.
178
+
179
+ Args:
180
+ cache_dir (str, optional): Directory containing the pre-trained ArcPlus model checkpoint.
181
+ The tool will automatically look for 'Ark6_swinLarge768_ep50.pth.tar' in this directory.
182
+ If None, model will be initialized with random weights (not recommended for inference).
183
+ Default: None.
184
+ device (str, optional): Device to run the model on ('cuda' for GPU, 'cpu' for CPU).
185
+ GPU is recommended for better performance. Default: "cuda".
186
+
187
+ Model Architecture Details:
188
+ - OmniSwinTransformer with 6 classification heads
189
+ - Input resolution: 768x768 pixels
190
+ - Projector features: 1376 dimensions
191
+ - Multi-head configuration: [14, 14, 14, 3, 6, 1] classes per head
192
+ - Total pathologies: 52+ across 6 medical datasets
193
+ - Preprocessing: ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
194
+
195
+ Raises:
196
+ FileNotFoundError: If cache_dir is provided but model file doesn't exist.
197
+ RuntimeError: If model loading fails or device is unavailable.
198
+ """
199
+ super().__init__()
200
+
201
+ # Create combined disease list from all supported datasets
202
+ self.disease_list = (
203
+ self.mimic_diseases
204
+ + self.chexpert_diseases
205
+ + self.nih14_diseases
206
+ + self.rsna_diseases
207
+ + self.vindr_diseases
208
+ + self.shenzhen_diseases
209
+ )
210
+
211
+ # Multi-head configuration: [MIMIC, CheXpert, NIH, RSNA, VinDr, Shenzhen]
212
+ self.num_classes_list = [14, 14, 14, 3, 6, 1]
213
+
214
+ # Initialize the OmniSwinTransformer model with ArcPlus architecture
215
+ self.model = OmniSwinTransformer(
216
+ num_classes_list=self.num_classes_list,
217
+ projector_features=1376, # Enhanced feature representation
218
+ use_mlp=False, # Linear projector (not MLP)
219
+ img_size=768, # High-resolution input
220
+ patch_size=4,
221
+ window_size=12,
222
+ embed_dim=192,
223
+ depths=(2, 2, 18, 2), # Swin-Large configuration
224
+ num_heads=(6, 12, 24, 48),
225
+ )
226
+
227
+ # Load pre-trained weights if provided
228
+ if cache_dir:
229
+ model_path = os.path.join(cache_dir, "Ark6_swinLarge768_ep50.pth.tar")
230
+ self._load_checkpoint(model_path)
231
+
232
+ self.model.eval()
233
+ self.device = torch.device(device) if device else "cuda"
234
+ self.model = self.model.to(self.device)
235
+
236
+ # ImageNet normalization parameters for optimal performance
237
+ self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
238
+
239
+ def _load_checkpoint(self, model_path: str) -> None:
240
+ """
241
+ Load the ArcPlus model checkpoint.
242
+
243
+ Args:
244
+ model_path (str): Path to the model checkpoint file.
245
+ """
246
+ # Load the checkpoint (set weights_only=False for PyTorch 2.6+ compatibility)
247
+ checkpoint = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
248
+ state_dict = checkpoint["teacher"] # Use 'teacher' key
249
+
250
+ # Remove "module." prefix if present (improved logic from example)
251
+ if any([True if "module." in k else False for k in state_dict.keys()]):
252
+ state_dict = {
253
+ k.replace("module.", ""): v
254
+ for k, v in state_dict.items()
255
+ if k.startswith("module.")
256
+ }
257
+
258
+ # Load the model weights
259
+ msg = self.model.load_state_dict(state_dict, strict=False)
260
+
261
+ def _process_image(self, image_path: str) -> torch.Tensor:
262
+ """
263
+ Process the input chest X-ray image for model inference.
264
+
265
+ This method loads the image, applies necessary transformations,
266
+ and prepares it as a torch.Tensor for model input.
267
+
268
+ Args:
269
+ image_path (str): The file path to the chest X-ray image.
270
+
271
+ Returns:
272
+ torch.Tensor: A processed image tensor ready for model inference.
273
+
274
+ Raises:
275
+ FileNotFoundError: If the specified image file does not exist.
276
+ ValueError: If the image cannot be properly loaded or processed.
277
+ """
278
+ try:
279
+ # Load and preprocess image following the example pattern
280
+ image = Image.open(image_path).convert("RGB").resize((768, 768))
281
+
282
+ # Convert to numpy array and normalize to [0, 1]
283
+ image_array = np.array(image) / 255.0
284
+
285
+ # Apply ImageNet normalization
286
+ image_tensor = torch.from_numpy(image_array).float()
287
+ image_tensor = image_tensor.permute(2, 0, 1) # HWC to CHW
288
+ image_tensor = self.normalize(image_tensor)
289
+
290
+ # Add batch dimension and move to device
291
+ image_tensor = image_tensor.unsqueeze(0).to(self.device)
292
+
293
+ return image_tensor
294
+
295
+ except Exception as e:
296
+ raise ValueError(f"Error processing image {image_path}: {str(e)}")
297
+
298
+ def _run(
299
+ self,
300
+ image_path: str,
301
+ run_manager: Optional[CallbackManagerForToolRun] = None,
302
+ ) -> Tuple[Dict[str, float], Dict]:
303
+ """Classify the chest X-ray image using ArcPlus SwinTransformer.
304
+
305
+ Args:
306
+ image_path (str): The path to the chest X-ray image file.
307
+ run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
308
+
309
+ Returns:
310
+ Tuple[Dict[str, float], Dict]: A tuple containing the classification results
311
+ (pathologies and their probabilities from 0 to 1)
312
+ and any additional metadata.
313
+
314
+ Raises:
315
+ Exception: If there's an error processing the image or during classification.
316
+ """
317
+ try:
318
+ # Process the image
319
+ image_tensor = self._process_image(image_path)
320
+
321
+ # Run model inference
322
+ with torch.no_grad():
323
+ pre_logits = self.model(image_tensor)
324
+
325
+ # Apply sigmoid to each output head (as seen in example)
326
+ preds = [torch.sigmoid(out) for out in pre_logits]
327
+
328
+ # Concatenate all predictions into single tensor
329
+ preds = torch.cat(preds, dim=1)
330
+
331
+ # Convert to numpy
332
+ predictions = preds.cpu().numpy().flatten()
333
+
334
+ # Map predictions to disease names
335
+ if len(predictions) != len(self.disease_list):
336
+ print(
337
+ f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}"
338
+ )
339
+ # Pad or truncate as needed
340
+ if len(predictions) < len(self.disease_list):
341
+ predictions = np.pad(
342
+ predictions, (0, len(self.disease_list) - len(predictions))
343
+ )
344
+ else:
345
+ predictions = predictions[: len(self.disease_list)]
346
+
347
+ # Create output dictionary mapping disease names to probabilities
348
+ output = dict(zip(self.disease_list, predictions.astype(float)))
349
+
350
+ metadata = {
351
+ "image_path": image_path,
352
+ "model": "ArcPlus OmniSwinTransformer",
353
+ "analysis_status": "completed",
354
+ "num_predictions": len(predictions),
355
+ "num_heads": len(self.num_classes_list),
356
+ "projector_features": 1376,
357
+ "note": "Probabilities range from 0 to 1, with higher values indicating higher likelihood of the condition.",
358
+ }
359
+
360
+ return output, metadata
361
+
362
+ except Exception as e:
363
+ return {"error": str(e)}, {
364
+ "image_path": image_path,
365
+ "analysis_status": "failed",
366
+ "error_details": str(e),
367
+ }
368
+
369
+ async def _arun(
370
+ self,
371
+ image_path: str,
372
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
373
+ ) -> Tuple[Dict[str, float], Dict]:
374
+ """Asynchronously classify the chest X-ray image using ArcPlus SwinTransformer.
375
+
376
+ This method currently calls the synchronous version, as the model inference
377
+ is not inherently asynchronous. For true asynchronous behavior, consider
378
+ using a separate thread or process.
379
+
380
+ Args:
381
+ image_path (str): The path to the chest X-ray image file.
382
+ run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
383
+
384
+ Returns:
385
+ Tuple[Dict[str, float], Dict]: A tuple containing the classification results
386
+ (pathologies and their probabilities from 0 to 1)
387
+ and any additional metadata.
388
+
389
+ Raises:
390
+ Exception: If there's an error processing the image or during classification.
391
+ """
392
+ return self._run(image_path)
medrax/tools/{classification.py → classification/torchxrayvision.py} RENAMED
@@ -13,15 +13,15 @@ from langchain_core.callbacks import (
13
  from langchain_core.tools import BaseTool
14
 
15
 
16
- class ChestXRayInput(BaseModel):
17
- """Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
18
 
19
  image_path: str = Field(
20
  ..., description="Path to the radiology image file, only supports JPG or PNG images"
21
  )
22
 
23
 
24
- class ChestXRayClassifierTool(BaseTool):
25
  """Tool that classifies chest X-ray images for multiple pathologies.
26
 
27
  This tool uses a pre-trained DenseNet model to analyze chest X-ray images and
@@ -35,9 +35,9 @@ class ChestXRayClassifierTool(BaseTool):
35
  A higher value indicates a higher likelihood of the condition being present.
36
  """
37
 
38
- name: str = "chest_xray_classifier"
39
  description: str = (
40
- "A tool that analyzes chest X-ray images and classifies them for 18 different pathologies. "
41
  "Input should be the path to a chest X-ray image file. "
42
  "Output is a dictionary of pathologies and their predicted probabilities (0 to 1). "
43
  "Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, "
@@ -45,7 +45,7 @@ class ChestXRayClassifierTool(BaseTool):
45
  "Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, and Pneumothorax. "
46
  "Higher values indicate a higher likelihood of the condition being present."
47
  )
48
- args_schema: Type[BaseModel] = ChestXRayInput
49
  model: xrv.models.DenseNet = None
50
  device: Optional[str] = "cuda"
51
  transform: torchvision.transforms.Compose = None
 
13
  from langchain_core.tools import BaseTool
14
 
15
 
16
+ class TorchXRayVisionInput(BaseModel):
17
+ """Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
18
 
19
  image_path: str = Field(
20
  ..., description="Path to the radiology image file, only supports JPG or PNG images"
21
  )
22
 
23
 
24
+ class TorchXRayVisionClassifierTool(BaseTool):
25
  """Tool that classifies chest X-ray images for multiple pathologies.
26
 
27
  This tool uses a pre-trained DenseNet model to analyze chest X-ray images and
 
35
  A higher value indicates a higher likelihood of the condition being present.
36
  """
37
 
38
+ name: str = "torchxrayvision_classifier"
39
  description: str = (
40
+ "A tool that analyzes chest X-ray images and classifies them for 18 different pathologies using TorchXRayVision DenseNet. "
41
  "Input should be the path to a chest X-ray image file. "
42
  "Output is a dictionary of pathologies and their predicted probabilities (0 to 1). "
43
  "Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, "
 
45
  "Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, and Pneumothorax. "
46
  "Higher values indicate a higher likelihood of the condition being present."
47
  )
48
+ args_schema: Type[BaseModel] = TorchXRayVisionInput
49
  model: xrv.models.DenseNet = None
50
  device: Optional[str] = "cuda"
51
  transform: torchvision.transforms.Compose = None
medrax/tools/web_browser.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web browser tool for MedRAX2.
2
+
3
+ This module implements a web browsing tool for MedRAX2, allowing the agent
4
+ to search the web, visit URLs, and extract information from web pages.
5
+ """
6
+
7
+ import os
8
+ import re
9
+ import json
10
+ import time
11
+ from typing import Dict, Optional, Any, Type, Tuple
12
+ from urllib.parse import urlparse
13
+
14
+ import requests
15
+ from bs4 import BeautifulSoup
16
+ from langchain_core.tools import BaseTool
17
+ from pydantic import BaseModel, Field
18
+
19
+
20
+ class WebBrowserSchema(BaseModel):
21
+ """Schema for web browser tool."""
22
+ query: str = Field("", description="The search query (leave empty if visiting a URL)")
23
+ url: str = Field("", description="The URL to visit (leave empty if performing a search)")
24
+
25
+
26
+ class SearchQuerySchema(BaseModel):
27
+ """Schema for web search queries."""
28
+ query: str = Field(..., description="The search query string")
29
+
30
+
31
+ class VisitUrlSchema(BaseModel):
32
+ """Schema for URL visits."""
33
+ url: str = Field(..., description="The URL to visit")
34
+
35
+
36
+ class WebBrowserTool(BaseTool):
37
+ """Tool for browsing the web, searching for information, and visiting URLs.
38
+
39
+ This tool provides the agent with internet browsing capabilities, including:
40
+ 1. Performing web searches using a search engine API
41
+ 2. Visiting specific URLs and extracting their content
42
+ 3. Following links within pages
43
+ """
44
+ name: str = "WebBrowserTool"
45
+ description: str = "Search the web for information or visit specific URLs to retrieve content"
46
+ search_api_key: Optional[str] = None
47
+ search_engine_id: Optional[str] = None
48
+ user_agent: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
49
+ max_results: int = 5
50
+ args_schema: Type[BaseModel] = WebBrowserSchema
51
+
52
+ def __init__(self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs):
53
+ """Initialize the web browser tool.
54
+
55
+ Args:
56
+ search_api_key: Google Custom Search API key (optional)
57
+ search_engine_id: Google Custom Search Engine ID (optional)
58
+ **kwargs: Additional keyword arguments
59
+ """
60
+ super().__init__(**kwargs)
61
+ # Try to get API keys from environment variables if not provided
62
+ self.search_api_key = search_api_key or os.environ.get("GOOGLE_SEARCH_API_KEY")
63
+ self.search_engine_id = search_engine_id or os.environ.get("GOOGLE_SEARCH_ENGINE_ID")
64
+
65
+ def search_web(self, query: str) -> Dict[str, Any]:
66
+ """Search the web using Google Custom Search API.
67
+
68
+ Args:
69
+ query: The search query string
70
+
71
+ Returns:
72
+ Dict containing search results
73
+ """
74
+ if not self.search_api_key or not self.search_engine_id:
75
+ return {
76
+ "error": "Search API key or engine ID not configured. Please set GOOGLE_SEARCH_API_KEY and GOOGLE_SEARCH_ENGINE_ID environment variables."
77
+ }
78
+
79
+ url = "https://www.googleapis.com/customsearch/v1"
80
+ params = {
81
+ "key": self.search_api_key,
82
+ "cx": self.search_engine_id,
83
+ "q": query,
84
+ "num": self.max_results
85
+ }
86
+
87
+ try:
88
+ response = requests.get(url, params=params, timeout=10)
89
+ response.raise_for_status()
90
+ results = response.json()
91
+
92
+ if "items" not in results:
93
+ return {"results": [], "message": "No results found"}
94
+
95
+ formatted_results = []
96
+ for item in results["items"]:
97
+ formatted_results.append({
98
+ "title": item.get("title"),
99
+ "link": item.get("link"),
100
+ "snippet": item.get("snippet"),
101
+ "source": item.get("displayLink")
102
+ })
103
+
104
+ return {
105
+ "results": formatted_results,
106
+ "message": f"Found {len(formatted_results)} results for query: {query}"
107
+ }
108
+
109
+ except Exception as e:
110
+ return {"error": f"Search failed: {str(e)}"}
111
+
112
+ def visit_url(self, url: str) -> Dict[str, Any]:
113
+ """Visit a URL and extract its content.
114
+
115
+ Args:
116
+ url: The URL to visit
117
+
118
+ Returns:
119
+ Dict containing the page content, title, and metadata
120
+ """
121
+ try:
122
+ # Validate URL
123
+ parsed_url = urlparse(url)
124
+ if not parsed_url.scheme or not parsed_url.netloc:
125
+ return {"error": f"Invalid URL: {url}"}
126
+
127
+ headers = {"User-Agent": self.user_agent}
128
+ response = requests.get(url, headers=headers, timeout=15)
129
+ response.raise_for_status()
130
+
131
+ # Parse the HTML content
132
+ soup = BeautifulSoup(response.text, "html.parser")
133
+
134
+ # Extract title
135
+ title = soup.title.string if soup.title else "No title"
136
+
137
+ # Extract main content (remove scripts, styles, etc.)
138
+ for script in soup(["script", "style", "meta", "noscript"]):
139
+ script.extract()
140
+
141
+ # Get text content
142
+ text_content = soup.get_text(separator="\n", strip=True)
143
+ # Clean up whitespace
144
+ text_content = re.sub(r'\n+', '\n', text_content)
145
+ text_content = re.sub(r' +', ' ', text_content)
146
+
147
+ # Extract links
148
+ links = []
149
+ for link in soup.find_all("a", href=True):
150
+ href = link["href"]
151
+ # Handle relative URLs
152
+ if href.startswith("/"):
153
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
154
+ href = base_url + href
155
+ if href.startswith(("http://", "https://")):
156
+ links.append({
157
+ "text": link.get_text(strip=True) or href,
158
+ "url": href
159
+ })
160
+
161
+ # Extract images (limited to first 3)
162
+ images = []
163
+ for i, img in enumerate(soup.find_all("img", src=True)[:3]):
164
+ src = img["src"]
165
+ # Handle relative URLs
166
+ if src.startswith("/"):
167
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
168
+ src = base_url + src
169
+ if src.startswith(("http://", "https://")):
170
+ images.append(src)
171
+
172
+ return {
173
+ "title": title,
174
+ "content": text_content[:10000] if len(text_content) > 10000 else text_content,
175
+ "url": url,
176
+ "links": links[:10], # Limit to 10 links
177
+ "images": images,
178
+ "content_type": response.headers.get("Content-Type", ""),
179
+ "content_length": len(text_content),
180
+ "truncated": len(text_content) > 10000
181
+ }
182
+
183
+ except Exception as e:
184
+ return {"error": f"Failed to visit {url}: {str(e)}"}
185
+
186
+ async def _arun(self, query: str = "", url: str = "") -> str:
187
+ """Run the tool asynchronously."""
188
+ return json.dumps(self._run(query=query, url=url))
189
+
190
+ def _run(self, query: str = "", url: str = "") -> Tuple[Dict[str, Any], Dict[str, Any]]:
191
+ """Run the web browser tool.
192
+
193
+ Args:
194
+ query: Search query (if searching)
195
+ url: URL to visit (if visiting a specific page)
196
+
197
+ Returns:
198
+ Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing the results and metadata
199
+ """
200
+ metadata = {
201
+ "query": query if query else "",
202
+ "url": url if url else "",
203
+ "timestamp": time.time(),
204
+ "tool": "WebBrowserTool"
205
+ }
206
+
207
+ if url:
208
+ result = self.visit_url(url)
209
+ return result, metadata
210
+ elif query:
211
+ result = self.search_web(query)
212
+ return result, metadata
213
+ else:
214
+ return {"error": "Please provide either a search query or a URL to visit"}, metadata
215
+
pyproject.toml CHANGED
@@ -66,6 +66,7 @@ dependencies = [
66
  "chromadb>=0.0.10",
67
  "pinecone-client>=3.2.2",
68
  "langchain-pinecone>=0.0.1",
 
69
  ]
70
 
71
  [project.optional-dependencies]
 
66
  "chromadb>=0.0.10",
67
  "pinecone-client>=3.2.2",
68
  "langchain-pinecone>=0.0.1",
69
+ "langchain-google-genai>=0.1.0",
70
  ]
71
 
72
  [project.optional-dependencies]
quickstart.py CHANGED
@@ -11,7 +11,7 @@ from datasets import load_dataset
11
 
12
  # Initialize global variables
13
  logger = logging.getLogger('benchmark')
14
- model_name = 'chatgpt-4o-latest' # default value
15
  temperature = 0.2 # default value
16
  log_filename = None
17
 
@@ -199,7 +199,7 @@ def main():
199
  # Add command line argument parsing
200
  parser = argparse.ArgumentParser(description='Run medical image analysis benchmark')
201
  parser.add_argument('--use-urls', action='store_true', help='Use image URLs instead of local files')
202
- parser.add_argument('--model', type=str, default='chatgpt-4o-latest', help='Model name to use')
203
  parser.add_argument('--temperature', type=float, default=0.2, help='Temperature for model inference')
204
  parser.add_argument('--log-prefix', type=str, help='Prefix for log filename (default: model name)')
205
  parser.add_argument('--max-cases', type=int, default=None, help='Maximum number of cases to process (default: all)')
 
11
 
12
  # Initialize global variables
13
  logger = logging.getLogger('benchmark')
14
+ model_name = 'gpt-4.1-2025-04-14' # default value
15
  temperature = 0.2 # default value
16
  log_filename = None
17
 
 
199
  # Add command line argument parsing
200
  parser = argparse.ArgumentParser(description='Run medical image analysis benchmark')
201
  parser.add_argument('--use-urls', action='store_true', help='Use image URLs instead of local files')
202
+ parser.add_argument('--model', type=str, default='gpt-4.1-2025-04-14', help='Model name to use')
203
  parser.add_argument('--temperature', type=float, default=0.2, help='Temperature for model inference')
204
  parser.add_argument('--log-prefix', type=str, help='Prefix for log filename (default: model name)')
205
  parser.add_argument('--max-cases', type=int, default=None, help='Maximum number of cases to process (default: all)')