edeler commited on
Commit
e6cb34f
Β·
verified Β·
1 Parent(s): d611a6e

lorai (#1)

Browse files

- Update app.py with Spaces-optimized medical image analysis and enhanced README (46d6674ae12fcd833c37edb5df9b4e72cbab8790)
- Add proper Space metadata to README for better Space configuration (12c0045ea2c4166f7e0372f8e70fdfbcedb5d7e4)

Files changed (2) hide show
  1. README.md +48 -19
  2. app.py +451 -136
README.md CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  # πŸ₯ Medical Image Analysis Tool
2
 
3
  An AI-powered medical image analysis application using advanced detection models and large language models for medical image interpretation.
@@ -8,17 +20,23 @@ An AI-powered medical image analysis application using advanced detection models
8
  - **Medical AI Analysis**: Integrates MedGemma, a specialized medical vision-language model
9
  - **Interactive Interface**: Built with Gradio for easy web-based interaction
10
  - **Configurable Thresholds**: Adjustable confidence thresholds for detection sensitivity
11
- - **GPU Acceleration**: Optimized for GPU usage when available
 
 
12
 
13
  ## Models Used
14
 
15
  - **RF-DETR Medium**: State-of-the-art object detection model
16
- - **MedGemma 4B**: Medical-specialized vision-language model for analysis and descriptions
 
 
17
 
18
  ## Usage
19
 
20
  1. **Upload Image**: Click on the image upload area or drag and drop a medical image
21
- 2. **Adjust Settings**: Use the confidence threshold slider to control detection sensitivity
 
 
22
  3. **Analyze**: Click "Analyze Image" to run the AI analysis
23
  4. **View Results**: See the annotated image with detected objects and AI-generated descriptions
24
 
@@ -26,24 +44,28 @@ An AI-powered medical image analysis application using advanced detection models
26
 
27
  This application is designed to run on Hugging Face Spaces. The following files are required:
28
 
29
- - `app.py` - Main application file
30
  - `requirements.txt` - Python dependencies
31
  - `packages.txt` - System packages
32
- - Model files in the `models/` directory
33
 
34
- ## Model Files Structure
35
 
36
- The application expects the following model files:
 
 
37
 
38
- ```
39
- models/
40
- β”œβ”€β”€ medgemma-4b-it/ # MedGemma model files
41
- β”‚ β”œβ”€β”€ config.json
42
- β”‚ β”œβ”€β”€ tokenizer.json
43
- β”‚ β”œβ”€β”€ model-00001-of-00002.safetensors
44
- β”‚ └── model-00002-of-00002.safetensors
45
- └── rf-detr-medium.pth # RF-DETR model weights
46
- ```
 
 
47
 
48
  ## Technical Details
49
 
@@ -54,9 +76,11 @@ models/
54
 
55
  ## Performance Tips
56
 
57
- - Higher confidence thresholds reduce false positives but may miss subtle findings
58
- - The application automatically uses GPU acceleration when available
59
- - Model loading happens on first use and is cached for subsequent analyses
 
 
60
 
61
  ## Limitations
62
 
@@ -73,6 +97,11 @@ pip install -r requirements.txt
73
  python app.py
74
  ```
75
 
 
 
 
 
 
76
  ## License
77
 
78
  This project is for research and educational purposes. Medical applications should be developed and validated according to appropriate regulatory standards.
 
1
+ ---
2
+ title: Medical Image Analysis Tool
3
+ emoji: πŸ₯
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: "4.0.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
  # πŸ₯ Medical Image Analysis Tool
14
 
15
  An AI-powered medical image analysis application using advanced detection models and large language models for medical image interpretation.
 
20
  - **Medical AI Analysis**: Integrates MedGemma, a specialized medical vision-language model
21
  - **Interactive Interface**: Built with Gradio for easy web-based interaction
22
  - **Configurable Thresholds**: Adjustable confidence thresholds for detection sensitivity
23
+ - **Model Size Selection**: Choose between MedGemma 4B (faster) or 27B (more accurate) models
24
+ - **GPU Acceleration**: Optimized for GPU usage when available with 4-bit quantization
25
+ - **Automatic Model Downloads**: Models download automatically from Hugging Face Hub
26
 
27
  ## Models Used
28
 
29
  - **RF-DETR Medium**: State-of-the-art object detection model
30
+ - **MedGemma 4B/27B**: Medical-specialized vision-language models for analysis and descriptions
31
+ - 4B model: Faster inference, lower memory usage
32
+ - 27B model: Higher accuracy, requires more resources
33
 
34
  ## Usage
35
 
36
  1. **Upload Image**: Click on the image upload area or drag and drop a medical image
37
+ 2. **Adjust Settings**:
38
+ - Use the confidence threshold slider to control detection sensitivity
39
+ - Select model size (4B for speed, 27B for accuracy)
40
  3. **Analyze**: Click "Analyze Image" to run the AI analysis
41
  4. **View Results**: See the annotated image with detected objects and AI-generated descriptions
42
 
 
44
 
45
  This application is designed to run on Hugging Face Spaces. The following files are required:
46
 
47
+ - `app.py` - Main application file (optimized for Spaces)
48
  - `requirements.txt` - Python dependencies
49
  - `packages.txt` - System packages
50
+ - `README.md` - This documentation
51
 
52
+ ## Model Loading
53
 
54
+ **RF-DETR Model:**
55
+ - Upload your trained `rf-detr-medium.pth` file to the Space
56
+ - The application will automatically find and load it
57
 
58
+ **MedGemma Models:**
59
+ - Models download automatically from Hugging Face Hub on first use
60
+ - No manual installation required
61
+ - Choose between 4B (faster) or 27B (more accurate) models
62
+
63
+ ## Space Configuration
64
+
65
+ For optimal performance, configure your Space settings:
66
+ - **Hardware**: GPU (T4 minimum, A100 recommended for 27B models)
67
+ - **Storage**: Enable persistent storage for model caching
68
+ - **Timeout**: 30+ minutes for large model downloads
69
 
70
  ## Technical Details
71
 
 
76
 
77
  ## Performance Tips
78
 
79
+ - **Model Selection**: Use MedGemma 4B for faster processing or 27B for higher accuracy
80
+ - **Confidence Thresholds**: Higher values reduce false positives but may miss subtle findings
81
+ - **GPU Acceleration**: The application automatically uses GPU acceleration when available
82
+ - **Memory Optimization**: Uses 4-bit quantization to reduce memory usage
83
+ - **Model Caching**: Models are cached after first load for faster subsequent analyses
84
 
85
  ## Limitations
86
 
 
97
  python app.py
98
  ```
99
 
100
+ **Note**: For local development, you'll need to:
101
+ 1. Install the RF-DETR package or ensure it's available
102
+ 2. Place your `rf-detr-medium.pth` file in the project directory
103
+ 3. Models will download automatically on first run
104
+
105
  ## License
106
 
107
  This project is for research and educational purposes. Medical applications should be developed and validated according to appropriate regulatory standards.
app.py CHANGED
@@ -1,185 +1,500 @@
1
  import os
2
- import gc
3
  import json
 
4
  import time
5
- import warnings
6
- from typing import Dict, List, Optional, Tuple, Any
7
  import traceback
 
8
 
9
  import torch
10
- import cv2
11
- import numpy as np
12
- from PIL import Image
13
  import gradio as gr
 
 
14
 
15
- # Import ML libraries
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  try:
17
- import supervision as sv
18
- from transformers import AutoModelForImageTextToText, AutoProcessor
19
- except ImportError as e:
20
- print(f"Warning: Missing dependencies: {e}")
21
 
22
- # Suppress warnings
23
- warnings.filterwarnings("ignore")
 
24
 
25
- # Model paths - adjust these for your Space
26
- MODEL_DIR = "models"
27
- RESULTS_DIR = "results"
28
- CACHE_DIR = os.path.join(MODEL_DIR, "hf_cache")
29
 
30
- class ModelManager:
31
  def __init__(self):
32
- self.detector = None
33
- self.processor = None
34
- self.llm_model = None
35
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def load_models(self):
38
- """Load the detection and LLM models"""
 
 
 
 
 
 
39
  try:
40
- print(f"Loading models on device: {self.device}")
41
-
42
- # Load RF-DETR detector
43
- print("Loading RF-DETR detector...")
44
- self.detector = torch.load("rf-detr-medium.pth", map_location=self.device)
45
- self.detector.eval()
46
-
47
- # Load MedGemma processor and model
48
- print("Loading MedGemma model...")
49
- processor_path = os.path.join(MODEL_DIR, "medgemma-4b-it")
50
- if os.path.exists(processor_path):
51
- self.processor = AutoProcessor.from_pretrained(processor_path)
52
- self.llm_model = AutoModelForImageTextToText.from_pretrained(
53
- processor_path,
54
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
55
- device_map="auto" if self.device == "cuda" else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
57
- else:
58
- print("Warning: MedGemma model not found locally, using basic detection only")
 
59
 
60
- except Exception as e:
61
- print(f"Error loading models: {e}")
62
- self.detector = None
63
- self.processor = None
64
- self.llm_model = None
 
 
 
 
 
 
 
 
 
 
65
 
66
- def detect_objects(self, image: Image.Image, threshold: float = 0.7) -> Tuple[Image.Image, str]:
67
- """Run object detection on the image"""
68
- if self.detector is None:
69
- return image, "Error: Detector not loaded"
 
 
70
 
71
  try:
72
- # Convert PIL to numpy
73
- image_np = np.array(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Run detection (simplified - adjust based on your RF-DETR implementation)
76
- with torch.no_grad():
77
- # This is a placeholder - you'll need to adapt based on your RF-DETR usage
78
- detections = self.detector(image_np, threshold=threshold)
79
 
80
- # Annotate image
81
- annotated_image = self._annotate_image(image_np, detections)
 
 
 
 
 
 
82
 
83
- # Generate description
84
- description = self._generate_description(annotated_image, detections)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- return Image.fromarray(annotated_image), description
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  except Exception as e:
89
- return image, f"Error during detection: {str(e)}"
 
 
90
 
91
- def _annotate_image(self, image: np.ndarray, detections) -> np.ndarray:
92
- """Annotate image with detections"""
93
- # Placeholder annotation - adapt based on your detection format
94
- annotated = image.copy()
95
 
96
- # Add detection boxes (adjust based on your detection format)
97
- if hasattr(detections, 'boxes') and len(detections.boxes) > 0:
98
- for box in detections.boxes:
99
- x1, y1, x2, y2 = box.cpu().numpy().astype(int)
100
- cv2.rectangle(annotated, (x1, y1), (x2, y2), (0, 255, 0), 2)
101
 
102
- return annotated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- def _generate_description(self, image: np.ndarray, detections) -> str:
105
- """Generate text description using LLM"""
106
- if self.processor is None or self.llm_model is None:
107
- return "Basic detection completed (LLM not available)"
108
 
 
109
  try:
110
- # Prepare image for LLM
111
- pil_image = Image.fromarray(image)
112
-
113
- # Create prompt for medical analysis
114
- prompt = "Analyze this medical image and describe any findings related to larynx granuloma or other abnormalities."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- # Process image and text
117
- inputs = self.processor(text=prompt, images=pil_image, return_tensors="pt")
 
118
 
119
- if self.device == "cuda":
120
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
121
 
122
- # Generate response
123
- with torch.no_grad():
124
- outputs = self.llm_model.generate(
125
- **inputs,
126
- max_new_tokens=200,
127
- temperature=0.2,
128
- do_sample=True
129
- )
130
 
131
- # Decode response
132
- response = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
133
- return response.strip()
 
134
 
135
- except Exception as e:
136
- return f"LLM analysis failed: {str(e)}"
 
 
 
 
 
137
 
138
- # Global model manager
139
- model_manager = ModelManager()
140
 
141
- def analyze_image(image: Image.Image, threshold: float = 0.7, use_llm: bool = True) -> Tuple[Image.Image, str]:
142
- """Main function to analyze uploaded image"""
143
- if model_manager.detector is None:
144
- model_manager.load_models()
145
 
146
- if model_manager.detector is None:
147
- return image, "Error: Could not load models. Please check the model files."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- return model_manager.detect_objects(image, threshold)
150
 
151
- # Create Gradio interface
152
- with gr.Blocks(title="Medical Image Analysis") as demo:
153
- gr.Markdown(
154
- "# πŸ₯ Medical Image Analysis Tool\n\n"
155
- "Upload a medical image for AI-powered analysis using advanced detection models."
156
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- with gr.Row():
159
- with gr.Column():
160
- input_image = gr.Image(type="pil", label="Upload Medical Image")
161
- threshold_slider = gr.Slider(
162
- 0.1, 1.0, value=0.7, step=0.05,
163
- label="Detection Threshold",
164
- info="Higher values = fewer but more confident detections"
165
- )
166
- analyze_btn = gr.Button("Analyze Image", variant="primary")
167
 
168
- with gr.Column():
169
- output_image = gr.Image(type="pil", label="Analysis Results")
170
- description = gr.Markdown(label="AI Analysis", value="Upload an image to begin analysis")
171
 
172
- analyze_btn.click(
173
- analyze_image,
174
- inputs=[input_image, threshold_slider],
175
- outputs=[output_image, description]
176
- )
 
 
 
177
 
178
- input_image.change(
179
- analyze_image,
180
- inputs=[input_image, threshold_slider],
181
- outputs=[output_image, description]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  )
183
 
184
  if __name__ == "__main__":
185
- demo.launch()
 
1
  import os
 
2
  import json
3
+ import gc
4
  import time
 
 
5
  import traceback
6
+ from typing import Dict, List, Optional, Tuple, Callable, Any
7
 
8
  import torch
 
 
 
9
  import gradio as gr
10
+ import supervision as sv
11
+ from PIL import Image
12
 
13
+ # Try to import optional dependencies
14
+ try:
15
+ from transformers import (
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ AutoModelForImageTextToText,
19
+ AutoProcessor,
20
+ BitsAndBytesConfig,
21
+ )
22
+ except Exception:
23
+ AutoModelForCausalLM = None
24
+ AutoTokenizer = None
25
+ AutoModelForImageTextToText = None
26
+ AutoProcessor = None
27
+ BitsAndBytesConfig = None
28
+
29
+ # Import RF-DETR (assumes it's in the same directory or installed)
30
  try:
31
+ from rfdetr import RFDETRMedium
32
+ except ImportError:
33
+ print("Warning: RF-DETR not found. Please ensure it's properly installed.")
34
+ RFDETRMedium = None
35
 
36
+ # ============================================================================
37
+ # Configuration for Hugging Face Spaces
38
+ # ============================================================================
39
 
40
+ class SpacesConfig:
41
+ """Configuration optimized for Hugging Face Spaces."""
 
 
42
 
 
43
  def __init__(self):
44
+ self.settings = {
45
+ 'results_dir': '/tmp/results',
46
+ 'checkpoint': None,
47
+ 'resolution': 576,
48
+ 'threshold': 0.7,
49
+ 'use_llm': True,
50
+ 'llm_model_id': 'google/medgemma-4b-it',
51
+ 'llm_max_new_tokens': 200,
52
+ 'llm_temperature': 0.2,
53
+ 'llm_4bit': True,
54
+ 'enable_caching': True,
55
+ 'max_cache_size': 100,
56
+ }
57
+
58
+ def get(self, key: str, default: Any = None) -> Any:
59
+ return self.settings.get(key, default)
60
+
61
+ # ============================================================================
62
+ # Memory Management (simplified for Spaces)
63
+ # ============================================================================
64
+
65
+ class MemoryManager:
66
+ """Simplified memory management for Spaces."""
67
 
68
+ def __init__(self):
69
+ self.memory_thresholds = {
70
+ 'gpu_warning': 0.8,
71
+ 'system_warning': 0.85,
72
+ }
73
+
74
+ def cleanup_memory(self, force: bool = False) -> None:
75
+ """Perform memory cleanup."""
76
  try:
77
+ gc.collect()
78
+ if torch and torch.cuda.is_available():
79
+ torch.cuda.empty_cache()
80
+ torch.cuda.synchronize()
81
+ except Exception as e:
82
+ print(f"Memory cleanup error: {e}")
83
+
84
+ # Global memory manager
85
+ memory_manager = MemoryManager()
86
+
87
+ # ============================================================================
88
+ # Model Loading
89
+ # ============================================================================
90
+
91
+ def find_checkpoint() -> Optional[str]:
92
+ """Find RF-DETR checkpoint in various locations."""
93
+ candidates = [
94
+ "rf-detr-medium.pth", # Current directory
95
+ "/tmp/results/checkpoint_best_total.pth",
96
+ "/tmp/results/checkpoint_best_ema.pth",
97
+ "/tmp/results/checkpoint_best_regular.pth",
98
+ "/tmp/results/checkpoint.pth",
99
+ ]
100
+
101
+ for path in candidates:
102
+ if os.path.isfile(path):
103
+ return path
104
+ return None
105
+
106
+ def load_model(checkpoint_path: str, resolution: int):
107
+ """Load RF-DETR model."""
108
+ if RFDETRMedium is None:
109
+ raise RuntimeError("RF-DETR not available. Please install it properly.")
110
+
111
+ model = RFDETRMedium(pretrain_weights=checkpoint_path, resolution=resolution)
112
+ try:
113
+ model.optimize_for_inference()
114
+ except Exception:
115
+ pass
116
+ return model
117
+
118
+ # ============================================================================
119
+ # LLM Integration
120
+ # ============================================================================
121
+
122
+ class TextGenerator:
123
+ """Simplified text generator for Spaces."""
124
+
125
+ def __init__(self, model_id: str, max_tokens: int = 200, temperature: float = 0.2):
126
+ self.model_id = model_id
127
+ self.max_tokens = max_tokens
128
+ self.temperature = temperature
129
+ self.model = None
130
+ self.tokenizer = None
131
+ self.processor = None
132
+ self.is_multimodal = False
133
+
134
+ def load_model(self):
135
+ """Load the LLM model."""
136
+ if self.model is not None:
137
+ return
138
+
139
+ if (AutoModelForCausalLM is None and AutoModelForImageTextToText is None):
140
+ raise RuntimeError("Transformers not available")
141
+
142
+ # Clear memory before loading
143
+ memory_manager.cleanup_memory()
144
+
145
+ print(f"Loading model: {self.model_id}")
146
+
147
+ model_kwargs = {
148
+ "device_map": "auto",
149
+ "low_cpu_mem_usage": True,
150
+ }
151
+
152
+ if torch and torch.cuda.is_available():
153
+ model_kwargs["torch_dtype"] = torch.bfloat16
154
+
155
+ # Use 4-bit quantization if available
156
+ if BitsAndBytesConfig is not None:
157
+ try:
158
+ compute_dtype = torch.bfloat16 if torch and torch.cuda.is_available() else torch.float16
159
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
160
+ load_in_4bit=True,
161
+ bnb_4bit_compute_dtype=compute_dtype,
162
+ bnb_4bit_use_double_quant=True,
163
+ bnb_4bit_quant_type="nf4"
164
  )
165
+ model_kwargs["torch_dtype"] = compute_dtype
166
+ except Exception:
167
+ pass
168
 
169
+ # Check if it's a multimodal model
170
+ is_multimodal = "medgemma" in self.model_id.lower()
171
+
172
+ if is_multimodal and AutoModelForImageTextToText is not None and AutoProcessor is not None:
173
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
174
+ self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, **model_kwargs)
175
+ self.is_multimodal = True
176
+ elif AutoModelForCausalLM is not None and AutoTokenizer is not None:
177
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
178
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs)
179
+ self.is_multimodal = False
180
+ else:
181
+ raise RuntimeError("Required model classes not available")
182
+
183
+ print("βœ“ Model loaded successfully")
184
 
185
+ def generate(self, text: str, image: Optional[Image.Image] = None) -> str:
186
+ """Generate text using the loaded model."""
187
+ self.load_model()
188
+
189
+ if self.model is None:
190
+ return f"[Model not loaded: {text}]"
191
 
192
  try:
193
+ # Create messages
194
+ system_text = "You are a concise medical assistant. Provide a brief, clear summary of detection results. Avoid repetition and be direct. Do not give medical advice."
195
+ user_text = f"Summarize these detection results in 3 clear sentences:\n\n{text}"
196
+
197
+ if self.is_multimodal:
198
+ # Multimodal model
199
+ user_content = [{"type": "text", "text": user_text}]
200
+ if image is not None:
201
+ user_content.append({"type": "image", "image": image})
202
+
203
+ messages = [
204
+ {"role": "system", "content": [{"type": "text", "text": system_text}]},
205
+ {"role": "user", "content": user_content},
206
+ ]
207
+
208
+ inputs = self.processor.apply_chat_template(
209
+ messages,
210
+ add_generation_prompt=True,
211
+ tokenize=True,
212
+ return_dict=True,
213
+ return_tensors="pt",
214
+ )
215
 
216
+ if torch:
217
+ inputs = inputs.to(self.model.device, dtype=torch.bfloat16)
 
 
218
 
219
+ with torch.inference_mode():
220
+ generation = self.model.generate(
221
+ **inputs,
222
+ max_new_tokens=self.max_tokens,
223
+ do_sample=self.temperature > 0,
224
+ temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
225
+ use_cache=False,
226
+ )
227
 
228
+ input_len = inputs["input_ids"].shape[-1]
229
+ generation = generation[0][input_len:]
230
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
231
+ return decoded.strip()
232
+
233
+ else:
234
+ # Text-only model
235
+ messages = [
236
+ {"role": "system", "content": system_text},
237
+ {"role": "user", "content": user_text},
238
+ ]
239
+
240
+ inputs = self.tokenizer.apply_chat_template(
241
+ messages,
242
+ add_generation_prompt=True,
243
+ tokenize=True,
244
+ return_dict=True,
245
+ return_tensors="pt",
246
+ )
247
+
248
+ inputs = inputs.to(self.model.device)
249
 
250
+ with torch.inference_mode():
251
+ generation = self.model.generate(
252
+ **inputs,
253
+ max_new_tokens=self.max_tokens,
254
+ do_sample=self.temperature > 0,
255
+ temperature=max(0.01, self.temperature) if self.temperature > 0 else None,
256
+ use_cache=False,
257
+ )
258
+
259
+ input_len = inputs["input_ids"].shape[-1]
260
+ generation = generation[0][input_len:]
261
+ decoded = self.tokenizer.decode(generation, skip_special_tokens=True)
262
+ return decoded.strip()
263
 
264
  except Exception as e:
265
+ error_msg = f"[Generation error: {e}]"
266
+ print(f"Generation error: {traceback.format_exc()}")
267
+ return f"{error_msg}\n\n{text}"
268
 
269
+ # ============================================================================
270
+ # Application State
271
+ # ============================================================================
 
272
 
273
+ class AppState:
274
+ """Application state for Spaces."""
 
 
 
275
 
276
+ def __init__(self):
277
+ self.config = SpacesConfig()
278
+ self.model = None
279
+ self.class_names = None
280
+ self.text_generator = None
281
+
282
+ def load_model(self):
283
+ """Load the detection model."""
284
+ if self.model is not None:
285
+ return
286
+
287
+ checkpoint = find_checkpoint()
288
+ if not checkpoint:
289
+ raise FileNotFoundError(
290
+ "No RF-DETR checkpoint found. Please upload rf-detr-medium.pth to your Space."
291
+ )
292
 
293
+ print(f"Loading RF-DETR from: {checkpoint}")
294
+ self.model = load_model(checkpoint, self.config.get('resolution'))
 
 
295
 
296
+ # Try to load class names
297
  try:
298
+ results_json = "/tmp/results/results.json"
299
+ if os.path.isfile(results_json):
300
+ with open(results_json, 'r') as f:
301
+ data = json.load(f)
302
+ classes = []
303
+ for split in ("valid", "test", "train"):
304
+ if "class_map" in data and split in data["class_map"]:
305
+ for item in data["class_map"][split]:
306
+ name = item.get("class")
307
+ if name and name != "all" and name not in classes:
308
+ classes.append(name)
309
+ self.class_names = classes if classes else None
310
+ except Exception:
311
+ pass
312
+
313
+ print("βœ“ RF-DETR model loaded")
314
+
315
+ def get_text_generator(self, model_size: str = "4B") -> TextGenerator:
316
+ """Get or create text generator."""
317
+ # Determine model ID based on size selection
318
+ model_id = 'google/medgemma-27b-it' if model_size == "27B" else 'google/medgemma-4b-it'
319
+
320
+ # Check if we need to create a new generator for different model size
321
+ if (self.text_generator is None or
322
+ hasattr(self.text_generator, 'model_id') and
323
+ self.text_generator.model_id != model_id):
324
+
325
+ max_tokens = self.config.get('llm_max_new_tokens')
326
+ temperature = self.config.get('llm_temperature')
327
+
328
+ self.text_generator = TextGenerator(model_id, max_tokens, temperature)
329
+ return self.text_generator
330
+
331
+ # ============================================================================
332
+ # UI and Inference
333
+ # ============================================================================
334
+
335
+ def create_detection_interface():
336
+ """Create the Gradio interface."""
337
+
338
+ # Color palette for annotations
339
+ COLOR_PALETTE = sv.ColorPalette.from_hex([
340
+ "#ffff00", "#ff9b00", "#ff66ff", "#3399ff", "#ff66b2",
341
+ "#ff8080", "#b266ff", "#9999ff", "#66ffff", "#33ff99",
342
+ "#66ff66", "#99ff00",
343
+ ])
344
+
345
+ def annotate_image(image: Image.Image, threshold: float, model_size: str = "4B") -> Tuple[Image.Image, str]:
346
+ """Process an image and return annotated version with description."""
347
+
348
+ if image is None:
349
+ return None, "Please upload an image."
350
 
351
+ try:
352
+ # Load model if needed
353
+ app_state.load_model()
354
 
355
+ # Run detection
356
+ detections = app_state.model.predict(image, threshold=threshold)
357
 
358
+ # Annotate image
359
+ bbox_annotator = sv.BoxAnnotator(color=COLOR_PALETTE, thickness=2)
360
+ label_annotator = sv.LabelAnnotator(text_scale=0.5, text_color=sv.Color.BLACK)
 
 
 
 
 
361
 
362
+ labels = []
363
+ for i in range(len(detections)):
364
+ class_id = int(detections.class_id[i]) if detections.class_id is not None else None
365
+ conf = float(detections.confidence[i]) if detections.confidence is not None else 0.0
366
 
367
+ if app_state.class_names and class_id is not None:
368
+ if 0 <= class_id < len(app_state.class_names):
369
+ label_name = app_state.class_names[class_id]
370
+ else:
371
+ label_name = str(class_id)
372
+ else:
373
+ label_name = str(class_id) if class_id is not None else "object"
374
 
375
+ labels.append(f"{label_name} {conf:.2f}")
 
376
 
377
+ annotated = image.copy()
378
+ annotated = bbox_annotator.annotate(annotated, detections)
379
+ annotated = label_annotator.annotate(annotated, detections, labels)
 
380
 
381
+ # Generate description
382
+ description = f"Found {len(detections)} detections above threshold {threshold}:\n\n"
383
+
384
+ if len(detections) > 0:
385
+ counts = {}
386
+ for i in range(len(detections)):
387
+ class_id = int(detections.class_id[i]) if detections.class_id is not None else None
388
+ if app_state.class_names and class_id is not None:
389
+ if 0 <= class_id < len(app_state.class_names):
390
+ name = app_state.class_names[class_id]
391
+ else:
392
+ name = str(class_id)
393
+ else:
394
+ name = str(class_id) if class_id is not None else "object"
395
+ counts[name] = counts.get(name, 0) + 1
396
+
397
+ for name, count in counts.items():
398
+ description += f"- {count}Γ— {name}\n"
399
+
400
+ # Use LLM for description if enabled
401
+ if app_state.config.get('use_llm'):
402
+ try:
403
+ generator = app_state.get_text_generator(model_size)
404
+ llm_description = generator.generate(description, image=annotated)
405
+ description = llm_description
406
+ except Exception as e:
407
+ description = f"[LLM error: {e}]\n\n{description}"
408
+ else:
409
+ description += "No objects detected above the confidence threshold."
410
 
411
+ return annotated, description
412
 
413
+ except Exception as e:
414
+ error_msg = f"Error processing image: {str(e)}"
415
+ print(f"Processing error: {traceback.format_exc()}")
416
+ return None, error_msg
417
+
418
+ # Create the interface
419
+ with gr.Blocks(title="Medical Image Analysis", theme=gr.themes.Soft()) as demo:
420
+ gr.Markdown("# πŸ₯ Medical Image Analysis")
421
+ gr.Markdown("Upload a medical image to detect and analyze findings using AI.")
422
+
423
+ with gr.Row():
424
+ with gr.Column():
425
+ input_image = gr.Image(type="pil", label="Upload Image", height=400)
426
+ threshold_slider = gr.Slider(
427
+ minimum=0.1,
428
+ maximum=1.0,
429
+ value=0.7,
430
+ step=0.05,
431
+ label="Confidence Threshold",
432
+ info="Higher values = fewer but more confident detections"
433
+ )
434
 
435
+ model_size_radio = gr.Radio(
436
+ choices=["4B", "27B"],
437
+ value="4B",
438
+ label="MedGemma Model Size",
439
+ info="4B: Faster, less memory | 27B: More accurate, more memory"
440
+ )
 
 
 
441
 
442
+ analyze_btn = gr.Button("πŸ” Analyze Image", variant="primary")
 
 
443
 
444
+ with gr.Column():
445
+ output_image = gr.Image(type="pil", label="Results", height=400)
446
+ output_text = gr.Textbox(
447
+ label="Analysis Results",
448
+ lines=8,
449
+ max_lines=15,
450
+ show_copy_button=True
451
+ )
452
 
453
+ # Wire up the interface
454
+ analyze_btn.click(
455
+ fn=annotate_image,
456
+ inputs=[input_image, threshold_slider, model_size_radio],
457
+ outputs=[output_image, output_text]
458
+ )
459
+
460
+ # Also run when image is uploaded
461
+ input_image.change(
462
+ fn=annotate_image,
463
+ inputs=[input_image, threshold_slider, model_size_radio],
464
+ outputs=[output_image, output_text]
465
+ )
466
+
467
+ # Footer
468
+ gr.Markdown("---")
469
+ gr.Markdown("*Powered by RF-DETR and MedGemma β€’ Built for Hugging Face Spaces*")
470
+
471
+ return demo
472
+
473
+ # ============================================================================
474
+ # Main Application
475
+ # ============================================================================
476
+
477
+ # Global app state
478
+ app_state = AppState()
479
+
480
+ def main():
481
+ """Main entry point for the Spaces app."""
482
+ print("πŸš€ Starting Medical Image Analysis App")
483
+
484
+ # Ensure results directory exists
485
+ os.makedirs(app_state.config.get('results_dir'), exist_ok=True)
486
+
487
+ # Create and launch the interface
488
+ demo = create_detection_interface()
489
+
490
+ # Launch with Spaces-optimized settings
491
+ demo.launch(
492
+ server_name="0.0.0.0",
493
+ server_port=7860,
494
+ share=False, # Spaces handles this
495
+ show_error=True,
496
+ show_api=False,
497
  )
498
 
499
  if __name__ == "__main__":
500
+ main()