VictorLJZ commited on
Commit
5a3031b
·
2 Parent(s): db11328 9fe334c

Merge pull request #2 from bowang-lab/victor

Browse files
Files changed (3) hide show
  1. README.md +37 -2
  2. main.py +2 -2
  3. medrax/models/model_factory.py +24 -4
README.md CHANGED
@@ -15,7 +15,7 @@ Chest X-rays (CXRs) play an integral role in driving critical decisions in disea
15
  ## MedRAX
16
  MedRAX is built on a robust technical foundation:
17
  - **Core Architecture**: Built on LangChain and LangGraph frameworks
18
- - **Language Model**: Uses GPT-4o with vision capabilities as the backbone LLM
19
  - **Deployment**: Supports both local and cloud-based deployments
20
  - **Interface**: Production-ready interface built with Gradio
21
  - **Modular Design**: Tool-agnostic architecture allowing easy integration of new capabilities
@@ -27,6 +27,7 @@ MedRAX is built on a robust technical foundation:
27
  - **Report Generation**: Implements SwinV2 Transformer trained on CheXpert Plus for detailed medical reporting
28
  - **Disease Classification**: Leverages DenseNet-121 from TorchXRayVision for detecting 18 pathology classes
29
  - **X-ray Generation**: Utilizes RoentGen for synthetic CXR generation
 
30
  - **Utilities**: Includes DICOM processing, visualization tools, and custom plotting capabilities
31
  <br><br>
32
 
@@ -180,6 +181,7 @@ No additional model weights required:
180
  ```python
181
  ImageVisualizerTool()
182
  DicomProcessorTool(temp_dir=temp_dir)
 
183
  ```
184
  <br>
185
 
@@ -212,12 +214,45 @@ ChestXRayGeneratorTool(
212
  - Some tools (LLaVA-Med, Grounding) are more resource-intensive
213
  <br>
214
 
215
- ### Local LLMs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  If you are running a local LLM using frameworks like [Ollama](https://ollama.com/) or [LM Studio](https://lmstudio.ai/), you need to configure your environment variables accordingly. For example:
217
  ```
218
  export OPENAI_BASE_URL="http://localhost:11434/v1"
219
  export OPENAI_API_KEY="ollama"
220
  ```
 
 
 
 
 
 
 
221
  <br>
222
 
223
  ## Star History
 
15
  ## MedRAX
16
  MedRAX is built on a robust technical foundation:
17
  - **Core Architecture**: Built on LangChain and LangGraph frameworks
18
+ - **Language Models**: Supports multiple LLM providers including OpenAI (GPT-4o) and Google (Gemini) models
19
  - **Deployment**: Supports both local and cloud-based deployments
20
  - **Interface**: Production-ready interface built with Gradio
21
  - **Modular Design**: Tool-agnostic architecture allowing easy integration of new capabilities
 
27
  - **Report Generation**: Implements SwinV2 Transformer trained on CheXpert Plus for detailed medical reporting
28
  - **Disease Classification**: Leverages DenseNet-121 from TorchXRayVision for detecting 18 pathology classes
29
  - **X-ray Generation**: Utilizes RoentGen for synthetic CXR generation
30
+ - **Web Browser**: Provides web search capabilities and URL content retrieval using Google Custom Search API
31
  - **Utilities**: Includes DICOM processing, visualization tools, and custom plotting capabilities
32
  <br><br>
33
 
 
181
  ```python
182
  ImageVisualizerTool()
183
  DicomProcessorTool(temp_dir=temp_dir)
184
+ WebBrowserTool() # Requires Google Search API credentials
185
  ```
186
  <br>
187
 
 
214
  - Some tools (LLaVA-Med, Grounding) are more resource-intensive
215
  <br>
216
 
217
+ ### Language Model Options
218
+ MedRAX supports multiple language model providers:
219
+
220
+ #### OpenAI Models
221
+ Supported prefixes: `gpt-` and `chatgpt-`
222
+ ```
223
+ export OPENAI_API_KEY="your-openai-api-key"
224
+ export OPENAI_BASE_URL="https://api.openai.com/v1" # Optional for custom endpoints
225
+ ```
226
+
227
+ #### Google Gemini Models
228
+ Supported prefix: `gemini-`
229
+ ```
230
+ export GOOGLE_API_KEY="your-google-api-key"
231
+ ```
232
+
233
+ #### OpenRouter Models (Open Source & Proprietary)
234
+ Supported prefix: `openrouter-`
235
+
236
+ Access many open source and proprietary models via [OpenRouter](https://openrouter.ai/):
237
+ ```
238
+ export OPENROUTER_API_KEY="your-openrouter-api-key"
239
+ ```
240
+
241
+ **Note:** Tool compatibility may vary with open-source models. For best results with tools, we recommend using OpenAI or Google Gemini models.
242
+
243
+ #### Local LLMs
244
  If you are running a local LLM using frameworks like [Ollama](https://ollama.com/) or [LM Studio](https://lmstudio.ai/), you need to configure your environment variables accordingly. For example:
245
  ```
246
  export OPENAI_BASE_URL="http://localhost:11434/v1"
247
  export OPENAI_API_KEY="ollama"
248
  ```
249
+
250
+ #### WebBrowserTool Configuration
251
+ If you're using the WebBrowserTool, you'll need to set these environment variables:
252
+ ```
253
+ export GOOGLE_SEARCH_API_KEY="your-google-search-api-key"
254
+ export GOOGLE_SEARCH_ENGINE_ID="your-google-search-engine-id"
255
+ ```
256
  <br>
257
 
258
  ## Star History
main.py CHANGED
@@ -23,7 +23,7 @@ def initialize_agent(
23
  model_dir="/model-weights",
24
  temp_dir="temp",
25
  device="cuda",
26
- model="gpt-4o",
27
  temperature=0.7,
28
  top_p=0.95,
29
  model_kwargs={}
@@ -137,7 +137,7 @@ if __name__ == "__main__":
137
  model_dir="/m_weights", # Change this to the path of the model weights
138
  temp_dir="temp", # Change this to the path of the temporary directory
139
  device="cpu", # Change this to the device you want to use
140
- model="gemini-2.5-pro", # Change this to the model you want to use, e.g. gpt-4o-mini, gemini-2.5-pro
141
  temperature=0.7,
142
  top_p=0.95,
143
  model_kwargs=model_kwargs
 
23
  model_dir="/model-weights",
24
  temp_dir="temp",
25
  device="cuda",
26
+ model="chatgpt-4o-latest",
27
  temperature=0.7,
28
  top_p=0.95,
29
  model_kwargs={}
 
137
  model_dir="/m_weights", # Change this to the path of the model weights
138
  temp_dir="temp", # Change this to the path of the temporary directory
139
  device="cpu", # Change this to the device you want to use
140
+ model="gpt-4o-mini", # Change this to the model you want to use, e.g. gpt-4o-mini, gemini-2.5-pro
141
  temperature=0.7,
142
  top_p=0.95,
143
  model_kwargs=model_kwargs
medrax/models/model_factory.py CHANGED
@@ -22,10 +22,21 @@ class ModelFactory:
22
  "env_key": "OPENAI_API_KEY",
23
  "base_url_key": "OPENAI_BASE_URL"
24
  },
 
 
 
 
 
25
  "gemini": {
26
  "class": ChatGoogleGenerativeAI,
27
  "env_key": "GOOGLE_API_KEY"
28
  },
 
 
 
 
 
 
29
  # Add more providers with default configurations here
30
  }
31
 
@@ -91,17 +102,26 @@ class ModelFactory:
91
  print(f"Warning: Environment variable {env_key} not found. Authentication may fail.")
92
 
93
  # Check for base_url if applicable
94
- if "base_url_key" in provider and provider["base_url_key"] in os.environ:
95
- provider_kwargs["base_url"] = os.environ[provider["base_url_key"]]
 
 
 
96
 
97
  # Merge with any additional provider-specific settings from the registry
98
  for k, v in provider.items():
99
- if k not in ["class", "env_key", "base_url_key"]:
100
  provider_kwargs[k] = v
101
 
 
 
 
 
 
 
102
  # Create and return the model instance
103
  return model_class(
104
- model=model_name,
105
  temperature=temperature,
106
  top_p=top_p,
107
  **provider_kwargs,
 
22
  "env_key": "OPENAI_API_KEY",
23
  "base_url_key": "OPENAI_BASE_URL"
24
  },
25
+ "chatgpt": {
26
+ "class": ChatOpenAI,
27
+ "env_key": "OPENAI_API_KEY",
28
+ "base_url_key": "OPENAI_BASE_URL"
29
+ },
30
  "gemini": {
31
  "class": ChatGoogleGenerativeAI,
32
  "env_key": "GOOGLE_API_KEY"
33
  },
34
+ "openrouter": {
35
+ "class": ChatOpenAI, # OpenRouter uses OpenAI-compatible interface
36
+ "env_key": "OPENROUTER_API_KEY",
37
+ "base_url_key": "OPENROUTER_BASE_URL",
38
+ "default_base_url": "https://openrouter.ai/api/v1"
39
+ },
40
  # Add more providers with default configurations here
41
  }
42
 
 
102
  print(f"Warning: Environment variable {env_key} not found. Authentication may fail.")
103
 
104
  # Check for base_url if applicable
105
+ if "base_url_key" in provider:
106
+ if provider["base_url_key"] in os.environ:
107
+ provider_kwargs["base_url"] = os.environ[provider["base_url_key"]]
108
+ elif "default_base_url" in provider:
109
+ provider_kwargs["base_url"] = provider["default_base_url"]
110
 
111
  # Merge with any additional provider-specific settings from the registry
112
  for k, v in provider.items():
113
+ if k not in ["class", "env_key", "base_url_key", "default_base_url"]:
114
  provider_kwargs[k] = v
115
 
116
+ # Strip the provider prefix from the model name
117
+ # For example, 'openrouter-anthropic/claude-sonnet-4' becomes 'anthropic/claude-sonnet-4'
118
+ actual_model_name = model_name
119
+ if model_name.startswith(f"{provider_prefix}-"):
120
+ actual_model_name = model_name[len(provider_prefix)+1:]
121
+
122
  # Create and return the model instance
123
  return model_class(
124
+ model=actual_model_name,
125
  temperature=temperature,
126
  top_p=top_p,
127
  **provider_kwargs,