import gc import logging import os import time import traceback from dataclasses import dataclass, field from typing import Any, Callable, Dict, Optional, Tuple import cv2 import numpy as np import torch from PIL import Image from diffusers import AutoPipelineForInpainting from diffusers import ControlNetModel from diffusers import DPMSolverMultistepScheduler from diffusers import StableDiffusionXLControlNetInpaintPipeline from transformers import AutoImageProcessor from transformers import AutoModelForDepthEstimation from transformers import DPTForDepthEstimation from transformers import DPTImageProcessor from control_image_processor import ControlImageProcessor from inpainting_blender import InpaintingBlender logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # Dedicated SDXL Inpainting model - trained specifically for inpainting SDXL_INPAINTING_MODEL = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1" @dataclass class InpaintingConfig: """Configuration for inpainting operations.""" # ControlNet settings (for ControlNet mode only) controlnet_conditioning_scale: float = 0.7 conditioning_type: str = "canny" # Canny edge detection parameters canny_low_threshold: int = 100 canny_high_threshold: int = 200 # Mask settings feather_radius: int = 3 min_mask_coverage: float = 0.01 max_mask_coverage: float = 0.95 # Generation settings num_inference_steps: int = 25 guidance_scale: float = 7.5 strength: float = 0.99 # Use 0.99 to avoid noise issues with 1.0 # Memory settings enable_vae_tiling: bool = True max_resolution: int = 1024 @dataclass class InpaintingResult: """Result container for inpainting operations.""" success: bool result_image: Optional[Image.Image] = None preview_image: Optional[Image.Image] = None control_image: Optional[Image.Image] = None blended_image: Optional[Image.Image] = None quality_score: float = 0.0 generation_time: float = 0.0 error_message: str = "" metadata: Dict[str, Any] = field(default_factory=dict) class InpaintingModule: """ Dual-mode Inpainting Module for SceneWeaver. Supports two modes: 1. Pure Inpainting (use_controlnet=False): Uses dedicated SDXL Inpainting model - Best for: Object replacement, Object removal - More stable, better edge blending 2. ControlNet Inpainting (use_controlnet=True): Uses ControlNet + SDXL - Best for: Clothing change (depth), Color change (canny) - Preserves structure in masked region Example: >>> module = InpaintingModule(device="cuda") >>> # For object replacement (no ControlNet) >>> module.load_pipeline(use_controlnet=False) >>> result = module.execute_inpainting(image, mask, "a vase with flowers") """ # ControlNet model identifiers CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0" CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0" DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf" DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas" # Base models for ControlNet mode SUPPORTED_MODELS = { "juggernaut_xl": "RunDiffusion/Juggernaut-XL-v9", "realvis_xl": "SG161222/RealVisXL_V4.0", "sdxl_base": "stabilityai/stable-diffusion-xl-base-1.0", "animagine_xl": "cagliostrolab/animagine-xl-3.1", } def __init__( self, device: str = "auto", config: Optional[InpaintingConfig] = None ): """Initialize the InpaintingModule.""" self.device = self._setup_device(device) self.config = config or InpaintingConfig() # Sub-modules self._control_processor = ControlImageProcessor( device=self.device, canny_low_threshold=self.config.canny_low_threshold, canny_high_threshold=self.config.canny_high_threshold ) self._blender = InpaintingBlender( min_mask_coverage=self.config.min_mask_coverage, max_mask_coverage=self.config.max_mask_coverage ) # Pipeline instances self._pipeline = None self._controlnet = None self._depth_estimator = None self._depth_processor = None # State tracking self.is_initialized = False self._current_mode = None # "pure" or "controlnet" self._current_conditioning_type = None self._current_model_key = None logger.info(f"InpaintingModule initialized on {self.device}") def _setup_device(self, device: str) -> str: """Setup computation device.""" if device == "auto": if torch.cuda.is_available(): return "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "mps" return "cpu" return device def _memory_cleanup(self, aggressive: bool = False) -> None: """Perform memory cleanup.""" for _ in range(5 if aggressive else 2): gc.collect() is_spaces = os.getenv('SPACE_ID') is not None if not is_spaces and torch.cuda.is_available(): torch.cuda.empty_cache() if aggressive: torch.cuda.ipc_collect() def load_pipeline( self, use_controlnet: bool = False, conditioning_type: str = "canny", model_key: str = "sdxl_base", progress_callback: Optional[Callable[[str, int], None]] = None ) -> Tuple[bool, str]: """ Load the appropriate inpainting pipeline. Parameters ---------- use_controlnet : bool If False, use dedicated SDXL Inpainting model (for replacement/removal) If True, use ControlNet pipeline (for clothing/color change) conditioning_type : str ControlNet type: "canny" or "depth" (only used when use_controlnet=True) model_key : str Base model for ControlNet mode progress_callback : callable, optional Progress update function Returns ------- tuple (success: bool, error_message: str) """ mode = "controlnet" if use_controlnet else "pure" # Check if already loaded with same config if (self.is_initialized and self._current_mode == mode and (not use_controlnet or (self._current_conditioning_type == conditioning_type and self._current_model_key == model_key))): logger.info(f"Pipeline already loaded: mode={mode}") return True, "" logger.info(f"Loading pipeline: mode={mode}, conditioning={conditioning_type}") try: self._memory_cleanup(aggressive=True) if progress_callback: progress_callback("Preparing pipeline...", 10) # Unload existing pipeline self._unload_pipeline() dtype = torch.float16 if self.device == "cuda" else torch.float32 if not use_controlnet: # Mode A: Pure SDXL Inpainting (for replacement/removal) if progress_callback: progress_callback("Loading SDXL Inpainting model...", 30) self._pipeline = AutoPipelineForInpainting.from_pretrained( SDXL_INPAINTING_MODEL, torch_dtype=dtype, variant="fp16" if dtype == torch.float16 else None, ) self._current_mode = "pure" self._current_conditioning_type = None logger.info("Loaded pure SDXL Inpainting pipeline") else: # Mode B: ControlNet Inpainting (for structure-preserving tasks) if model_key not in self.SUPPORTED_MODELS: model_key = "sdxl_base" base_model_id = self.SUPPORTED_MODELS[model_key] if progress_callback: progress_callback("Loading ControlNet model...", 30) # Load ControlNet if conditioning_type == "canny": self._controlnet = ControlNetModel.from_pretrained( self.CONTROLNET_CANNY_MODEL, torch_dtype=dtype, use_safetensors=True ) elif conditioning_type == "depth": self._controlnet = ControlNetModel.from_pretrained( self.CONTROLNET_DEPTH_MODEL, torch_dtype=dtype, use_safetensors=True ) self._load_depth_estimator() else: raise ValueError(f"Unknown conditioning type: {conditioning_type}") if progress_callback: progress_callback(f"Loading {model_key}...", 60) # Load pipeline with ControlNet use_variant = model_key != "animagine_xl" load_kwargs = { "controlnet": self._controlnet, "torch_dtype": dtype, "use_safetensors": True, } if use_variant and dtype == torch.float16: load_kwargs["variant"] = "fp16" self._pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( base_model_id, **load_kwargs ) self._current_mode = "controlnet" self._current_conditioning_type = conditioning_type self._current_model_key = model_key logger.info(f"Loaded ControlNet pipeline: {model_key} + {conditioning_type}") if progress_callback: progress_callback("Configuring pipeline...", 80) # Configure scheduler self._pipeline.scheduler = DPMSolverMultistepScheduler.from_config( self._pipeline.scheduler.config ) # Move to device and optimize self._pipeline = self._pipeline.to(self.device) self._apply_optimizations() self.is_initialized = True if progress_callback: progress_callback("Pipeline ready!", 100) return True, "" except Exception as e: error_msg = str(e) logger.error(f"Failed to load pipeline: {error_msg}") traceback.print_exc() self._unload_pipeline() return False, error_msg def _load_depth_estimator(self) -> None: """Load depth estimation model.""" try: self._depth_processor = AutoImageProcessor.from_pretrained( self.DEPTH_MODEL_PRIMARY ) self._depth_estimator = AutoModelForDepthEstimation.from_pretrained( self.DEPTH_MODEL_PRIMARY, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) self._depth_estimator.to(self.device) self._depth_estimator.eval() logger.info("Loaded Depth-Anything model") except Exception as e: logger.warning(f"Primary depth model failed: {e}, trying fallback...") self._depth_processor = DPTImageProcessor.from_pretrained( self.DEPTH_MODEL_FALLBACK ) self._depth_estimator = DPTForDepthEstimation.from_pretrained( self.DEPTH_MODEL_FALLBACK, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) self._depth_estimator.to(self.device) self._depth_estimator.eval() logger.info("Loaded MiDaS fallback model") def _apply_optimizations(self) -> None: """Apply memory and performance optimizations.""" if self._pipeline is None: return try: self._pipeline.enable_xformers_memory_efficient_attention() logger.info("Enabled xformers attention") except Exception: try: self._pipeline.enable_attention_slicing() logger.info("Enabled attention slicing") except Exception: pass if self.config.enable_vae_tiling: if hasattr(self._pipeline, 'enable_vae_tiling'): self._pipeline.enable_vae_tiling() if hasattr(self._pipeline, 'enable_vae_slicing'): self._pipeline.enable_vae_slicing() def _unload_pipeline(self) -> None: """Unload pipeline and free memory.""" if self._pipeline is not None: del self._pipeline self._pipeline = None if self._controlnet is not None: del self._controlnet self._controlnet = None if self._depth_estimator is not None: del self._depth_estimator self._depth_estimator = None if self._depth_processor is not None: del self._depth_processor self._depth_processor = None self.is_initialized = False self._current_mode = None self._current_conditioning_type = None self._memory_cleanup(aggressive=True) logger.info("Pipeline unloaded") def execute_inpainting( self, image: Image.Image, mask: Image.Image, prompt: str, progress_callback: Optional[Callable[[str, int], None]] = None, **kwargs ) -> InpaintingResult: """ Execute inpainting operation. Parameters ---------- image : PIL.Image Original image mask : PIL.Image Inpainting mask (white = area to regenerate) prompt : str Text description progress_callback : callable, optional Progress update function **kwargs Additional parameters from template Returns ------- InpaintingResult Result with generated image """ start_time = time.time() if not self.is_initialized: return InpaintingResult( success=False, error_message="Pipeline not initialized. Call load_pipeline() first." ) logger.info(f"Inpainting: mode={self._current_mode}, prompt='{prompt[:50]}...'") try: if progress_callback: progress_callback("Preparing images...", 10) # Prepare image if image.mode != 'RGB': image = image.convert('RGB') # Store original size for later restoration original_size = image.size # (width, height) # Ensure dimensions are multiple of 8 for model compatibility width, height = image.size new_width = (width // 8) * 8 new_height = (height // 8) * 8 if new_width != width or new_height != height: image = image.resize((new_width, new_height), Image.LANCZOS) # Limit resolution for memory efficiency max_res = self.config.max_resolution if max(new_width, new_height) > max_res: scale = max_res / max(new_width, new_height) new_width = int(new_width * scale) // 8 * 8 new_height = int(new_height * scale) // 8 * 8 image = image.resize((new_width, new_height), Image.LANCZOS) # Prepare mask with dilation mask_dilation = kwargs.get('mask_dilation', 0) processed_mask = self._prepare_mask( mask, (new_width, new_height), dilation=mask_dilation, feather_radius=kwargs.get('feather_radius', self.config.feather_radius) ) # Get generation parameters strength = kwargs.get('strength', self.config.strength) guidance_scale = kwargs.get('guidance_scale', self.config.guidance_scale) num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps) negative_prompt = kwargs.get('negative_prompt', "") # Optimize for HuggingFace Spaces is_spaces = os.getenv('SPACE_ID') is not None if is_spaces: num_steps = min(num_steps, 15) # Setup generator with seed # If seed is -1 or None, use random seed based on current time input_seed = kwargs.get('seed', -1) if input_seed is None or input_seed < 0: seed = int(time.time() * 1000) % (2**32) else: seed = int(input_seed) generator = torch.Generator(device=self.device).manual_seed(seed) logger.info(f"Using seed: {seed}") # Generate based on mode if self._current_mode == "pure": # Pure inpainting - no ControlNet if progress_callback: progress_callback("Generating (Pure Inpainting)...", 40) result_image = self._generate_pure_inpaint( image=image, mask=processed_mask, prompt=prompt, negative_prompt=negative_prompt, num_steps=num_steps, guidance_scale=guidance_scale, strength=strength, generator=generator ) control_image = None else: # ControlNet inpainting if progress_callback: progress_callback("Generating control image...", 30) # Prepare control image preserve_structure = kwargs.get('preserve_structure_in_mask', False) edge_guidance_mode = kwargs.get('edge_guidance_mode', 'boundary') control_image = self._control_processor.prepare_control_image( image=image, mode=self._current_conditioning_type, mask=processed_mask, preserve_structure=preserve_structure, edge_guidance_mode=edge_guidance_mode ) if progress_callback: progress_callback("Generating (ControlNet)...", 50) conditioning_scale = kwargs.get( 'controlnet_conditioning_scale', self.config.controlnet_conditioning_scale ) result_image = self._generate_controlnet_inpaint( image=image, mask=processed_mask, control_image=control_image, prompt=prompt, negative_prompt=negative_prompt, num_steps=num_steps, guidance_scale=guidance_scale, conditioning_scale=conditioning_scale, strength=strength, generator=generator ) generation_time = time.time() - start_time # Restore original size if it was changed if result_image.size != original_size: result_image = result_image.resize(original_size, Image.LANCZOS) logger.info(f"Restored result to original size: {original_size}") if progress_callback: progress_callback("Complete!", 100) return InpaintingResult( success=True, result_image=result_image, blended_image=result_image, # Pipeline output is already blended control_image=control_image, generation_time=generation_time, metadata={ "seed": seed, "prompt": prompt, "mode": self._current_mode, "num_steps": num_steps, "guidance_scale": guidance_scale, "strength": strength, "original_size": original_size, } ) except torch.cuda.OutOfMemoryError: logger.error("CUDA out of memory") self._memory_cleanup(aggressive=True) return InpaintingResult( success=False, error_message="GPU memory exhausted." ) except Exception as e: logger.error(f"Inpainting failed: {e}") traceback.print_exc() return InpaintingResult( success=False, error_message=str(e) ) def _prepare_mask( self, mask: Image.Image, target_size: Tuple[int, int], dilation: int = 0, feather_radius: int = 3 ) -> Image.Image: """Prepare mask with optional dilation and feathering.""" # Convert and resize if mask.mode != 'L': mask = mask.convert('L') if mask.size != target_size: mask = mask.resize(target_size, Image.LANCZOS) mask_array = np.array(mask) # Apply dilation to expand mask if dilation > 0: kernel = cv2.getStructuringElement( cv2.MORPH_ELLIPSE, (dilation * 2 + 1, dilation * 2 + 1) ) mask_array = cv2.dilate(mask_array, kernel, iterations=1) logger.debug(f"Applied mask dilation: {dilation}px") # Apply feathering if feather_radius > 0: mask_array = cv2.GaussianBlur( mask_array, (feather_radius * 2 + 1, feather_radius * 2 + 1), feather_radius / 2 ) return Image.fromarray(mask_array, mode='L') def _generate_pure_inpaint( self, image: Image.Image, mask: Image.Image, prompt: str, negative_prompt: str, num_steps: int, guidance_scale: float, strength: float, generator: torch.Generator ) -> Image.Image: """Generate using pure SDXL Inpainting pipeline with DPM++ scheduler for speed.""" # Use DPM++ 2M Karras scheduler for faster generation original_scheduler = self._pipeline.scheduler self._pipeline.scheduler = DPMSolverMultistepScheduler.from_config( self._pipeline.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++" ) logger.info("Switched to DPM++ 2M Karras scheduler for Pure Inpainting") try: with torch.inference_mode(): result = self._pipeline( prompt=prompt, negative_prompt=negative_prompt, image=image, mask_image=mask, num_inference_steps=num_steps, guidance_scale=guidance_scale, strength=strength, generator=generator ) return result.images[0] finally: # Restore original scheduler self._pipeline.scheduler = original_scheduler def _generate_controlnet_inpaint( self, image: Image.Image, mask: Image.Image, control_image: Image.Image, prompt: str, negative_prompt: str, num_steps: int, guidance_scale: float, conditioning_scale: float, strength: float, generator: torch.Generator ) -> Image.Image: """Generate using ControlNet Inpainting pipeline.""" with torch.inference_mode(): result = self._pipeline( prompt=prompt, negative_prompt=negative_prompt, image=image, mask_image=mask, control_image=control_image, num_inference_steps=num_steps, guidance_scale=guidance_scale, controlnet_conditioning_scale=conditioning_scale, strength=strength, generator=generator ) return result.images[0] def get_status(self) -> Dict[str, Any]: """Get current module status.""" return { "initialized": self.is_initialized, "device": self.device, "mode": self._current_mode, "conditioning_type": self._current_conditioning_type, "model_key": self._current_model_key, }