Adibvafa commited on
Commit
ef6fc50
·
1 Parent(s): 5a3031b

Change default LLM to GPT-4.1 and add new ArcPlus classification tool

Browse files
README.md CHANGED
@@ -104,7 +104,8 @@ MedRAX supports selective tool initialization, allowing you to use only the tool
104
  ```python
105
  selected_tools = [
106
  "ImageVisualizerTool",
107
- "ChestXRayClassifierTool",
 
108
  "ChestXRaySegmentationTool",
109
  # Add or remove tools as needed
110
  ]
@@ -121,9 +122,17 @@ agent, tools_dict = initialize_agent(
121
 
122
  The following tools will automatically download their model weights when initialized:
123
 
124
- ### Classification Tool
125
  ```python
126
- ChestXRayClassifierTool(device=device)
 
 
 
 
 
 
 
 
127
  ```
128
 
129
  ### Segmentation Tool
 
104
  ```python
105
  selected_tools = [
106
  "ImageVisualizerTool",
107
+ "TorchXRayVisionClassifierTool", # Renamed from ChestXRayClassifierTool
108
+ "ArcPlusClassifierTool", # New ArcPlus classifier
109
  "ChestXRaySegmentationTool",
110
  # Add or remove tools as needed
111
  ]
 
122
 
123
  The following tools will automatically download their model weights when initialized:
124
 
125
+ ### Classification Tools
126
  ```python
127
+ # TorchXRayVision-based classifier (original)
128
+ TorchXRayVisionClassifierTool(device=device)
129
+
130
+ # ArcPlus SwinTransformer-based classifier (new)
131
+ ArcPlusClassifierTool(
132
+ model_path="/path/to/Ark6_swinLarge768_ep50.pth.tar", # Optional
133
+ num_classes=18, # Default
134
+ device=device
135
+ )
136
  ```
137
 
138
  ### Segmentation Tool
main.py CHANGED
@@ -23,7 +23,7 @@ def initialize_agent(
23
  model_dir="/model-weights",
24
  temp_dir="temp",
25
  device="cuda",
26
- model="chatgpt-4o-latest",
27
  temperature=0.7,
28
  top_p=0.95,
29
  model_kwargs={}
@@ -48,7 +48,11 @@ def initialize_agent(
48
  prompt = prompts["MEDICAL_ASSISTANT"]
49
 
50
  all_tools = {
51
- "ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device),
 
 
 
 
52
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
53
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
54
  "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
@@ -113,7 +117,8 @@ if __name__ == "__main__":
113
  selected_tools = [
114
  # "ImageVisualizerTool",
115
  # "DicomProcessorTool",
116
- # "ChestXRayClassifierTool",
 
117
  # "ChestXRaySegmentationTool",
118
  # "ChestXRayReportGeneratorTool",
119
  # "XRayVQATool",
 
23
  model_dir="/model-weights",
24
  temp_dir="temp",
25
  device="cuda",
26
+ model="gpt-4.1-2025-04-14",
27
  temperature=0.7,
28
  top_p=0.95,
29
  model_kwargs={}
 
48
  prompt = prompts["MEDICAL_ASSISTANT"]
49
 
50
  all_tools = {
51
+ "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
52
+ "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(
53
+ model_path=f"{model_dir}/Ark6_swinLarge768_ep50.pth.tar" if model_dir else None,
54
+ device=device
55
+ ),
56
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
57
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
58
  "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
 
117
  selected_tools = [
118
  # "ImageVisualizerTool",
119
  # "DicomProcessorTool",
120
+ # "TorchXRayVisionClassifierTool", # Renamed from ChestXRayClassifierTool
121
+ # "ArcPlusClassifierTool", # New ArcPlus classifier
122
  # "ChestXRaySegmentationTool",
123
  # "ChestXRayReportGeneratorTool",
124
  # "XRayVQATool",
medrax/tools/classification/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Classification tools for chest X-ray analysis."""
2
+
3
+ from .torchxrayvision import TorchXRayVisionClassifierTool, TorchXRayVisionInput
4
+ from .arcplus import ArcPlusClassifierTool, ArcPlusInput
5
+
6
+ __all__ = [
7
+ "TorchXRayVisionClassifierTool",
8
+ "TorchXRayVisionInput",
9
+ "ArcPlusClassifierTool",
10
+ "ArcPlusInput"
11
+ ]
medrax/tools/classification/arcplus.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, Type, ClassVar, List
2
+ from pydantic import BaseModel, Field
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torchvision.transforms as transforms
9
+ from timm.models.swin_transformer import SwinTransformer
10
+
11
+ from langchain_core.callbacks import (
12
+ AsyncCallbackManagerForToolRun,
13
+ CallbackManagerForToolRun,
14
+ )
15
+ from langchain_core.tools import BaseTool
16
+
17
+
18
+ class OmniSwinTransformer(SwinTransformer):
19
+ """OmniSwinTransformer with multiple classification heads and optional projector."""
20
+
21
+ def __init__(self, num_classes_list, projector_features=None, use_mlp=False, *args, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+ assert num_classes_list is not None
24
+
25
+ self.projector = None
26
+ if projector_features:
27
+ encoder_features = self.num_features
28
+ self.num_features = projector_features
29
+ if use_mlp:
30
+ self.projector = nn.Sequential(
31
+ nn.Linear(encoder_features, self.num_features),
32
+ nn.ReLU(inplace=True),
33
+ nn.Linear(self.num_features, self.num_features),
34
+ )
35
+ else:
36
+ self.projector = nn.Linear(encoder_features, self.num_features)
37
+
38
+ self.omni_heads = []
39
+ for num_classes in num_classes_list:
40
+ self.omni_heads.append(
41
+ nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
42
+ )
43
+ self.omni_heads = nn.ModuleList(self.omni_heads)
44
+
45
+ def forward(self, x, head_n=None):
46
+ x = self.forward_features(x)
47
+ if self.projector:
48
+ x = self.projector(x)
49
+ if head_n is not None:
50
+ return x, self.omni_heads[head_n](x)
51
+ else:
52
+ return [head(x) for head in self.omni_heads]
53
+
54
+ def generate_embeddings(self, x, after_proj=True):
55
+ x = self.forward_features(x)
56
+ if after_proj and self.projector:
57
+ x = self.projector(x)
58
+ return x
59
+
60
+
61
+ class ArcPlusInput(BaseModel):
62
+ """Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images."""
63
+
64
+ image_path: str = Field(
65
+ ..., description="Path to the radiology image file, only supports JPG or PNG images"
66
+ )
67
+
68
+
69
+ class ArcPlusClassifierTool(BaseTool):
70
+ """Tool that classifies chest X-ray images using the ArcPlus OmniSwinTransformer model.
71
+
72
+ This tool uses a pre-trained OmniSwinTransformer model (ArcPlus) to analyze chest X-ray images
73
+ and predict the likelihood of various pathologies across multiple medical datasets. The model
74
+ employs a Swin Transformer architecture with multiple classification heads, each specialized
75
+ for different medical datasets and conditions.
76
+
77
+ The ArcPlus model is trained on 6 different medical datasets:
78
+ - MIMIC-CXR: 14 pathologies including common chest conditions
79
+ - CheXpert: 14 pathologies with standardized labeling
80
+ - NIH ChestX-ray14: 14 pathologies from large-scale dataset
81
+ - RSNA: 3 classes for pneumonia detection
82
+ - VinDr-CXR: 6 categories including tuberculosis and lung tumors
83
+ - Shenzhen: 1 class for tuberculosis detection
84
+
85
+ Key Features:
86
+ - Multi-head architecture with 6 specialized classification heads
87
+ - 768x768 input resolution for high-detail analysis
88
+ - Projector layer with 1376 features for enhanced representation
89
+ - Sigmoid activation for multi-label classification
90
+ - Covers 52+ distinct pathology categories across datasets
91
+
92
+ The model outputs probabilities (0 to 1) for each condition, with higher values
93
+ indicating higher likelihood of the pathology being present in the image.
94
+ """
95
+
96
+ name: str = "arcplus_classifier"
97
+ description: str = (
98
+ "Advanced chest X-ray classification tool using ArcPlus OmniSwinTransformer with multi-dataset training. "
99
+ "Analyzes chest X-ray images and provides probability predictions for 52+ pathologies across 6 medical datasets. "
100
+ "Input: Path to chest X-ray image file (JPG/PNG). "
101
+ "Output: Dictionary mapping pathology names to probabilities (0-1). "
102
+ "Features: Multi-head architecture, 768px resolution, projector layer, specialized for medical imaging. "
103
+ "Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged Cardiomediastinum, "
104
+ "Fracture, Lung Lesion, Lung Opacity, Pleural Effusion, Pneumonia, Pneumothorax, Mass, Nodule, "
105
+ "Emphysema, Fibrosis, PE, Lung Tumor, Tuberculosis, and many more across MIMIC, CheXpert, NIH, "
106
+ "RSNA, VinDr, and Shenzhen datasets. Higher probabilities indicate higher likelihood of condition presence."
107
+ )
108
+ args_schema: Type[BaseModel] = ArcPlusInput
109
+ model: OmniSwinTransformer = None
110
+ device: Optional[str] = "cuda"
111
+ normalize: transforms.Normalize = None
112
+ disease_list: List[str] = None
113
+ num_classes_list: List[int] = None
114
+
115
+ # Disease mappings from the analysis
116
+ mimic_diseases: ClassVar[List[str]] = [
117
+ "Atelectasis",
118
+ "Cardiomegaly",
119
+ "Consolidation",
120
+ "Edema",
121
+ "Enlarged Cardiomediastinum",
122
+ "Fracture",
123
+ "Lung Lesion",
124
+ "Lung Opacity",
125
+ "No Finding",
126
+ "Pleural Effusion",
127
+ "Pleural Other",
128
+ "Pneumonia",
129
+ "Pneumothorax",
130
+ "Support Devices",
131
+ ]
132
+ chexpert_diseases: ClassVar[List[str]] = [
133
+ "No Finding",
134
+ "Enlarged Cardiomediastinum",
135
+ "Cardiomegaly",
136
+ "Lung Opacity",
137
+ "Lung Lesion",
138
+ "Edema",
139
+ "Consolidation",
140
+ "Pneumonia",
141
+ "Atelectasis",
142
+ "Pneumothorax",
143
+ "Pleural Effusion",
144
+ "Pleural Other",
145
+ "Fracture",
146
+ "Support Devices",
147
+ ]
148
+ nih14_diseases: ClassVar[List[str]] = [
149
+ "Atelectasis",
150
+ "Cardiomegaly",
151
+ "Effusion",
152
+ "Infiltration",
153
+ "Mass",
154
+ "Nodule",
155
+ "Pneumonia",
156
+ "Pneumothorax",
157
+ "Consolidation",
158
+ "Edema",
159
+ "Emphysema",
160
+ "Fibrosis",
161
+ "Pleural_Thickening",
162
+ "Hernia",
163
+ ]
164
+ rsna_diseases: ClassVar[List[str]] = ["No Lung Opacity/Not Normal", "Normal", "Lung Opacity"]
165
+ vindr_diseases: ClassVar[List[str]] = [
166
+ "PE",
167
+ "Lung tumor",
168
+ "Pneumonia",
169
+ "Tuberculosis",
170
+ "Other diseases",
171
+ "No finding",
172
+ ]
173
+ shenzhen_diseases: ClassVar[List[str]] = ["TB"]
174
+
175
+ def __init__(self, model_path: str = None, device: Optional[str] = "cuda"):
176
+ """Initialize the ArcPlus Classifier Tool.
177
+
178
+ Args:
179
+ model_path (str, optional): Path to the pre-trained ArcPlus model checkpoint file.
180
+ Expected file: 'Ark6_swinLarge768_ep50.pth.tar' or similar ArcPlus checkpoint.
181
+ If None, model will be initialized with random weights (not recommended for inference).
182
+ Default: None.
183
+ device (str, optional): Device to run the model on ('cuda' for GPU, 'cpu' for CPU).
184
+ GPU is recommended for better performance. Default: "cuda".
185
+
186
+ Model Architecture Details:
187
+ - OmniSwinTransformer with 6 classification heads
188
+ - Input resolution: 768x768 pixels
189
+ - Projector features: 1376 dimensions
190
+ - Multi-head configuration: [14, 14, 14, 3, 6, 1] classes per head
191
+ - Total pathologies: 52+ across 6 medical datasets
192
+ - Preprocessing: ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
193
+
194
+ Raises:
195
+ FileNotFoundError: If model_path is provided but file doesn't exist.
196
+ RuntimeError: If model loading fails or device is unavailable.
197
+ """
198
+ super().__init__()
199
+
200
+ # Create combined disease list from all supported datasets
201
+ self.disease_list = (
202
+ self.mimic_diseases
203
+ + self.chexpert_diseases
204
+ + self.nih14_diseases
205
+ + self.rsna_diseases
206
+ + self.vindr_diseases
207
+ + self.shenzhen_diseases
208
+ )
209
+
210
+ # Multi-head configuration: [MIMIC, CheXpert, NIH, RSNA, VinDr, Shenzhen]
211
+ self.num_classes_list = [14, 14, 14, 3, 6, 1]
212
+
213
+ # Initialize the OmniSwinTransformer model with ArcPlus architecture
214
+ self.model = OmniSwinTransformer(
215
+ num_classes_list=self.num_classes_list,
216
+ projector_features=1376, # Enhanced feature representation
217
+ use_mlp=False, # Linear projector (not MLP)
218
+ img_size=768, # High-resolution input
219
+ patch_size=4,
220
+ window_size=12,
221
+ embed_dim=192,
222
+ depths=(2, 2, 18, 2), # Swin-Large configuration
223
+ num_heads=(6, 12, 24, 48),
224
+ )
225
+
226
+ # Load pre-trained weights if provided
227
+ if model_path:
228
+ self._load_checkpoint(model_path)
229
+
230
+ self.model.eval()
231
+ self.device = torch.device(device) if device else "cuda"
232
+ self.model = self.model.to(self.device)
233
+
234
+ # ImageNet normalization parameters for optimal performance
235
+ self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
236
+
237
+ def _load_checkpoint(self, model_path: str) -> None:
238
+ """
239
+ Load the ArcPlus model checkpoint.
240
+
241
+ Args:
242
+ model_path (str): Path to the model checkpoint file.
243
+ """
244
+ # Load the checkpoint (set weights_only=False for PyTorch 2.6+ compatibility)
245
+ checkpoint = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
246
+ state_dict = checkpoint["teacher"] # Use 'teacher' key
247
+
248
+ # Remove "module." prefix if present (improved logic from example)
249
+ if any([True if "module." in k else False for k in state_dict.keys()]):
250
+ state_dict = {
251
+ k.replace("module.", ""): v
252
+ for k, v in state_dict.items()
253
+ if k.startswith("module.")
254
+ }
255
+
256
+ # Load the model weights
257
+ msg = self.model.load_state_dict(state_dict, strict=False)
258
+
259
+ def _process_image(self, image_path: str) -> torch.Tensor:
260
+ """
261
+ Process the input chest X-ray image for model inference.
262
+
263
+ This method loads the image, applies necessary transformations,
264
+ and prepares it as a torch.Tensor for model input.
265
+
266
+ Args:
267
+ image_path (str): The file path to the chest X-ray image.
268
+
269
+ Returns:
270
+ torch.Tensor: A processed image tensor ready for model inference.
271
+
272
+ Raises:
273
+ FileNotFoundError: If the specified image file does not exist.
274
+ ValueError: If the image cannot be properly loaded or processed.
275
+ """
276
+ try:
277
+ # Load and preprocess image following the example pattern
278
+ image = Image.open(image_path).convert("RGB").resize((768, 768))
279
+
280
+ # Convert to numpy array and normalize to [0, 1]
281
+ image_array = np.array(image) / 255.0
282
+
283
+ # Apply ImageNet normalization
284
+ image_tensor = torch.from_numpy(image_array).float()
285
+ image_tensor = image_tensor.permute(2, 0, 1) # HWC to CHW
286
+ image_tensor = self.normalize(image_tensor)
287
+
288
+ # Add batch dimension and move to device
289
+ image_tensor = image_tensor.unsqueeze(0).to(self.device)
290
+
291
+ return image_tensor
292
+
293
+ except Exception as e:
294
+ raise ValueError(f"Error processing image {image_path}: {str(e)}")
295
+
296
+ def _run(
297
+ self,
298
+ image_path: str,
299
+ run_manager: Optional[CallbackManagerForToolRun] = None,
300
+ ) -> Tuple[Dict[str, float], Dict]:
301
+ """Classify the chest X-ray image using ArcPlus SwinTransformer.
302
+
303
+ Args:
304
+ image_path (str): The path to the chest X-ray image file.
305
+ run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
306
+
307
+ Returns:
308
+ Tuple[Dict[str, float], Dict]: A tuple containing the classification results
309
+ (pathologies and their probabilities from 0 to 1)
310
+ and any additional metadata.
311
+
312
+ Raises:
313
+ Exception: If there's an error processing the image or during classification.
314
+ """
315
+ try:
316
+ # Process the image
317
+ image_tensor = self._process_image(image_path)
318
+
319
+ # Run model inference
320
+ with torch.no_grad():
321
+ pre_logits = self.model(image_tensor)
322
+
323
+ # Apply sigmoid to each output head (as seen in example)
324
+ preds = [torch.sigmoid(out) for out in pre_logits]
325
+
326
+ # Concatenate all predictions into single tensor
327
+ preds = torch.cat(preds, dim=1)
328
+
329
+ # Convert to numpy
330
+ predictions = preds.cpu().numpy().flatten()
331
+
332
+ # Map predictions to disease names
333
+ if len(predictions) != len(self.disease_list):
334
+ print(
335
+ f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}"
336
+ )
337
+ # Pad or truncate as needed
338
+ if len(predictions) < len(self.disease_list):
339
+ predictions = np.pad(
340
+ predictions, (0, len(self.disease_list) - len(predictions))
341
+ )
342
+ else:
343
+ predictions = predictions[: len(self.disease_list)]
344
+
345
+ # Create output dictionary mapping disease names to probabilities
346
+ output = dict(zip(self.disease_list, predictions.astype(float)))
347
+
348
+ metadata = {
349
+ "image_path": image_path,
350
+ "model": "ArcPlus OmniSwinTransformer",
351
+ "analysis_status": "completed",
352
+ "num_predictions": len(predictions),
353
+ "num_heads": len(self.num_classes_list),
354
+ "projector_features": 1376,
355
+ "note": "Probabilities range from 0 to 1, with higher values indicating higher likelihood of the condition.",
356
+ }
357
+
358
+ return output, metadata
359
+
360
+ except Exception as e:
361
+ return {"error": str(e)}, {
362
+ "image_path": image_path,
363
+ "analysis_status": "failed",
364
+ "error_details": str(e),
365
+ }
366
+
367
+ async def _arun(
368
+ self,
369
+ image_path: str,
370
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
371
+ ) -> Tuple[Dict[str, float], Dict]:
372
+ """Asynchronously classify the chest X-ray image using ArcPlus SwinTransformer.
373
+
374
+ This method currently calls the synchronous version, as the model inference
375
+ is not inherently asynchronous. For true asynchronous behavior, consider
376
+ using a separate thread or process.
377
+
378
+ Args:
379
+ image_path (str): The path to the chest X-ray image file.
380
+ run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
381
+
382
+ Returns:
383
+ Tuple[Dict[str, float], Dict]: A tuple containing the classification results
384
+ (pathologies and their probabilities from 0 to 1)
385
+ and any additional metadata.
386
+
387
+ Raises:
388
+ Exception: If there's an error processing the image or during classification.
389
+ """
390
+ return self._run(image_path)
medrax/tools/{classification.py → classification/torchxrayvision.py} RENAMED
@@ -13,15 +13,15 @@ from langchain_core.callbacks import (
13
  from langchain_core.tools import BaseTool
14
 
15
 
16
- class ChestXRayInput(BaseModel):
17
- """Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
18
 
19
  image_path: str = Field(
20
  ..., description="Path to the radiology image file, only supports JPG or PNG images"
21
  )
22
 
23
 
24
- class ChestXRayClassifierTool(BaseTool):
25
  """Tool that classifies chest X-ray images for multiple pathologies.
26
 
27
  This tool uses a pre-trained DenseNet model to analyze chest X-ray images and
@@ -35,9 +35,9 @@ class ChestXRayClassifierTool(BaseTool):
35
  A higher value indicates a higher likelihood of the condition being present.
36
  """
37
 
38
- name: str = "chest_xray_classifier"
39
  description: str = (
40
- "A tool that analyzes chest X-ray images and classifies them for 18 different pathologies. "
41
  "Input should be the path to a chest X-ray image file. "
42
  "Output is a dictionary of pathologies and their predicted probabilities (0 to 1). "
43
  "Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, "
@@ -45,7 +45,7 @@ class ChestXRayClassifierTool(BaseTool):
45
  "Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, and Pneumothorax. "
46
  "Higher values indicate a higher likelihood of the condition being present."
47
  )
48
- args_schema: Type[BaseModel] = ChestXRayInput
49
  model: xrv.models.DenseNet = None
50
  device: Optional[str] = "cuda"
51
  transform: torchvision.transforms.Compose = None
 
13
  from langchain_core.tools import BaseTool
14
 
15
 
16
+ class TorchXRayVisionInput(BaseModel):
17
+ """Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
18
 
19
  image_path: str = Field(
20
  ..., description="Path to the radiology image file, only supports JPG or PNG images"
21
  )
22
 
23
 
24
+ class TorchXRayVisionClassifierTool(BaseTool):
25
  """Tool that classifies chest X-ray images for multiple pathologies.
26
 
27
  This tool uses a pre-trained DenseNet model to analyze chest X-ray images and
 
35
  A higher value indicates a higher likelihood of the condition being present.
36
  """
37
 
38
+ name: str = "torchxrayvision_classifier"
39
  description: str = (
40
+ "A tool that analyzes chest X-ray images and classifies them for 18 different pathologies using TorchXRayVision DenseNet. "
41
  "Input should be the path to a chest X-ray image file. "
42
  "Output is a dictionary of pathologies and their predicted probabilities (0 to 1). "
43
  "Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, "
 
45
  "Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, and Pneumothorax. "
46
  "Higher values indicate a higher likelihood of the condition being present."
47
  )
48
+ args_schema: Type[BaseModel] = TorchXRayVisionInput
49
  model: xrv.models.DenseNet = None
50
  device: Optional[str] = "cuda"
51
  transform: torchvision.transforms.Compose = None
quickstart.py CHANGED
@@ -11,7 +11,7 @@ from datasets import load_dataset
11
 
12
  # Initialize global variables
13
  logger = logging.getLogger('benchmark')
14
- model_name = 'chatgpt-4o-latest' # default value
15
  temperature = 0.2 # default value
16
  log_filename = None
17
 
@@ -199,7 +199,7 @@ def main():
199
  # Add command line argument parsing
200
  parser = argparse.ArgumentParser(description='Run medical image analysis benchmark')
201
  parser.add_argument('--use-urls', action='store_true', help='Use image URLs instead of local files')
202
- parser.add_argument('--model', type=str, default='chatgpt-4o-latest', help='Model name to use')
203
  parser.add_argument('--temperature', type=float, default=0.2, help='Temperature for model inference')
204
  parser.add_argument('--log-prefix', type=str, help='Prefix for log filename (default: model name)')
205
  parser.add_argument('--max-cases', type=int, default=None, help='Maximum number of cases to process (default: all)')
 
11
 
12
  # Initialize global variables
13
  logger = logging.getLogger('benchmark')
14
+ model_name = 'gpt-4.1-2025-04-14' # default value
15
  temperature = 0.2 # default value
16
  log_filename = None
17
 
 
199
  # Add command line argument parsing
200
  parser = argparse.ArgumentParser(description='Run medical image analysis benchmark')
201
  parser.add_argument('--use-urls', action='store_true', help='Use image URLs instead of local files')
202
+ parser.add_argument('--model', type=str, default='gpt-4.1-2025-04-14', help='Model name to use')
203
  parser.add_argument('--temperature', type=float, default=0.2, help='Temperature for model inference')
204
  parser.add_argument('--log-prefix', type=str, help='Prefix for log filename (default: model name)')
205
  parser.add_argument('--max-cases', type=int, default=None, help='Maximum number of cases to process (default: all)')