refactor: Simplify segmentation to single overlay panel
Browse filesChanged from 3-panel debug view to clean single overlay:
- Removed: Original image panel (left)
- Removed: Pure mask panel (middle)
- Kept: Overlay panel with contours and legend
Benefits:
- Cleaner, more professional appearance
- Focuses on the actual result (overlay)
- Larger visualization size
- Better for production use
The 3-panel view was useful for debugging "no masks detected"
but now that segmentation works, we only need the overlay.
Co-Authored-By: Claude <[email protected]>
medrax/tools/segmentation/segmentation.py
CHANGED
|
@@ -173,9 +173,8 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 173 |
|
| 174 |
def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str:
|
| 175 |
"""Save visualization of original image with segmentation masks overlaid."""
|
| 176 |
-
# Create
|
| 177 |
-
fig,
|
| 178 |
-
ax_orig, ax_mask, ax_overlay = axes
|
| 179 |
|
| 180 |
# Generate color palette for organs
|
| 181 |
colors = plt.cm.tab10(np.linspace(0, 1, min(len(organ_indices), 10)))
|
|
@@ -220,27 +219,10 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 220 |
|
| 221 |
print(f"Total masks found and rendered: {masks_found}")
|
| 222 |
|
| 223 |
-
#
|
| 224 |
-
|
| 225 |
-
ax_orig.set_title("Original X-Ray", fontsize=12, color='white')
|
| 226 |
-
ax_orig.axis("off")
|
| 227 |
|
| 228 |
-
#
|
| 229 |
-
if masks_found > 0:
|
| 230 |
-
# Create RGB image for mask
|
| 231 |
-
mask_rgb = np.zeros((*original_img.shape, 3))
|
| 232 |
-
for idx, (organ_name, color) in enumerate(legend_items):
|
| 233 |
-
mask_region = (combined_mask == idx + 1)
|
| 234 |
-
mask_rgb[mask_region] = color[:3]
|
| 235 |
-
ax_mask.imshow(mask_rgb)
|
| 236 |
-
ax_mask.set_title(f"Segmentation Mask ({masks_found} organ(s))", fontsize=12, color='white')
|
| 237 |
-
else:
|
| 238 |
-
ax_mask.imshow(np.zeros_like(original_img), cmap="gray")
|
| 239 |
-
ax_mask.set_title("No Masks Detected", fontsize=12, color='red')
|
| 240 |
-
ax_mask.axis("off")
|
| 241 |
-
|
| 242 |
-
# Panel 3: Overlay
|
| 243 |
-
ax_overlay.imshow(original_img, cmap="gray")
|
| 244 |
if masks_found > 0:
|
| 245 |
from matplotlib.patches import Patch
|
| 246 |
for idx, (organ_name, color) in enumerate(legend_items):
|
|
@@ -249,24 +231,22 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 249 |
# Create colored overlay
|
| 250 |
overlay = np.zeros((*original_img.shape, 4))
|
| 251 |
overlay[mask_region] = [color[0], color[1], color[2], 0.5]
|
| 252 |
-
|
| 253 |
|
| 254 |
-
# Draw contours
|
| 255 |
contours = skimage.measure.find_contours(mask_region.astype(float), 0.5)
|
| 256 |
for contour in contours:
|
| 257 |
-
|
| 258 |
|
| 259 |
-
|
| 260 |
-
# Add legend to overlay panel
|
| 261 |
patches = [Patch(facecolor=c, edgecolor=c, label=n, alpha=0.7) for n, c in legend_items]
|
| 262 |
-
|
|
|
|
| 263 |
else:
|
| 264 |
-
|
| 265 |
-
ax_overlay.axis("off")
|
| 266 |
|
| 267 |
-
|
| 268 |
fig.patch.set_facecolor('black')
|
| 269 |
-
plt.tight_layout()
|
| 270 |
|
| 271 |
save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png"
|
| 272 |
plt.savefig(save_path, bbox_inches="tight", dpi=150, facecolor='black')
|
|
|
|
| 173 |
|
| 174 |
def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str:
|
| 175 |
"""Save visualization of original image with segmentation masks overlaid."""
|
| 176 |
+
# Create single panel with overlay
|
| 177 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
|
|
|
| 178 |
|
| 179 |
# Generate color palette for organs
|
| 180 |
colors = plt.cm.tab10(np.linspace(0, 1, min(len(organ_indices), 10)))
|
|
|
|
| 219 |
|
| 220 |
print(f"Total masks found and rendered: {masks_found}")
|
| 221 |
|
| 222 |
+
# Display original image
|
| 223 |
+
ax.imshow(original_img, cmap="gray")
|
|
|
|
|
|
|
| 224 |
|
| 225 |
+
# Overlay masks with contours
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
if masks_found > 0:
|
| 227 |
from matplotlib.patches import Patch
|
| 228 |
for idx, (organ_name, color) in enumerate(legend_items):
|
|
|
|
| 231 |
# Create colored overlay
|
| 232 |
overlay = np.zeros((*original_img.shape, 4))
|
| 233 |
overlay[mask_region] = [color[0], color[1], color[2], 0.5]
|
| 234 |
+
ax.imshow(overlay)
|
| 235 |
|
| 236 |
+
# Draw contours for clear boundaries
|
| 237 |
contours = skimage.measure.find_contours(mask_region.astype(float), 0.5)
|
| 238 |
for contour in contours:
|
| 239 |
+
ax.plot(contour[:, 1], contour[:, 0], color=color, linewidth=3, alpha=0.9)
|
| 240 |
|
| 241 |
+
# Add legend
|
|
|
|
| 242 |
patches = [Patch(facecolor=c, edgecolor=c, label=n, alpha=0.7) for n, c in legend_items]
|
| 243 |
+
ax.legend(handles=patches, loc="upper right", fontsize=10, framealpha=0.9)
|
| 244 |
+
ax.set_title("Segmentation Overlay", fontsize=14, color='white', pad=15)
|
| 245 |
else:
|
| 246 |
+
ax.set_title("No Masks Detected", fontsize=14, color='red', pad=15)
|
|
|
|
| 247 |
|
| 248 |
+
ax.axis("off")
|
| 249 |
fig.patch.set_facecolor('black')
|
|
|
|
| 250 |
|
| 251 |
save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png"
|
| 252 |
plt.savefig(save_path, bbox_inches="tight", dpi=150, facecolor='black')
|