samwell Claude commited on
Commit
0e36bfb
·
1 Parent(s): f6108da

refactor: Simplify segmentation to single overlay panel

Browse files

Changed from 3-panel debug view to clean single overlay:
- Removed: Original image panel (left)
- Removed: Pure mask panel (middle)
- Kept: Overlay panel with contours and legend

Benefits:
- Cleaner, more professional appearance
- Focuses on the actual result (overlay)
- Larger visualization size
- Better for production use

The 3-panel view was useful for debugging "no masks detected"
but now that segmentation works, we only need the overlay.

Co-Authored-By: Claude <[email protected]>

medrax/tools/segmentation/segmentation.py CHANGED
@@ -173,9 +173,8 @@ class ChestXRaySegmentationTool(BaseTool):
173
 
174
  def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str:
175
  """Save visualization of original image with segmentation masks overlaid."""
176
- # Create a 3-panel figure: original | mask only | overlay
177
- fig, axes = plt.subplots(1, 3, figsize=(18, 6))
178
- ax_orig, ax_mask, ax_overlay = axes
179
 
180
  # Generate color palette for organs
181
  colors = plt.cm.tab10(np.linspace(0, 1, min(len(organ_indices), 10)))
@@ -220,27 +219,10 @@ class ChestXRaySegmentationTool(BaseTool):
220
 
221
  print(f"Total masks found and rendered: {masks_found}")
222
 
223
- # Panel 1: Original image
224
- ax_orig.imshow(original_img, cmap="gray")
225
- ax_orig.set_title("Original X-Ray", fontsize=12, color='white')
226
- ax_orig.axis("off")
227
 
228
- # Panel 2: Mask only (colorized)
229
- if masks_found > 0:
230
- # Create RGB image for mask
231
- mask_rgb = np.zeros((*original_img.shape, 3))
232
- for idx, (organ_name, color) in enumerate(legend_items):
233
- mask_region = (combined_mask == idx + 1)
234
- mask_rgb[mask_region] = color[:3]
235
- ax_mask.imshow(mask_rgb)
236
- ax_mask.set_title(f"Segmentation Mask ({masks_found} organ(s))", fontsize=12, color='white')
237
- else:
238
- ax_mask.imshow(np.zeros_like(original_img), cmap="gray")
239
- ax_mask.set_title("No Masks Detected", fontsize=12, color='red')
240
- ax_mask.axis("off")
241
-
242
- # Panel 3: Overlay
243
- ax_overlay.imshow(original_img, cmap="gray")
244
  if masks_found > 0:
245
  from matplotlib.patches import Patch
246
  for idx, (organ_name, color) in enumerate(legend_items):
@@ -249,24 +231,22 @@ class ChestXRaySegmentationTool(BaseTool):
249
  # Create colored overlay
250
  overlay = np.zeros((*original_img.shape, 4))
251
  overlay[mask_region] = [color[0], color[1], color[2], 0.5]
252
- ax_overlay.imshow(overlay)
253
 
254
- # Draw contours
255
  contours = skimage.measure.find_contours(mask_region.astype(float), 0.5)
256
  for contour in contours:
257
- ax_overlay.plot(contour[:, 1], contour[:, 0], color=color, linewidth=3, alpha=0.9)
258
 
259
- ax_overlay.set_title("Overlay", fontsize=12, color='white')
260
- # Add legend to overlay panel
261
  patches = [Patch(facecolor=c, edgecolor=c, label=n, alpha=0.7) for n, c in legend_items]
262
- ax_overlay.legend(handles=patches, loc="upper right", fontsize=8, framealpha=0.9)
 
263
  else:
264
- ax_overlay.set_title("No Overlay (No Masks)", fontsize=12, color='red')
265
- ax_overlay.axis("off")
266
 
267
- # Set background
268
  fig.patch.set_facecolor('black')
269
- plt.tight_layout()
270
 
271
  save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png"
272
  plt.savefig(save_path, bbox_inches="tight", dpi=150, facecolor='black')
 
173
 
174
  def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str:
175
  """Save visualization of original image with segmentation masks overlaid."""
176
+ # Create single panel with overlay
177
+ fig, ax = plt.subplots(figsize=(10, 10))
 
178
 
179
  # Generate color palette for organs
180
  colors = plt.cm.tab10(np.linspace(0, 1, min(len(organ_indices), 10)))
 
219
 
220
  print(f"Total masks found and rendered: {masks_found}")
221
 
222
+ # Display original image
223
+ ax.imshow(original_img, cmap="gray")
 
 
224
 
225
+ # Overlay masks with contours
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  if masks_found > 0:
227
  from matplotlib.patches import Patch
228
  for idx, (organ_name, color) in enumerate(legend_items):
 
231
  # Create colored overlay
232
  overlay = np.zeros((*original_img.shape, 4))
233
  overlay[mask_region] = [color[0], color[1], color[2], 0.5]
234
+ ax.imshow(overlay)
235
 
236
+ # Draw contours for clear boundaries
237
  contours = skimage.measure.find_contours(mask_region.astype(float), 0.5)
238
  for contour in contours:
239
+ ax.plot(contour[:, 1], contour[:, 0], color=color, linewidth=3, alpha=0.9)
240
 
241
+ # Add legend
 
242
  patches = [Patch(facecolor=c, edgecolor=c, label=n, alpha=0.7) for n, c in legend_items]
243
+ ax.legend(handles=patches, loc="upper right", fontsize=10, framealpha=0.9)
244
+ ax.set_title("Segmentation Overlay", fontsize=14, color='white', pad=15)
245
  else:
246
+ ax.set_title("No Masks Detected", fontsize=14, color='red', pad=15)
 
247
 
248
+ ax.axis("off")
249
  fig.patch.set_facecolor('black')
 
250
 
251
  save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png"
252
  plt.savefig(save_path, bbox_inches="tight", dpi=150, facecolor='black')