Transformers
image-processing
medical-imaging
fundus
retinal-imaging
diabetic-retinopathy
ophthalmology
clahe
preprocessing
Instructions to use iszt/eye-clahe-processor with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use iszt/eye-clahe-processor with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("iszt/eye-clahe-processor", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| GPU-Native Eye Image Processor for Color Fundus Photography (CFP) Images. | |
| This module implements a fully PyTorch-based image processor that: | |
| 1. Localizes the eye/fundus region using gradient-based radial symmetry | |
| 2. Crops to a border-minimized square centered on the eye | |
| 3. Applies CLAHE for contrast enhancement | |
| 4. Outputs tensors compatible with Hugging Face vision models | |
| Constraints: | |
| - PyTorch only (no OpenCV, PIL, NumPy in runtime) | |
| - CUDA-compatible, batch-friendly, deterministic | |
| """ | |
| from typing import Dict, List, Optional, Union | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers.image_processing_utils import BaseImageProcessor | |
| from transformers.image_processing_base import BatchFeature | |
| # Optional imports for broader input support | |
| try: | |
| from PIL import Image | |
| PIL_AVAILABLE = True | |
| except ImportError: | |
| PIL_AVAILABLE = False | |
| try: | |
| import numpy as np | |
| NUMPY_AVAILABLE = True | |
| except ImportError: | |
| NUMPY_AVAILABLE = False | |
| # ============================================================================= | |
| # PHASE 1: Input & Tensor Standardization | |
| # ============================================================================= | |
| def _pil_to_tensor(image: "Image.Image") -> torch.Tensor: | |
| """Convert a single PIL Image to a float32 tensor of shape (C, H, W) in [0, 1]. | |
| Converts to RGB if not already. Uses numpy as intermediate when available, | |
| otherwise falls back to manual pixel extraction. | |
| """ | |
| if not PIL_AVAILABLE: | |
| raise ImportError("PIL is required to process PIL Images") | |
| # Convert to RGB if necessary | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Use numpy as intermediate if available, otherwise manual conversion | |
| if NUMPY_AVAILABLE: | |
| arr = np.array(image, dtype=np.float32) / 255.0 | |
| # (H, W, C) -> (C, H, W) | |
| tensor = torch.from_numpy(arr).permute(2, 0, 1) | |
| else: | |
| # Manual conversion without numpy | |
| width, height = image.size | |
| pixels = list(image.getdata()) | |
| tensor = torch.tensor(pixels, dtype=torch.float32).view(height, width, 3) / 255.0 | |
| tensor = tensor.permute(2, 0, 1) | |
| return tensor | |
| def _numpy_to_tensor(arr: "np.ndarray") -> torch.Tensor: | |
| """Convert a single numpy array to a float32 tensor of shape (C, H, W) in [0, 1]. | |
| Handles grayscale (H, W), HWC (H, W, C) with C in {1, 3, 4}, and uint8/float inputs. | |
| Makes a copy to avoid sharing memory with the source array. | |
| """ | |
| if not NUMPY_AVAILABLE: | |
| raise ImportError("NumPy is required to process numpy arrays") | |
| # Handle different array shapes | |
| if arr.ndim == 2: | |
| # Grayscale (H, W) -> (1, H, W) | |
| arr = arr[..., None] | |
| if arr.ndim == 3 and arr.shape[-1] in [1, 3, 4]: | |
| # (H, W, C) -> (C, H, W) | |
| arr = arr.transpose(2, 0, 1) | |
| # Convert to float and normalize | |
| if arr.dtype == np.uint8: | |
| arr = arr.astype(np.float32) / 255.0 | |
| elif arr.dtype != np.float32: | |
| arr = arr.astype(np.float32) | |
| return torch.from_numpy(arr.copy()) | |
| def standardize_input( | |
| images: Union[torch.Tensor, List[torch.Tensor], "Image.Image", List["Image.Image"], "np.ndarray", List["np.ndarray"]], | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| """Convert heterogeneous image inputs to a standardized (B, C, H, W) float32 tensor in [0, 1]. | |
| Accepts torch.Tensor, PIL.Image, numpy.ndarray, or lists thereof. Integer-typed | |
| inputs (uint8) are scaled to [0, 1]. The output is clamped to [0, 1]. | |
| Note: All images in a list must have the same spatial dimensions (required by torch.stack). | |
| A single numpy array with ndim==3 is treated as a single HWC image if the last dimension | |
| is in {1, 3, 4}; otherwise it falls through to the tensor path (assumed CHW). | |
| Args: | |
| images: Input as: | |
| - torch.Tensor (C,H,W), (B,C,H,W), or list of tensors | |
| - PIL.Image.Image or list of PIL Images | |
| - numpy.ndarray (H,W,C), (B,H,W,C), or list of arrays | |
| device: Target device (defaults to input device or CPU) | |
| Returns: | |
| Tensor of shape (B, C, H, W) in float32, range [0, 1] | |
| """ | |
| # Handle single inputs by wrapping in list | |
| if PIL_AVAILABLE and isinstance(images, Image.Image): | |
| images = [images] | |
| if NUMPY_AVAILABLE and isinstance(images, np.ndarray) and images.ndim == 3: | |
| # Could be single (H,W,C) or batch (B,H,W) grayscale - assume single if last dim is 1-4 | |
| if images.shape[-1] in [1, 3, 4]: | |
| images = [images] | |
| # Convert list inputs to tensors | |
| if isinstance(images, list): | |
| converted = [] | |
| for img in images: | |
| if PIL_AVAILABLE and isinstance(img, Image.Image): | |
| converted.append(_pil_to_tensor(img)) | |
| elif NUMPY_AVAILABLE and isinstance(img, np.ndarray): | |
| converted.append(_numpy_to_tensor(img)) | |
| elif isinstance(img, torch.Tensor): | |
| t = img if img.dim() == 3 else img.squeeze(0) | |
| converted.append(t) | |
| else: | |
| raise TypeError(f"Unsupported image type: {type(img)}") | |
| images = torch.stack(converted) | |
| elif NUMPY_AVAILABLE and isinstance(images, np.ndarray): | |
| # Batch of numpy arrays (B, H, W, C) | |
| if images.ndim == 4: | |
| images = images.transpose(0, 3, 1, 2) # (B, C, H, W) | |
| if images.dtype == np.uint8: | |
| images = images.astype(np.float32) / 255.0 | |
| images = torch.from_numpy(images.copy()) | |
| if images.dim() == 3: | |
| # Add batch dimension: (C, H, W) -> (B, C, H, W) | |
| images = images.unsqueeze(0) | |
| # Move to target device if specified | |
| if device is not None: | |
| images = images.to(device) | |
| # Convert to float32 and normalize to [0, 1] | |
| if images.dtype == torch.uint8: | |
| images = images.float() / 255.0 | |
| elif images.dtype != torch.float32: | |
| images = images.float() | |
| # Clamp to valid range | |
| images = images.clamp(0.0, 1.0) | |
| return images | |
| def standardize_mask_input( | |
| masks: Union[ | |
| torch.Tensor, | |
| List[torch.Tensor], | |
| "Image.Image", | |
| List["Image.Image"], | |
| "np.ndarray", | |
| List["np.ndarray"], | |
| ], | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| """Convert heterogeneous mask inputs to a standardized (B, 1, H, W) tensor. | |
| Unlike ``standardize_input``, this preserves the original dtype (typically integer | |
| label values) and does **not** normalize to [0, 1]. | |
| Accepts torch.Tensor, PIL.Image, numpy.ndarray, or lists thereof. | |
| A single 2-D input is treated as (H, W) and expanded to (1, 1, H, W). | |
| Args: | |
| masks: Input masks in any supported format. | |
| device: Target device. | |
| Returns: | |
| Tensor of shape (B, 1, H, W) with original dtype preserved. | |
| """ | |
| # Handle single inputs | |
| if PIL_AVAILABLE and isinstance(masks, Image.Image): | |
| masks = [masks] | |
| if NUMPY_AVAILABLE and isinstance(masks, np.ndarray) and masks.ndim == 2: | |
| masks = [masks] | |
| # Convert list inputs | |
| if isinstance(masks, list): | |
| converted = [] | |
| for m in masks: | |
| if PIL_AVAILABLE and isinstance(m, Image.Image): | |
| # PIL mask → numpy → tensor | |
| m = np.array(m) | |
| converted.append(torch.from_numpy(m)) | |
| elif NUMPY_AVAILABLE and isinstance(m, np.ndarray): | |
| converted.append(torch.from_numpy(m)) | |
| elif isinstance(m, torch.Tensor): | |
| converted.append(m) | |
| else: | |
| raise TypeError(f"Unsupported mask type: {type(m)}") | |
| masks = torch.stack(converted) | |
| elif NUMPY_AVAILABLE and isinstance(masks, np.ndarray): | |
| masks = torch.from_numpy(masks) | |
| # At this point masks is a torch.Tensor | |
| if masks.dim() == 2: | |
| # (H, W) → (1, 1, H, W) | |
| masks = masks.unsqueeze(0).unsqueeze(0) | |
| elif masks.dim() == 3: | |
| # (B, H, W) → (B, 1, H, W) | |
| masks = masks.unsqueeze(1) | |
| elif masks.dim() == 4: | |
| # Assume already (B, C, H, W) | |
| pass | |
| else: | |
| raise ValueError(f"Invalid mask shape: {masks.shape}") | |
| # Move to device | |
| if device is not None: | |
| masks = masks.to(device) | |
| return masks | |
| def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor: | |
| """Convert RGB images to grayscale via ITU-R BT.601 luminance: Y = 0.299R + 0.587G + 0.114B. | |
| Args: | |
| images: Tensor of shape (B, 3, H, W) in any value range. | |
| Returns: | |
| Tensor of shape (B, 1, H, W) in the same value range as input. | |
| """ | |
| # Luminance weights | |
| weights = torch.tensor([0.299, 0.587, 0.114], device=images.device, dtype=images.dtype) | |
| weights = weights.view(1, 3, 1, 1) | |
| grayscale = (images * weights).sum(dim=1, keepdim=True) | |
| return grayscale | |
| # ============================================================================= | |
| # PHASE 2: Eye Region Localization (GPU-Safe) | |
| # ============================================================================= | |
| def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple: | |
| """Create 3x3 Sobel edge-detection kernels for horizontal and vertical gradients. | |
| Args: | |
| device: Target device for the kernels. | |
| dtype: Target dtype for the kernels. | |
| Returns: | |
| Tuple of (sobel_x, sobel_y) kernels, each of shape (1, 1, 3, 3), | |
| suitable for use with ``F.conv2d`` on single-channel input. | |
| """ | |
| sobel_x = torch.tensor([ | |
| [-1, 0, 1], | |
| [-2, 0, 2], | |
| [-1, 0, 1] | |
| ], device=device, dtype=dtype).view(1, 1, 3, 3) | |
| sobel_y = torch.tensor([ | |
| [-1, -2, -1], | |
| [ 0, 0, 0], | |
| [ 1, 2, 1] | |
| ], device=device, dtype=dtype).view(1, 1, 3, 3) | |
| return sobel_x, sobel_y | |
| def compute_gradients(grayscale: torch.Tensor) -> tuple: | |
| """Compute horizontal and vertical image gradients using 3x3 Sobel filters. | |
| Uses reflect-free padding=1 (zero-padded convolution) to maintain spatial size. | |
| Args: | |
| grayscale: Single-channel images of shape (B, 1, H, W). | |
| Returns: | |
| Tuple of (grad_x, grad_y, grad_magnitude), each (B, 1, H, W). | |
| ``grad_magnitude`` = sqrt(grad_x^2 + grad_y^2 + 1e-8). | |
| """ | |
| sobel_x, sobel_y = create_sobel_kernels(grayscale.device, grayscale.dtype) | |
| # Apply Sobel filters with padding to maintain size | |
| grad_x = F.conv2d(grayscale, sobel_x, padding=1) | |
| grad_y = F.conv2d(grayscale, sobel_y, padding=1) | |
| # Compute gradient magnitude | |
| grad_magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8) | |
| return grad_x, grad_y, grad_magnitude | |
| def compute_radial_symmetry_response( | |
| grayscale: torch.Tensor, | |
| grad_x: torch.Tensor, | |
| grad_y: torch.Tensor, | |
| grad_magnitude: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Compute a radial-symmetry response map for circular-region detection. | |
| The algorithm: | |
| 1. Estimates an initial center as the intensity-weighted center of mass of | |
| dark regions (squared inverse intensity). | |
| 2. For each pixel, computes the dot product between the normalized gradient | |
| vector and the unit vector pointing toward the estimated center. | |
| 3. Weights this alignment score by gradient magnitude and darkness. | |
| 4. Smooths the response with a separable Gaussian whose sigma is | |
| proportional to the image size (kernel_size = max(H,W)//8, sigma = kernel_size/6). | |
| High response indicates pixels whose gradients point radially inward toward | |
| a dark center — characteristic of the fundus disc boundary. | |
| Args: | |
| grayscale: Grayscale images (B, 1, H, W) in [0, 1]. | |
| grad_x: Horizontal gradient (B, 1, H, W). | |
| grad_y: Vertical gradient (B, 1, H, W). | |
| grad_magnitude: Gradient magnitude (B, 1, H, W). | |
| Returns: | |
| Smoothed radial symmetry response map (B, 1, H, W). | |
| """ | |
| B, _, H, W = grayscale.shape | |
| device = grayscale.device | |
| dtype = grayscale.dtype | |
| # Create coordinate grids | |
| y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W) | |
| x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W) | |
| # Compute center of mass of dark regions as initial estimate | |
| # Invert intensity so dark regions have high weight | |
| dark_weight = 1.0 - grayscale | |
| dark_weight = dark_weight ** 2 # Emphasize darker regions | |
| # Normalize weights | |
| weight_sum = dark_weight.sum(dim=(2, 3), keepdim=True) + 1e-8 | |
| # Weighted center of mass | |
| cx_init = (dark_weight * x_coords).sum(dim=(2, 3), keepdim=True) / weight_sum | |
| cy_init = (dark_weight * y_coords).sum(dim=(2, 3), keepdim=True) / weight_sum | |
| # Compute vectors from each pixel to estimated center | |
| dx_to_center = cx_init - x_coords | |
| dy_to_center = cy_init - y_coords | |
| dist_to_center = torch.sqrt(dx_to_center ** 2 + dy_to_center ** 2 + 1e-8) | |
| # Normalize direction vectors | |
| dx_norm = dx_to_center / dist_to_center | |
| dy_norm = dy_to_center / dist_to_center | |
| # Normalize gradient vectors | |
| grad_norm = grad_magnitude + 1e-8 | |
| gx_norm = grad_x / grad_norm | |
| gy_norm = grad_y / grad_norm | |
| # Radial symmetry: gradient should point toward center | |
| # Dot product between gradient and direction to center | |
| radial_alignment = gx_norm * dx_norm + gy_norm * dy_norm | |
| # Weight by gradient magnitude and darkness | |
| response = radial_alignment * grad_magnitude * dark_weight | |
| # Apply Gaussian smoothing to get robust response | |
| kernel_size = max(H, W) // 8 | |
| if kernel_size % 2 == 0: | |
| kernel_size += 1 | |
| kernel_size = max(kernel_size, 5) | |
| sigma = kernel_size / 6.0 | |
| # Create 1D Gaussian kernel | |
| x = torch.arange(kernel_size, device=device, dtype=dtype) - kernel_size // 2 | |
| gaussian_1d = torch.exp(-x ** 2 / (2 * sigma ** 2)) | |
| gaussian_1d = gaussian_1d / gaussian_1d.sum() | |
| # Separable 2D convolution | |
| gaussian_1d_h = gaussian_1d.view(1, 1, 1, kernel_size) | |
| gaussian_1d_v = gaussian_1d.view(1, 1, kernel_size, 1) | |
| pad_h = kernel_size // 2 | |
| pad_v = kernel_size // 2 | |
| response = F.pad(response, (pad_h, pad_h, 0, 0), mode='reflect') | |
| response = F.conv2d(response, gaussian_1d_h) | |
| response = F.pad(response, (0, 0, pad_v, pad_v), mode='reflect') | |
| response = F.conv2d(response, gaussian_1d_v) | |
| return response | |
| def soft_argmax_2d(response: torch.Tensor, temperature: float = 0.1) -> tuple: | |
| """Find the sub-pixel peak location in a response map via softmax-weighted coordinates. | |
| Divides the flattened response by ``temperature`` before applying softmax, then | |
| computes the weighted mean of the (x, y) coordinate grids. Lower temperature yields | |
| a sharper, more argmax-like result; higher temperature yields a broader average. | |
| Caution: Very low temperatures (< 0.01) combined with large response magnitudes | |
| can cause numerical overflow in the softmax exponential. | |
| Args: | |
| response: Response map (B, 1, H, W). | |
| temperature: Softmax temperature. Default 0.1. | |
| Returns: | |
| Tuple of (cx, cy), each of shape (B,), in pixel coordinates. | |
| """ | |
| B, _, H, W = response.shape | |
| device = response.device | |
| dtype = response.dtype | |
| # Flatten spatial dimensions | |
| response_flat = response.view(B, -1) | |
| # Apply softmax with temperature | |
| weights = F.softmax(response_flat / temperature, dim=1) | |
| weights = weights.view(B, 1, H, W) | |
| # Create coordinate grids | |
| y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W) | |
| x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W) | |
| # Weighted sum of coordinates | |
| cx = (weights * x_coords).sum(dim=(2, 3)).squeeze(-1) # (B,) | |
| cy = (weights * y_coords).sum(dim=(2, 3)).squeeze(-1) # (B,) | |
| return cx, cy | |
| def estimate_eye_center( | |
| images: torch.Tensor, | |
| softmax_temperature: float = 0.1, | |
| ) -> tuple: | |
| """Estimate the center of the fundus/eye disc in each image. | |
| Pipeline: RGB → grayscale → Sobel gradients → radial symmetry response → soft argmax. | |
| Args: | |
| images: RGB images of shape (B, 3, H, W) in [0, 1]. | |
| softmax_temperature: Temperature for the soft-argmax peak finder. | |
| Lower values (0.01-0.1) give sharper localization; higher values | |
| (0.3-0.5) give broader averaging, useful for noisy or low-contrast images. | |
| Default 0.1. | |
| Returns: | |
| Tuple of (cx, cy), each of shape (B,), in pixel coordinates. | |
| """ | |
| grayscale = rgb_to_grayscale(images) | |
| grad_x, grad_y, grad_magnitude = compute_gradients(grayscale) | |
| response = compute_radial_symmetry_response(grayscale, grad_x, grad_y, grad_magnitude) | |
| cx, cy = soft_argmax_2d(response, temperature=softmax_temperature) | |
| return cx, cy | |
| # ============================================================================= | |
| # PHASE 2.3: Radius Estimation | |
| # ============================================================================= | |
| def estimate_radius( | |
| images: torch.Tensor, | |
| cx: torch.Tensor, | |
| cy: torch.Tensor, | |
| num_radii: int = 100, | |
| num_angles: int = 36, | |
| min_radius_frac: float = 0.1, | |
| max_radius_frac: float = 0.5, | |
| ) -> torch.Tensor: | |
| """Estimate the radius of the fundus disc by analyzing radial intensity profiles. | |
| Samples grayscale intensity along ``num_angles`` rays emanating from ``(cx, cy)`` | |
| at ``num_radii`` radial distances. The per-radius mean intensity across all angles | |
| gives a 1-D radial profile. The discrete derivative of this profile is linearly | |
| weighted by radius (range 0.5–1.5) to bias toward the outer fundus boundary | |
| rather than the smaller pupil boundary. The radius at the strongest weighted | |
| negative gradient is selected as the disc edge. | |
| Uses ``F.grid_sample`` with bilinear interpolation and border padding for | |
| sub-pixel sampling. | |
| Args: | |
| images: RGB images (B, 3, H, W) in [0, 1]. | |
| cx, cy: Center coordinates (B,) in pixel units. | |
| num_radii: Number of radial sample points. Default 100. | |
| num_angles: Number of angular sample rays. Default 36. | |
| min_radius_frac: Minimum search radius as fraction of min(H, W). Default 0.1. | |
| max_radius_frac: Maximum search radius as fraction of min(H, W). Default 0.5. | |
| Returns: | |
| Estimated radius for each image (B,), clamped to [min_radius, max_radius]. | |
| """ | |
| B, _, H, W = images.shape | |
| device = images.device | |
| dtype = images.dtype | |
| grayscale = rgb_to_grayscale(images) # (B, 1, H, W) | |
| min_dim = min(H, W) | |
| min_radius = int(min_radius_frac * min_dim) | |
| max_radius = int(max_radius_frac * min_dim) | |
| # Create radius and angle samples | |
| radii = torch.linspace(min_radius, max_radius, num_radii, device=device, dtype=dtype) | |
| angles = torch.linspace(0, 2 * math.pi, num_angles + 1, device=device, dtype=dtype)[:-1] | |
| # Create sampling grid: (num_angles, num_radii) | |
| cos_angles = torch.cos(angles).view(-1, 1) # (num_angles, 1) | |
| sin_angles = torch.sin(angles).view(-1, 1) # (num_angles, 1) | |
| # Offset coordinates from center | |
| dx = cos_angles * radii # (num_angles, num_radii) | |
| dy = sin_angles * radii # (num_angles, num_radii) | |
| # Compute absolute coordinates for each batch item | |
| # cx, cy: (B,) -> expand to (B, num_angles, num_radii) | |
| cx_expanded = cx.view(B, 1, 1).expand(B, num_angles, num_radii) | |
| cy_expanded = cy.view(B, 1, 1).expand(B, num_angles, num_radii) | |
| sample_x = cx_expanded + dx.unsqueeze(0) # (B, num_angles, num_radii) | |
| sample_y = cy_expanded + dy.unsqueeze(0) # (B, num_angles, num_radii) | |
| # Normalize to [-1, 1] for grid_sample | |
| sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0 | |
| sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0 | |
| # Create sampling grid: (B, num_angles, num_radii, 2) | |
| grid = torch.stack([sample_x_norm, sample_y_norm], dim=-1) | |
| # Sample intensities | |
| sampled = F.grid_sample( | |
| grayscale, grid, mode='bilinear', padding_mode='border', align_corners=True | |
| ) # (B, 1, num_angles, num_radii) | |
| # Average over angles to get radial profile | |
| radial_profile = sampled.mean(dim=2).squeeze(1) # (B, num_radii) | |
| # Compute gradient of radial profile (looking for strong negative gradient at iris edge) | |
| radial_gradient = radial_profile[:, 1:] - radial_profile[:, :-1] # (B, num_radii-1) | |
| # Find the radius with strongest negative gradient (edge of iris) | |
| # Weight by radius to prefer larger circles (avoid pupil boundary) | |
| radius_weights = torch.linspace(0.5, 1.5, num_radii - 1, device=device, dtype=dtype) | |
| weighted_gradient = radial_gradient * radius_weights.unsqueeze(0) | |
| # Find minimum (strongest negative gradient) | |
| min_idx = weighted_gradient.argmin(dim=1) # (B,) | |
| # Convert index to radius value | |
| estimated_radius = radii[min_idx + 1] # +1 because gradient has one less element | |
| # Clamp to valid range | |
| estimated_radius = estimated_radius.clamp(min_radius, max_radius) | |
| return estimated_radius | |
| # ============================================================================= | |
| # PHASE 3: Border-Minimized Square Crop | |
| # ============================================================================= | |
| def compute_crop_box( | |
| cx: torch.Tensor, | |
| cy: torch.Tensor, | |
| radius: torch.Tensor, | |
| H: int, | |
| W: int, | |
| scale_factor: float = 1.1, | |
| allow_overflow: bool = False, | |
| ) -> tuple: | |
| """Compute a square bounding box centered on the detected eye. | |
| The half-side length is ``radius * scale_factor``. When ``allow_overflow`` is | |
| False, the box is clamped to the image bounds and then made square by shrinking | |
| to the shorter side and re-centering. The resulting box is guaranteed to be | |
| square and fully within [0, W-1] x [0, H-1]. | |
| When ``allow_overflow`` is True the raw (possibly out-of-bounds) box is | |
| returned, which is useful for images where the fundus disc is partially | |
| clipped; out-of-bounds regions will be zero-filled during grid_sample. | |
| Args: | |
| cx, cy: Detected eye center coordinates (B,). | |
| radius: Estimated disc radius (B,). | |
| H, W: Spatial dimensions of the source images. | |
| scale_factor: Padding multiplier applied to ``radius``. Default 1.1. | |
| allow_overflow: Skip clamping / squareness enforcement. Default False. | |
| Returns: | |
| Tuple of (x1, y1, x2, y2), each of shape (B,), in pixel coordinates. | |
| """ | |
| # Compute half side length | |
| half_side = radius * scale_factor | |
| # Initial box centered on detected eye | |
| x1 = cx - half_side | |
| y1 = cy - half_side | |
| x2 = cx + half_side | |
| y2 = cy + half_side | |
| if allow_overflow: | |
| # Keep the box centered on the eye, don't clamp | |
| # Out-of-bounds regions will be filled with black during cropping | |
| return x1, y1, x2, y2 | |
| # Clamp to image bounds while maintaining square shape | |
| # If box exceeds bounds, shift it | |
| x1 = x1.clamp(min=0) | |
| y1 = y1.clamp(min=0) | |
| x2 = x2.clamp(max=W - 1) | |
| y2 = y2.clamp(max=H - 1) | |
| # Ensure square by taking minimum side | |
| side_x = x2 - x1 | |
| side_y = y2 - y1 | |
| side = torch.minimum(side_x, side_y) | |
| # Recenter the box | |
| cx_new = (x1 + x2) / 2 | |
| cy_new = (y1 + y2) / 2 | |
| x1 = (cx_new - side / 2).clamp(min=0) | |
| y1 = (cy_new - side / 2).clamp(min=0) | |
| x2 = x1 + side | |
| y2 = y1 + side | |
| # Final clamp | |
| x2 = x2.clamp(max=W - 1) | |
| y2 = y2.clamp(max=H - 1) | |
| return x1, y1, x2, y2 | |
| def batch_crop_and_resize( | |
| images: torch.Tensor, | |
| x1: torch.Tensor, | |
| y1: torch.Tensor, | |
| x2: torch.Tensor, | |
| y2: torch.Tensor, | |
| output_size: int, | |
| padding_mode: str = 'border', | |
| ) -> torch.Tensor: | |
| """Crop and resize images to a square using ``F.grid_sample`` (GPU-friendly). | |
| Builds a regular output grid in [0, 1]^2, maps it to the source rectangle | |
| [x1, x2] x [y1, y2] via affine scaling, normalizes to [-1, 1] for | |
| ``grid_sample``, and samples with bilinear interpolation (``align_corners=True``). | |
| Crop coordinates may extend beyond image bounds; the ``padding_mode`` | |
| controls how out-of-bounds pixels are filled. | |
| Args: | |
| images: Input images (B, C, H, W). | |
| x1, y1, x2, y2: Crop box corners (B,). May exceed [0, W-1] / [0, H-1]. | |
| output_size: Side length of the square output. | |
| padding_mode: ``'border'`` (repeat edge, default) or ``'zeros'`` (black fill). | |
| Returns: | |
| Cropped and resized images (B, C, output_size, output_size). | |
| """ | |
| B, C, H, W = images.shape | |
| device = images.device | |
| dtype = images.dtype | |
| # Create output grid coordinates | |
| out_coords = torch.linspace(0, 1, output_size, device=device, dtype=dtype) | |
| out_y, out_x = torch.meshgrid(out_coords, out_coords, indexing='ij') | |
| out_grid = torch.stack([out_x, out_y], dim=-1) # (output_size, output_size, 2) | |
| out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1) # (B, output_size, output_size, 2) | |
| # Scale grid to crop coordinates | |
| # out_grid is in [0, 1], need to map to [x1, x2] and [y1, y2] | |
| x1 = x1.view(B, 1, 1, 1) | |
| y1 = y1.view(B, 1, 1, 1) | |
| x2 = x2.view(B, 1, 1, 1) | |
| y2 = y2.view(B, 1, 1, 1) | |
| # Map [0, 1] to pixel coordinates | |
| sample_x = x1 + out_grid[..., 0:1] * (x2 - x1) | |
| sample_y = y1 + out_grid[..., 1:2] * (y2 - y1) | |
| # Normalize to [-1, 1] for grid_sample | |
| sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0 | |
| sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0 | |
| grid = torch.cat([sample_x_norm, sample_y_norm], dim=-1) # (B, output_size, output_size, 2) | |
| # Sample with specified padding mode | |
| cropped = F.grid_sample( | |
| images, grid, mode='bilinear', padding_mode=padding_mode, align_corners=True | |
| ) | |
| return cropped | |
| #def batch_crop_and_resize_mask( | |
| # masks: torch.Tensor, | |
| # x1: torch.Tensor, | |
| # y1: torch.Tensor, | |
| # x2: torch.Tensor, | |
| # y2: torch.Tensor, | |
| # output_size: int, | |
| # padding_mode: str = "zeros", | |
| #) -> torch.Tensor: | |
| # """ | |
| # Crop and resize masks using nearest-neighbor sampling. | |
| # """ | |
| # return batch_crop_and_resize( | |
| # masks, | |
| # x1, y1, x2, y2, | |
| # output_size, | |
| # padding_mode=padding_mode, | |
| # ) | |
| def batch_crop_and_resize_mask( | |
| masks: torch.Tensor, # (B, 1, H, W) | |
| x1: torch.Tensor, | |
| y1: torch.Tensor, | |
| x2: torch.Tensor, | |
| y2: torch.Tensor, | |
| output_size: int, | |
| padding_mode: str = "zeros", | |
| ) -> torch.Tensor: | |
| """Crop and resize segmentation masks using nearest-neighbor sampling. | |
| Same spatial transform as ``batch_crop_and_resize`` but uses ``mode='nearest'`` | |
| to preserve discrete label values. The output is rounded and cast to ``torch.long`` | |
| to guard against floating-point drift in ``grid_sample``. | |
| Args: | |
| masks: Integer label masks (B, 1, H, W) — any dtype (converted to float internally). | |
| x1, y1, x2, y2: Crop box corners (B,). May exceed image bounds. | |
| output_size: Side length of the square output. | |
| padding_mode: ``'zeros'`` (background = 0, default) or ``'border'`` (repeat edge). | |
| Returns: | |
| Cropped and resized masks (B, 1, output_size, output_size) as ``torch.long``. | |
| """ | |
| B, C, H, W = masks.shape | |
| device = masks.device | |
| # grid_sample requires floating point input | |
| masks_f = masks.float() | |
| # Create output grid in [0, 1] | |
| coords = torch.linspace(0, 1, output_size, device=device) | |
| out_y, out_x = torch.meshgrid(coords, coords, indexing="ij") | |
| out_grid = torch.stack([out_x, out_y], dim=-1) # (S, S, 2) | |
| out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1) | |
| # Reshape crop boxes | |
| x1 = x1.view(B, 1, 1, 1) | |
| y1 = y1.view(B, 1, 1, 1) | |
| x2 = x2.view(B, 1, 1, 1) | |
| y2 = y2.view(B, 1, 1, 1) | |
| # Map [0, 1] → pixel coordinates | |
| sample_x = x1 + out_grid[..., 0:1] * (x2 - x1) | |
| sample_y = y1 + out_grid[..., 1:2] * (y2 - y1) | |
| # Normalize to [-1, 1] | |
| sample_x = 2.0 * sample_x / (W - 1) - 1.0 | |
| sample_y = 2.0 * sample_y / (H - 1) - 1.0 | |
| grid = torch.cat([sample_x, sample_y], dim=-1) | |
| # Nearest-neighbor sampling with caller-specified padding | |
| cropped = F.grid_sample( | |
| masks_f, | |
| grid, | |
| mode="nearest", | |
| padding_mode=padding_mode, | |
| align_corners=True, | |
| ) | |
| # Round before converting to handle floating point errors from grid_sample. | |
| # Even with mode="nearest", grid_sample can produce values like 0.9999999 | |
| # which would truncate to 0 instead of rounding to 1. | |
| return cropped.round().long() | |
| # ============================================================================= | |
| # PHASE 4: CLAHE (Torch-Native) | |
| # ============================================================================= | |
| def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor: | |
| """Apply the sRGB electro-optical transfer function (EOTF) to convert sRGB to linear RGB. | |
| Uses the IEC 61966-2-1 piecewise formula with threshold 0.04045. | |
| """ | |
| threshold = 0.04045 | |
| linear = torch.where( | |
| rgb <= threshold, | |
| rgb / 12.92, | |
| ((rgb + 0.055) / 1.055) ** 2.4 | |
| ) | |
| return linear | |
| def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor: | |
| """Apply the inverse sRGB EOTF to convert linear RGB to sRGB. | |
| Uses the IEC 61966-2-1 piecewise formula with threshold 0.0031308. | |
| Input must be non-negative; negative values will produce NaN from the power function. | |
| """ | |
| threshold = 0.0031308 | |
| srgb = torch.where( | |
| linear <= threshold, | |
| linear * 12.92, | |
| 1.055 * (linear ** (1.0 / 2.4)) - 0.055 | |
| ) | |
| return srgb | |
| def rgb_to_lab(images: torch.Tensor) -> tuple: | |
| """Convert sRGB images to CIE LAB colour space (D65 illuminant). | |
| Conversion chain: sRGB → linear RGB → CIE XYZ → CIE LAB. | |
| The raw LAB values are rescaled for internal convenience: | |
| - L ∈ [0, 100] → L / 100 → [0, 1] | |
| - a ∈ ~[-128, 127] → a / 256 + 0.5 → ~[0, 1] | |
| - b ∈ ~[-128, 127] → b / 256 + 0.5 → ~[0, 1] | |
| These normalised values are **not** standard LAB; use ``lab_to_rgb`` to | |
| invert them back to sRGB. | |
| Args: | |
| images: RGB images (B, 3, H, W) in [0, 1] sRGB. | |
| Returns: | |
| Tuple of (L, a, b_ch), each (B, 1, H, W): | |
| - L: Normalised luminance in [0, 1]. | |
| - a: Normalised green–red chrominance, roughly [0, 1]. | |
| - b_ch: Normalised blue–yellow chrominance, roughly [0, 1]. | |
| """ | |
| device = images.device | |
| dtype = images.dtype | |
| # Step 1: sRGB to linear RGB | |
| linear_rgb = _srgb_to_linear(images) | |
| # Step 2: Linear RGB to XYZ (D65 illuminant) | |
| # RGB to XYZ matrix | |
| r = linear_rgb[:, 0:1, :, :] | |
| g = linear_rgb[:, 1:2, :, :] | |
| b = linear_rgb[:, 2:3, :, :] | |
| x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b | |
| y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b | |
| z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b | |
| # D65 reference white | |
| xn, yn, zn = 0.95047, 1.0, 1.08883 | |
| x = x / xn | |
| y = y / yn | |
| z = z / zn | |
| # Step 3: XYZ to LAB | |
| delta = 6.0 / 29.0 | |
| delta_cube = delta ** 3 | |
| def f(t): | |
| return torch.where( | |
| t > delta_cube, | |
| t ** (1.0 / 3.0), | |
| t / (3.0 * delta ** 2) + 4.0 / 29.0 | |
| ) | |
| fx = f(x) | |
| fy = f(y) | |
| fz = f(z) | |
| L = 116.0 * fy - 16.0 # Range [0, 100] | |
| a = 500.0 * (fx - fy) # Range roughly [-128, 127] | |
| b_ch = 200.0 * (fy - fz) # Range roughly [-128, 127] | |
| # Normalize to convenient ranges for processing | |
| L = L / 100.0 # [0, 1] | |
| a = a / 256.0 + 0.5 # Roughly [0, 1] | |
| b_ch = b_ch / 256.0 + 0.5 # Roughly [0, 1] | |
| return L, a, b_ch | |
| def lab_to_rgb(L: torch.Tensor, a: torch.Tensor, b_ch: torch.Tensor) -> torch.Tensor: | |
| """Convert normalised CIE LAB back to sRGB (inverse of ``rgb_to_lab``). | |
| Denormalisation: L*100, (a-0.5)*256, (b_ch-0.5)*256, then LAB → XYZ → linear RGB → sRGB. | |
| Output is clamped to [0, 1]. | |
| Args: | |
| L: Normalised luminance (B, 1, H, W) in [0, 1]. | |
| a: Normalised green–red chrominance (B, 1, H, W), roughly [0, 1]. | |
| b_ch: Normalised blue–yellow chrominance (B, 1, H, W), roughly [0, 1]. | |
| Returns: | |
| sRGB images (B, 3, H, W) clamped to [0, 1]. | |
| """ | |
| # Denormalize | |
| L_lab = L * 100.0 | |
| a_lab = (a - 0.5) * 256.0 | |
| b_lab = (b_ch - 0.5) * 256.0 | |
| # LAB to XYZ | |
| fy = (L_lab + 16.0) / 116.0 | |
| fx = a_lab / 500.0 + fy | |
| fz = fy - b_lab / 200.0 | |
| delta = 6.0 / 29.0 | |
| def f_inv(t): | |
| return torch.where( | |
| t > delta, | |
| t ** 3, | |
| 3.0 * (delta ** 2) * (t - 4.0 / 29.0) | |
| ) | |
| # D65 reference white | |
| xn, yn, zn = 0.95047, 1.0, 1.08883 | |
| x = xn * f_inv(fx) | |
| y = yn * f_inv(fy) | |
| z = zn * f_inv(fz) | |
| # XYZ to linear RGB | |
| r = 3.2404542 * x - 1.5371385 * y - 0.4985314 * z | |
| g = -0.9692660 * x + 1.8760108 * y + 0.0415560 * z | |
| b = 0.0556434 * x - 0.2040259 * y + 1.0572252 * z | |
| linear_rgb = torch.cat([r, g, b], dim=1) | |
| # Clamp before gamma correction to avoid NaN from negative values | |
| linear_rgb = linear_rgb.clamp(0.0, 1.0) | |
| # Linear RGB to sRGB | |
| srgb = _linear_to_srgb(linear_rgb) | |
| return srgb.clamp(0.0, 1.0) | |
| def compute_histogram( | |
| tensor: torch.Tensor, | |
| num_bins: int = 256, | |
| ) -> torch.Tensor: | |
| """Compute per-image histograms for a batch of single-channel images. | |
| Bins are uniformly spaced over [0, 1]. Each pixel is assigned to a bin via | |
| ``floor(value * (num_bins - 1))``, accumulated with ``scatter_add`` in a | |
| per-sample loop. | |
| Note: This function is used only by ``clahe_single_tile``. | |
| The vectorized CLAHE path (``apply_clahe_vectorized``) computes histograms | |
| inline for better GPU efficiency. | |
| Args: | |
| tensor: Input (B, 1, H, W) with values in [0, 1]. | |
| num_bins: Number of histogram bins. Default 256. | |
| Returns: | |
| Histograms of shape (B, num_bins), dtype matching input. | |
| """ | |
| B = tensor.shape[0] | |
| device = tensor.device | |
| dtype = tensor.dtype | |
| # Flatten spatial dimensions | |
| flat = tensor.view(B, -1) # (B, H*W) | |
| # Bin indices | |
| bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1) | |
| # Compute histogram using scatter_add | |
| histograms = torch.zeros(B, num_bins, device=device, dtype=dtype) | |
| ones = torch.ones_like(flat, dtype=dtype) | |
| for i in range(B): | |
| histograms[i] = histograms[i].scatter_add(0, bin_indices[i], ones[i]) | |
| return histograms | |
| def clahe_single_tile( | |
| tile: torch.Tensor, | |
| clip_limit: float, | |
| num_bins: int = 256, | |
| ) -> torch.Tensor: | |
| """Compute the clipped-and-redistributed CDF for a single CLAHE tile. | |
| Clips the histogram so no bin exceeds ``clip_limit * num_pixels / num_bins``, | |
| redistributes the excess uniformly, then computes and min-max normalises the CDF. | |
| Note: This function is not used by the main pipeline — see | |
| ``apply_clahe_vectorized`` which processes all tiles in a single pass. | |
| Args: | |
| tile: Single-channel tile images (B, 1, tile_h, tile_w) in [0, 1]. | |
| clip_limit: Relative clip limit (higher = less contrast limiting). | |
| num_bins: Number of histogram bins. Default 256. | |
| Returns: | |
| Normalised CDF lookup table (B, num_bins) in [0, 1]. | |
| """ | |
| B, _, tile_h, tile_w = tile.shape | |
| device = tile.device | |
| dtype = tile.dtype | |
| num_pixels = tile_h * tile_w | |
| # Compute histogram | |
| hist = compute_histogram(tile, num_bins) # (B, num_bins) | |
| # Clip histogram | |
| clip_value = clip_limit * num_pixels / num_bins | |
| excess = (hist - clip_value).clamp(min=0).sum(dim=1, keepdim=True) # (B, 1) | |
| hist = hist.clamp(max=clip_value) | |
| # Redistribute excess uniformly | |
| redistribution = excess / num_bins | |
| hist = hist + redistribution | |
| # Compute CDF | |
| cdf = hist.cumsum(dim=1) # (B, num_bins) | |
| # Normalize CDF to [0, 1] | |
| cdf_min = cdf[:, 0:1] | |
| cdf_max = cdf[:, -1:] | |
| cdf = (cdf - cdf_min) / (cdf_max - cdf_min + 1e-8) | |
| return cdf | |
| def apply_clahe_vectorized( | |
| images: torch.Tensor, | |
| grid_size: int = 8, | |
| clip_limit: float = 2.0, | |
| num_bins: int = 256, | |
| ) -> torch.Tensor: | |
| """Fully-vectorized CLAHE (Contrast Limited Adaptive Histogram Equalisation). | |
| For RGB input, converts to CIE LAB, applies CLAHE to the L channel only, | |
| then converts back to sRGB. For single-channel input, operates directly. | |
| Algorithm: | |
| 1. Pads the luminance channel to be divisible by ``grid_size`` (reflect padding). | |
| 2. Reshapes into ``grid_size x grid_size`` non-overlapping tiles. | |
| 3. Computes a histogram per tile via ``scatter_add_`` (fully batched, no loops). | |
| 4. Clips each histogram at ``clip_limit * num_pixels / num_bins`` and | |
| redistributes excess counts uniformly across all bins. | |
| 5. Computes the cumulative distribution function (CDF) per tile and | |
| min-max normalises it to [0, 1]. | |
| 6. Maps each output pixel to the four surrounding tile centres and | |
| bilinearly interpolates their CDF values for a smooth result. | |
| Args: | |
| images: Input images (B, C, H, W) in [0, 1]. C must be 1 or 3. | |
| grid_size: Tile grid resolution (tiles per axis). Default 8. | |
| clip_limit: Relative clip limit for histogram clipping. Default 2.0. | |
| num_bins: Number of histogram bins. Default 256. | |
| Returns: | |
| CLAHE-enhanced images (B, C, H, W) in [0, 1]. | |
| """ | |
| B, C, H, W = images.shape | |
| device = images.device | |
| dtype = images.dtype | |
| # Work on luminance only | |
| if C == 3: | |
| L, a, b_ch = rgb_to_lab(images) | |
| else: | |
| L = images.clone() | |
| a = b_ch = None | |
| # Ensure divisibility | |
| pad_h = (grid_size - H % grid_size) % grid_size | |
| pad_w = (grid_size - W % grid_size) % grid_size | |
| if pad_h > 0 or pad_w > 0: | |
| L_padded = F.pad(L, (0, pad_w, 0, pad_h), mode='reflect') | |
| else: | |
| L_padded = L | |
| _, _, H_pad, W_pad = L_padded.shape | |
| tile_h = H_pad // grid_size | |
| tile_w = W_pad // grid_size | |
| # Reshape into tiles: (B, 1, grid_size, tile_h, grid_size, tile_w) | |
| L_tiles = L_padded.view(B, 1, grid_size, tile_h, grid_size, tile_w) | |
| L_tiles = L_tiles.permute(0, 2, 4, 1, 3, 5) # (B, grid_size, grid_size, 1, tile_h, tile_w) | |
| L_tiles = L_tiles.reshape(B * grid_size * grid_size, 1, tile_h, tile_w) | |
| # Compute histograms for all tiles at once | |
| num_pixels = tile_h * tile_w | |
| flat = L_tiles.view(B * grid_size * grid_size, -1) | |
| bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1) | |
| # Vectorized histogram computation | |
| histograms = torch.zeros(B * grid_size * grid_size, num_bins, device=device, dtype=dtype) | |
| histograms.scatter_add_(1, bin_indices, torch.ones_like(flat)) | |
| # Clip and redistribute | |
| clip_value = clip_limit * num_pixels / num_bins | |
| excess = (histograms - clip_value).clamp(min=0).sum(dim=1, keepdim=True) | |
| histograms = histograms.clamp(max=clip_value) | |
| histograms = histograms + excess / num_bins | |
| # Compute CDFs | |
| cdfs = histograms.cumsum(dim=1) | |
| cdf_min = cdfs[:, 0:1] | |
| cdf_max = cdfs[:, -1:] | |
| cdfs = (cdfs - cdf_min) / (cdf_max - cdf_min + 1e-8) | |
| # Reshape CDFs: (B, grid_size, grid_size, num_bins) | |
| cdfs = cdfs.view(B, grid_size, grid_size, num_bins) | |
| # Create coordinate grids for interpolation | |
| y_coords = torch.arange(H_pad, device=device, dtype=dtype) | |
| x_coords = torch.arange(W_pad, device=device, dtype=dtype) | |
| # Map to tile coordinates (centered on tiles) | |
| tile_y = (y_coords + 0.5) / tile_h - 0.5 | |
| tile_x = (x_coords + 0.5) / tile_w - 0.5 | |
| tile_y = tile_y.clamp(0, grid_size - 1.001) | |
| tile_x = tile_x.clamp(0, grid_size - 1.001) | |
| # Integer indices and weights | |
| ty0 = tile_y.long().clamp(0, grid_size - 2) | |
| tx0 = tile_x.long().clamp(0, grid_size - 2) | |
| ty1 = (ty0 + 1).clamp(max=grid_size - 1) | |
| tx1 = (tx0 + 1).clamp(max=grid_size - 1) | |
| wy = (tile_y - ty0.float()).view(1, H_pad, 1, 1) | |
| wx = (tile_x - tx0.float()).view(1, 1, W_pad, 1) | |
| # Get bin indices for all pixels | |
| bin_idx = (L_padded * (num_bins - 1)).long().clamp(0, num_bins - 1) # (B, 1, H_pad, W_pad) | |
| bin_idx = bin_idx.squeeze(1) # (B, H_pad, W_pad) | |
| # Gather CDF values for each corner | |
| # We need cdfs[b, ty, tx, bin_idx[b, y, x]] for all combinations | |
| # Expand indices for gathering | |
| b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, H_pad, W_pad) | |
| ty0_exp = ty0.view(1, H_pad, 1).expand(B, H_pad, W_pad) | |
| ty1_exp = ty1.view(1, H_pad, 1).expand(B, H_pad, W_pad) | |
| tx0_exp = tx0.view(1, 1, W_pad).expand(B, H_pad, W_pad) | |
| tx1_exp = tx1.view(1, 1, W_pad).expand(B, H_pad, W_pad) | |
| # Gather using advanced indexing | |
| v00 = cdfs[b_idx, ty0_exp, tx0_exp, bin_idx] # (B, H_pad, W_pad) | |
| v01 = cdfs[b_idx, ty0_exp, tx1_exp, bin_idx] | |
| v10 = cdfs[b_idx, ty1_exp, tx0_exp, bin_idx] | |
| v11 = cdfs[b_idx, ty1_exp, tx1_exp, bin_idx] | |
| # Bilinear interpolation | |
| wy = wy.squeeze(-1) # (1, H_pad, 1) | |
| wx = wx.squeeze(-1) # (1, 1, W_pad) | |
| L_out = (1 - wy) * (1 - wx) * v00 + (1 - wy) * wx * v01 + wy * (1 - wx) * v10 + wy * wx * v11 | |
| L_out = L_out.unsqueeze(1) # (B, 1, H_pad, W_pad) | |
| # Remove padding | |
| if pad_h > 0 or pad_w > 0: | |
| L_out = L_out[:, :, :H, :W] | |
| # Convert back to RGB | |
| if C == 3: | |
| output = lab_to_rgb(L_out, a, b_ch) | |
| else: | |
| output = L_out | |
| return output | |
| # ============================================================================= | |
| # PHASE 5: Resize & Normalization | |
| # ============================================================================= | |
| # ImageNet normalization constants | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| def resize_images( | |
| images: torch.Tensor, | |
| size: int, | |
| mode: str = 'bilinear', | |
| antialias: bool = True, | |
| ) -> torch.Tensor: | |
| """Resize images to a square target size using ``F.interpolate``. | |
| Args: | |
| images: Input images (B, C, H, W). Must be float for bilinear/bicubic modes. | |
| size: Target side length (output is always square). | |
| mode: Interpolation mode (``'bilinear'``, ``'bicubic'``, ``'nearest'``, etc.). | |
| Default ``'bilinear'``. | |
| antialias: Enable antialiasing for bilinear/bicubic downscaling. Default True. | |
| Returns: | |
| Resized images (B, C, size, size). | |
| """ | |
| return F.interpolate( | |
| images, | |
| size=(size, size), | |
| mode=mode, | |
| align_corners=False if mode in ['bilinear', 'bicubic'] else None, | |
| antialias=antialias if mode in ['bilinear', 'bicubic'] else False, | |
| ) | |
| def normalize_images( | |
| images: torch.Tensor, | |
| mean: Optional[List[float]] = None, | |
| std: Optional[List[float]] = None, | |
| mode: str = 'imagenet', | |
| ) -> torch.Tensor: | |
| """Channel-wise normalisation: ``(image - mean) / std``. | |
| Args: | |
| images: Input images (B, C, H, W) in [0, 1]. | |
| mean: Per-channel means (length C). Required when ``mode='custom'``. | |
| std: Per-channel stds (length C). Required when ``mode='custom'``. | |
| mode: ``'imagenet'`` (uses ImageNet stats), ``'none'`` (identity), or | |
| ``'custom'`` (uses caller-supplied mean/std). Default ``'imagenet'``. | |
| Returns: | |
| Normalised images (B, C, H, W). Range depends on mean/std. | |
| """ | |
| if mode == 'none': | |
| return images | |
| if mode == 'imagenet': | |
| mean = IMAGENET_MEAN | |
| std = IMAGENET_STD | |
| elif mode == 'custom': | |
| if mean is None or std is None: | |
| raise ValueError("Custom mode requires mean and std") | |
| else: | |
| raise ValueError(f"Unknown normalization mode: {mode}") | |
| device = images.device | |
| dtype = images.dtype | |
| mean_tensor = torch.tensor(mean, device=device, dtype=dtype).view(1, -1, 1, 1) | |
| std_tensor = torch.tensor(std, device=device, dtype=dtype).view(1, -1, 1, 1) | |
| return (images - mean_tensor) / std_tensor | |
| # ============================================================================= | |
| # PHASE 6: Hugging Face ImageProcessor Integration | |
| # ============================================================================= | |
| class EyeCLAHEImageProcessor(BaseImageProcessor): | |
| """GPU-native Hugging Face image processor for Colour Fundus Photography (CFP). | |
| Processing pipeline (all steps optional via constructor flags): | |
| 1. **Eye localisation** (``do_crop=True``): detects the fundus disc centre via | |
| gradient-based radial symmetry (dark-region centre-of-mass → Sobel gradients → | |
| radial alignment score → Gaussian smoothing → soft argmax) and estimates the | |
| disc radius from the strongest negative radial intensity gradient. | |
| 2. **Square crop & resize**: crops a square region around the detected disc | |
| (``radius * crop_scale_factor``), optionally allowing overflow beyond image | |
| bounds (``allow_overflow``), then resamples to ``size x size`` via bilinear | |
| ``grid_sample``. When ``do_crop=False``, the whole image is resized directly. | |
| 3. **CLAHE** (``do_clahe=True``): applies Contrast Limited Adaptive Histogram | |
| Equalisation to the CIE LAB luminance channel, using a fully-vectorized | |
| tile-based implementation with bilinear CDF interpolation. | |
| 4. **Normalisation**: channel-wise ``(image - mean) / std`` with configurable | |
| mode (ImageNet, custom, or none). | |
| The processor also returns per-image coordinate-mapping scalars (``scale_x/y``, | |
| ``offset_x/y``) so that predictions in processed-image space can be mapped back | |
| to original pixel coordinates. | |
| All operations are pure PyTorch — no OpenCV, PIL, or NumPy at runtime — and are | |
| CUDA-compatible and batch-friendly. | |
| """ | |
| model_input_names = ["pixel_values"] | |
| def __init__( | |
| self, | |
| size: int = 224, | |
| crop_scale_factor: float = 1.1, | |
| clahe_grid_size: int = 8, | |
| clahe_clip_limit: float = 2.0, | |
| normalization_mode: str = "imagenet", | |
| custom_mean: Optional[List[float]] = None, | |
| custom_std: Optional[List[float]] = None, | |
| do_clahe: bool = True, | |
| do_crop: bool = True, | |
| min_radius_frac: float = 0.1, | |
| max_radius_frac: float = 0.5, | |
| allow_overflow: bool = False, | |
| softmax_temperature: float = 0.1, | |
| **kwargs, | |
| ): | |
| """ | |
| Initialize the EyeCLAHEImageProcessor. | |
| Args: | |
| size: Output image size (square) | |
| crop_scale_factor: Scale factor for crop box (relative to detected radius) | |
| clahe_grid_size: Number of tiles for CLAHE | |
| clahe_clip_limit: Histogram clip limit for CLAHE | |
| normalization_mode: 'imagenet', 'none', or 'custom' | |
| custom_mean: Custom normalization mean (if mode='custom') | |
| custom_std: Custom normalization std (if mode='custom') | |
| do_clahe: Whether to apply CLAHE | |
| do_crop: Whether to perform eye-centered cropping | |
| min_radius_frac: Minimum radius as fraction of image size | |
| max_radius_frac: Maximum radius as fraction of image size | |
| allow_overflow: If True, allow crop box to extend beyond image bounds | |
| and fill missing regions with black. Useful for pre-cropped | |
| images where the fundus circle is partially cut off. | |
| softmax_temperature: Temperature for soft argmax in eye center detection. | |
| Lower values (0.01-0.1) give sharper peak detection, higher values | |
| (0.3-0.5) provide more averaging for noisy images. Default: 0.1. | |
| """ | |
| super().__init__(**kwargs) | |
| self.size = size | |
| self.crop_scale_factor = crop_scale_factor | |
| self.clahe_grid_size = clahe_grid_size | |
| self.clahe_clip_limit = clahe_clip_limit | |
| self.normalization_mode = normalization_mode | |
| self.custom_mean = custom_mean | |
| self.custom_std = custom_std | |
| self.do_clahe = do_clahe | |
| self.do_crop = do_crop | |
| self.min_radius_frac = min_radius_frac | |
| self.max_radius_frac = max_radius_frac | |
| self.allow_overflow = allow_overflow | |
| self.softmax_temperature = softmax_temperature | |
| def preprocess( | |
| self, | |
| images, | |
| masks=None, | |
| return_tensors: str = "pt", | |
| device: Optional[Union[str, torch.device]] = None, | |
| **kwargs, | |
| ) -> BatchFeature: | |
| """Run the full preprocessing pipeline on a batch of images. | |
| Accepts any combination of torch.Tensor, PIL.Image, or numpy.ndarray inputs | |
| (see ``standardize_input`` for format details). Optionally processes | |
| accompanying segmentation masks with matching spatial transforms. | |
| Args: | |
| images: Input images in any supported format. | |
| masks: Optional segmentation masks in any format accepted by | |
| ``standardize_mask_input``. Undergo the same crop/resize as images | |
| (nearest-neighbour interpolation, label-preserving). Returned as | |
| ``torch.long`` under the ``"mask"`` key (or ``None`` if not provided). | |
| return_tensors: Only ``"pt"`` is supported. | |
| device: Device for all tensor operations (e.g. ``"cuda:0"``). | |
| Defaults to the device of the input tensor, or CPU for PIL/numpy. | |
| **kwargs: Passed through to ``BaseImageProcessor``. | |
| Returns: | |
| ``BatchFeature`` with keys: | |
| - ``pixel_values`` (B, 3, size, size): Processed float32 images. | |
| - ``mask`` (B, 1, size, size) or ``None``: Processed long masks. | |
| - ``scale_x``, ``scale_y`` (B,): Per-image scale factors. | |
| - ``offset_x``, ``offset_y`` (B,): Per-image offsets. | |
| Coordinate mapping from processed → original pixel space:: | |
| orig_x = offset_x + proc_x * scale_x | |
| orig_y = offset_y + proc_y * scale_y | |
| """ | |
| if return_tensors != "pt": | |
| raise ValueError("Only 'pt' (PyTorch) tensors are supported") | |
| # Determine device | |
| if device is not None: | |
| device = torch.device(device) | |
| elif isinstance(images, torch.Tensor): | |
| device = images.device | |
| elif isinstance(images, list) and len(images) > 0 and isinstance(images[0], torch.Tensor): | |
| device = images[0].device | |
| else: | |
| # PIL images and numpy arrays default to CPU | |
| device = torch.device('cpu') | |
| # Standardize input | |
| images = standardize_input(images, device) | |
| if masks is not None: | |
| masks = standardize_mask_input(masks, device) | |
| B, C, H_orig, W_orig = images.shape | |
| if self.do_crop: | |
| # Estimate eye center | |
| cx, cy = estimate_eye_center(images, softmax_temperature=self.softmax_temperature) | |
| # Estimate radius | |
| radius = estimate_radius( | |
| images, cx, cy, | |
| min_radius_frac=self.min_radius_frac, | |
| max_radius_frac=self.max_radius_frac, | |
| ) | |
| # Compute crop box | |
| x1, y1, x2, y2 = compute_crop_box( | |
| cx, cy, radius, H_orig, W_orig, | |
| scale_factor=self.crop_scale_factor, | |
| allow_overflow=self.allow_overflow, | |
| ) | |
| # Compute coordinate mapping | |
| # For processed coordinates in [0, self.size-1], map back to original | |
| scale_x = (x2 - x1) / (self.size - 1) | |
| scale_y = (y2 - y1) / (self.size - 1) | |
| offset_x = x1 | |
| offset_y = y1 | |
| # Crop and resize | |
| # Use 'zeros' padding when allow_overflow is True to fill out-of-bounds with black | |
| padding_mode = 'zeros' if self.allow_overflow else 'border' | |
| images = batch_crop_and_resize(images, x1, y1, x2, y2, self.size, padding_mode=padding_mode) | |
| if masks is not None: | |
| masks = batch_crop_and_resize_mask( | |
| masks, x1, y1, x2, y2, | |
| self.size, | |
| padding_mode=padding_mode, | |
| ) | |
| else: | |
| # Just resize - no crop | |
| # Compute coordinate mapping for direct resize | |
| scale_x = torch.full((B,), (W_orig - 1) / (self.size - 1), device=device, dtype=images.dtype) | |
| scale_y = torch.full((B,), (H_orig - 1) / (self.size - 1), device=device, dtype=images.dtype) | |
| offset_x = torch.zeros(B, device=device, dtype=images.dtype) | |
| offset_y = torch.zeros(B, device=device, dtype=images.dtype) | |
| images = resize_images(images, self.size) | |
| if masks is not None: | |
| # F.interpolate requires float input; cast, resize, then restore long | |
| masks = resize_images(masks.float(), self.size, mode="nearest", antialias=False).round().long() | |
| # Apply CLAHE | |
| if self.do_clahe: | |
| images = apply_clahe_vectorized( | |
| images, | |
| grid_size=self.clahe_grid_size, | |
| clip_limit=self.clahe_clip_limit, | |
| ) | |
| # Normalize | |
| images = normalize_images( | |
| images, | |
| mean=self.custom_mean, | |
| std=self.custom_std, | |
| mode=self.normalization_mode, | |
| ) | |
| # Return with coordinate mapping information (flattened structure) | |
| data = { | |
| "pixel_values": images, | |
| "scale_x": scale_x, | |
| "scale_y": scale_y, | |
| "offset_x": offset_x, | |
| "offset_y": offset_y, | |
| } | |
| if masks is not None: | |
| data["mask"] = masks | |
| return BatchFeature(data=data, tensor_type="pt") | |
| def __call__( | |
| self, | |
| images: Union[torch.Tensor, List[torch.Tensor]], | |
| **kwargs, | |
| ) -> BatchFeature: | |
| """Alias for ``preprocess`` — enables ``processor(images, ...)`` call syntax.""" | |
| return self.preprocess(images, **kwargs) | |
| # For AutoImageProcessor registration | |
| EyeGPUImageProcessor = EyeCLAHEImageProcessor | |