import logging from typing import Any, Dict, Optional, Tuple import cv2 import numpy as np from PIL import Image logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class InpaintingBlender: """ Handles mask processing, prompt enhancement, and result blending for inpainting. This class encapsulates all pre-processing and post-processing operations needed for inpainting, separate from the main generation pipeline. Attributes: min_mask_coverage: Minimum mask coverage threshold max_mask_coverage: Maximum mask coverage threshold Example: >>> blender = InpaintingBlender() >>> processed_mask, info = blender.prepare_mask(mask, (512, 512), feather_radius=8) >>> enhanced_prompt, negative = blender.enhance_prompt("a flower", image, mask) >>> result = blender.blend_result(original, generated, mask) """ def __init__( self, min_mask_coverage: float = 0.01, max_mask_coverage: float = 0.95 ): """ Initialize the InpaintingBlender. Parameters ---------- min_mask_coverage : float Minimum mask coverage (default: 1%) max_mask_coverage : float Maximum mask coverage (default: 95%) """ self.min_mask_coverage = min_mask_coverage self.max_mask_coverage = max_mask_coverage logger.info("InpaintingBlender initialized") def prepare_mask( self, mask: Image.Image, target_size: Tuple[int, int], feather_radius: int = 8 ) -> Tuple[Image.Image, Dict[str, Any]]: """ Prepare and validate mask for inpainting. Parameters ---------- mask : PIL.Image Input mask (white = inpaint area) target_size : tuple Target (width, height) to match input image feather_radius : int Feathering radius in pixels Returns ------- tuple (processed_mask, validation_info) Raises ------ ValueError If mask coverage is outside acceptable range """ # Convert to grayscale if mask.mode != 'L': mask = mask.convert('L') # Resize to match target if mask.size != target_size: mask = mask.resize(target_size, Image.LANCZOS) # Convert to array for processing mask_array = np.array(mask) # Calculate coverage total_pixels = mask_array.size white_pixels = np.count_nonzero(mask_array > 127) coverage = white_pixels / total_pixels validation_info = { "coverage": coverage, "white_pixels": white_pixels, "total_pixels": total_pixels, "feather_radius": feather_radius, "valid": True, "warning": "" } # Validate coverage if coverage < self.min_mask_coverage: validation_info["valid"] = False validation_info["warning"] = ( f"Mask coverage too low ({coverage:.1%}). " f"Please select a larger area to inpaint." ) logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.min_mask_coverage:.1%}") elif coverage > self.max_mask_coverage: validation_info["valid"] = False validation_info["warning"] = ( f"Mask coverage too high ({coverage:.1%}). " f"Consider using background generation instead." ) logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.max_mask_coverage:.1%}") # Apply feathering if feather_radius > 0: mask_array = cv2.GaussianBlur( mask_array, (feather_radius * 2 + 1, feather_radius * 2 + 1), feather_radius / 2 ) logger.debug(f"Applied {feather_radius}px feathering to mask") processed_mask = Image.fromarray(mask_array, mode='L') return processed_mask, validation_info def enhance_prompt_for_inpainting( self, prompt: str, image: Image.Image, mask: Image.Image ) -> Tuple[str, str]: """ Enhance prompt based on non-masked region analysis. Analyzes the surrounding context to generate appropriate lighting and color descriptors. Parameters ---------- prompt : str User-provided prompt image : PIL.Image Original image mask : PIL.Image Inpainting mask Returns ------- tuple (enhanced_prompt, negative_prompt) """ logger.info("Enhancing prompt for inpainting context...") # Convert to arrays img_array = np.array(image.convert('RGB')) mask_array = np.array(mask.convert('L')) # Analyze non-masked regions non_masked = mask_array < 127 if not np.any(non_masked): # No context available enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic" negative_prompt = self._get_inpainting_negative_prompt() return enhanced_prompt, negative_prompt # Extract context pixels context_pixels = img_array[non_masked] # Convert to Lab for analysis context_lab = cv2.cvtColor( context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2LAB ).reshape(-1, 3) # Use robust statistics (median) to avoid outlier influence median_l = np.median(context_lab[:, 0]) median_b = np.median(context_lab[:, 2]) # Analyze lighting conditions lighting_descriptors = [] if median_l > 170: lighting_descriptors.append("bright") elif median_l > 130: lighting_descriptors.append("well-lit") elif median_l > 80: lighting_descriptors.append("moderate lighting") else: lighting_descriptors.append("dim lighting") # Analyze color temperature (b channel: blue(-) to yellow(+)) if median_b > 140: lighting_descriptors.append("warm golden tones") elif median_b > 120: lighting_descriptors.append("warm afternoon light") elif median_b < 110: lighting_descriptors.append("cool neutral tones") # Calculate saturation from context hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV) median_saturation = np.median(hsv[:, :, 1]) if median_saturation > 150: lighting_descriptors.append("vibrant colors") elif median_saturation < 80: lighting_descriptors.append("subtle muted colors") # Build enhanced prompt lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else "" quality_suffix = "high quality, detailed, photorealistic, seamless integration" if lighting_desc: enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}" else: enhanced_prompt = f"{prompt}, {quality_suffix}" negative_prompt = self._get_inpainting_negative_prompt() logger.info(f"Enhanced prompt with context: {lighting_desc}") return enhanced_prompt, negative_prompt def _get_inpainting_negative_prompt(self) -> str: """Get standard negative prompt for inpainting.""" return ( "inconsistent lighting, wrong perspective, mismatched colors, " "visible seams, blending artifacts, color bleeding, " "blurry, low quality, distorted, deformed, " "harsh edges, unnatural transition" ) def blend_result( self, original: Image.Image, generated: Image.Image, mask: Image.Image ) -> Image.Image: """ Blend generated content with original image. Uses color matching and linear color space blending for seamless results. Parameters ---------- original : PIL.Image Original image generated : PIL.Image Generated inpainted image mask : PIL.Image Blending mask (white = use generated) Returns ------- PIL.Image Blended result """ logger.info("Blending inpainting result with color matching...") # Ensure same size if generated.size != original.size: generated = generated.resize(original.size, Image.LANCZOS) if mask.size != original.size: mask = mask.resize(original.size, Image.LANCZOS) # Convert to arrays orig_array = np.array(original.convert('RGB')).astype(np.float32) gen_array = np.array(generated.convert('RGB')).astype(np.float32) mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0 # Apply color matching to generated region (use original mask for accurate boundary detection) gen_array = self._match_colors_at_boundary(orig_array, gen_array, mask_array) # Create blend mask: soften edges ONLY for blending (not for generation) # This ensures full generation coverage while smooth blending at edges blend_mask = self._create_blend_mask(mask_array) # sRGB to linear conversion def srgb_to_linear(img: np.ndarray) -> np.ndarray: img_norm = img / 255.0 return np.where( img_norm <= 0.04045, img_norm / 12.92, np.power((img_norm + 0.055) / 1.055, 2.4) ) def linear_to_srgb(img: np.ndarray) -> np.ndarray: img_clipped = np.clip(img, 0, 1) return np.where( img_clipped <= 0.0031308, 12.92 * img_clipped, 1.055 * np.power(img_clipped, 1/2.4) - 0.055 ) # Convert to linear space orig_linear = srgb_to_linear(orig_array) gen_linear = srgb_to_linear(gen_array) # Alpha blending in linear space using the blend mask (with softened edges) alpha = blend_mask[:, :, np.newaxis] result_linear = gen_linear * alpha + orig_linear * (1 - alpha) # Convert back to sRGB result_srgb = linear_to_srgb(result_linear) result_array = (result_srgb * 255).astype(np.uint8) logger.debug("Blending completed with color matching") return Image.fromarray(result_array) def _match_colors_at_boundary( self, original: np.ndarray, generated: np.ndarray, mask: np.ndarray ) -> np.ndarray: """ Match colors of generated content to original at the boundary. Uses histogram matching in Lab color space for natural blending. Parameters ---------- original : np.ndarray Original image array (float32, 0-255) generated : np.ndarray Generated image array (float32, 0-255) mask : np.ndarray Mask array (float32, 0-1) Returns ------- np.ndarray Color-matched generated image """ # Create boundary region mask (dilated mask - eroded mask) mask_binary = (mask > 0.5).astype(np.uint8) * 255 # Create narrow boundary region for sampling original colors kernel_size = 25 # Pixels to sample around boundary kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) dilated = cv2.dilate(mask_binary, kernel, iterations=1) eroded = cv2.erode(mask_binary, kernel, iterations=1) # Outer boundary (original side) outer_boundary = (dilated > 0) & (mask_binary == 0) # Inner boundary (generated side) inner_boundary = (mask_binary > 0) & (eroded == 0) if not np.any(outer_boundary) or not np.any(inner_boundary): logger.debug("No boundary region found, skipping color matching") return generated # Convert to Lab color space orig_lab = cv2.cvtColor(original.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32) gen_lab = cv2.cvtColor(generated.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32) # Sample colors from boundary regions orig_boundary_pixels = orig_lab[outer_boundary] gen_boundary_pixels = gen_lab[inner_boundary] if len(orig_boundary_pixels) < 10 or len(gen_boundary_pixels) < 10: logger.debug("Not enough boundary pixels, skipping color matching") return generated # Calculate statistics orig_mean = np.mean(orig_boundary_pixels, axis=0) orig_std = np.std(orig_boundary_pixels, axis=0) + 1e-6 gen_mean = np.mean(gen_boundary_pixels, axis=0) gen_std = np.std(gen_boundary_pixels, axis=0) + 1e-6 # Calculate correction factors # Only correct L (lightness) and a,b (color) channels l_correction = (orig_mean[0] - gen_mean[0]) * 0.7 # 70% correction for lightness a_correction = (orig_mean[1] - gen_mean[1]) * 0.5 # 50% correction for color b_correction = (orig_mean[2] - gen_mean[2]) * 0.5 logger.debug(f"Color correction: L={l_correction:.1f}, a={a_correction:.1f}, b={b_correction:.1f}") # Apply correction to masked region only corrected_lab = gen_lab.copy() mask_region = mask > 0.3 # Apply to most of masked region corrected_lab[mask_region, 0] = np.clip( corrected_lab[mask_region, 0] + l_correction, 0, 255 ) corrected_lab[mask_region, 1] = np.clip( corrected_lab[mask_region, 1] + a_correction, 0, 255 ) corrected_lab[mask_region, 2] = np.clip( corrected_lab[mask_region, 2] + b_correction, 0, 255 ) # Convert back to RGB corrected_rgb = cv2.cvtColor( corrected_lab.astype(np.uint8), cv2.COLOR_LAB2RGB ).astype(np.float32) logger.info("Applied boundary color matching") return corrected_rgb def _create_blend_mask(self, mask: np.ndarray) -> np.ndarray: """ Create a blend mask with softened edges for natural compositing. The mask interior stays fully opaque (1.0) while only the edges get a smooth transition. This preserves full generated content while blending naturally at boundaries. Parameters ---------- mask : np.ndarray Original mask array (float32, 0-1) Returns ------- np.ndarray Blend mask with soft edges but solid interior """ # Convert to uint8 for morphological operations mask_uint8 = (mask * 255).astype(np.uint8) # Create eroded version (solid interior) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) eroded = cv2.erode(mask_uint8, kernel, iterations=1) # Create smooth transition zone at edges only # Blur the original mask for edge softness blurred = cv2.GaussianBlur(mask_uint8, (15, 15), 4) # Combine: use eroded (solid) for interior, blurred for edges # Where eroded > 0, use full opacity; elsewhere use blurred transition result = np.where(eroded > 128, mask_uint8, blurred) # Final light smoothing result = cv2.GaussianBlur(result, (5, 5), 1) # Convert back to float blend_mask = result.astype(np.float32) / 255.0 logger.debug("Created blend mask with soft edges and solid interior") return blend_mask def validate_inputs( self, image: Image.Image, mask: Image.Image ) -> Tuple[bool, str]: """ Validate image and mask inputs before processing. Parameters ---------- image : PIL.Image Input image mask : PIL.Image Input mask Returns ------- tuple (is_valid, error_message) """ if image is None: return False, "No image provided" if mask is None: return False, "No mask provided" # Check sizes match if image.size != mask.size: # Will be resized later, so just log a warning logger.warning(f"Image size {image.size} != mask size {mask.size}, will resize") return True, ""