import gc import logging from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional, Tuple from diffusers import StableDiffusionXLControlNetInpaintPipeline import torch logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class ImageMode(Enum): """Image style modes for model selection.""" PHOTO = "photo" ANIME = "anime" @dataclass class ModelConfig: """Configuration for an inpainting model.""" model_id: str name: str description: str mode: ImageMode requires_variant: bool = True variant: str = "fp16" recommended_for: str = "" # Model-specific settings default_guidance_scale: float = 7.5 default_num_inference_steps: int = 25 class InpaintingModelManager: """ Manages multiple inpainting models for different image styles. Provides lazy loading and switching between models optimized for photorealistic images vs anime/illustration styles. Attributes: AVAILABLE_MODELS: Dictionary of all supported models current_model: Currently loaded model identifier Example: >>> manager = InpaintingModelManager(device="cuda") >>> pipeline = manager.get_pipeline(ImageMode.PHOTO) >>> # Use pipeline for inpainting >>> manager.switch_model(ImageMode.ANIME) """ # Available models configuration AVAILABLE_MODELS: Dict[str, ModelConfig] = { # Photo-realistic models "juggernaut_xl": ModelConfig( model_id="RunDiffusion/Juggernaut-XL-v9", name="JuggernautXL v9", description="Best for photorealistic images, portraits, and real photos", mode=ImageMode.PHOTO, requires_variant=True, variant="fp16", recommended_for="Real photos, portraits, professional photography", default_guidance_scale=7.0, default_num_inference_steps=25 ), "realvis_xl": ModelConfig( model_id="SG161222/RealVisXL_V4.0", name="RealVisXL v4", description="Excellent for realistic images with fine details", mode=ImageMode.PHOTO, requires_variant=True, variant="fp16", recommended_for="Realistic scenes, product photos, nature", default_guidance_scale=7.0, default_num_inference_steps=25 ), # Anime/Illustration models "sdxl_base": ModelConfig( model_id="stabilityai/stable-diffusion-xl-base-1.0", name="SDXL Base", description="Versatile model for general use and illustrations", mode=ImageMode.ANIME, requires_variant=True, variant="fp16", recommended_for="General illustrations, digital art, versatile use", default_guidance_scale=7.5, default_num_inference_steps=25 ), "animagine_xl": ModelConfig( model_id="cagliostrolab/animagine-xl-3.1", name="Animagine XL 3.1", description="Specialized for anime and manga style images", mode=ImageMode.ANIME, requires_variant=False, recommended_for="Anime, manga, cartoon style images", default_guidance_scale=7.0, default_num_inference_steps=25 ), } # Default model for each mode DEFAULT_MODELS = { ImageMode.PHOTO: "juggernaut_xl", ImageMode.ANIME: "sdxl_base" } def __init__(self, device: Optional[str] = None): """ Initialize the model manager. Parameters ---------- device : str, optional Device to load models on. Auto-detected if not specified. """ self.device = device or self._detect_device() self._current_model_key: Optional[str] = None self._pipeline: Optional[Any] = None self._controlnet: Optional[Any] = None self._controlnet_loaded: bool = False logger.info(f"InpaintingModelManager initialized on device: {self.device}") def _detect_device(self) -> str: """Detect the best available device.""" if torch.cuda.is_available(): return "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" return "cpu" def get_models_for_mode(self, mode: ImageMode) -> Dict[str, ModelConfig]: """ Get all available models for a specific mode. Parameters ---------- mode : ImageMode The image mode (PHOTO or ANIME) Returns ------- dict Dictionary of model configs for the mode """ return { key: config for key, config in self.AVAILABLE_MODELS.items() if config.mode == mode } def get_model_choices(self) -> Dict[str, list]: """ Get model choices formatted for UI dropdown. Returns ------- dict Dictionary with 'photo' and 'anime' lists of (display_name, key) tuples """ choices = { "photo": [], "anime": [] } for key, config in self.AVAILABLE_MODELS.items(): display = f"{config.name} - {config.description}" if config.mode == ImageMode.PHOTO: choices["photo"].append((display, key)) else: choices["anime"].append((display, key)) return choices def get_default_model(self, mode: ImageMode) -> str: """Get the default model key for a mode.""" return self.DEFAULT_MODELS.get(mode, "sdxl_base") def load_controlnet(self) -> Any: """ Load the ControlNet model (shared across all base models). Returns ------- ControlNetModel Loaded ControlNet model """ if self._controlnet_loaded and self._controlnet is not None: return self._controlnet try: from diffusers import ControlNetModel logger.info("Loading ControlNet Canny model...") self._controlnet = ControlNetModel.from_pretrained( "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, use_safetensors=True ) self._controlnet_loaded = True logger.info("ControlNet loaded successfully") return self._controlnet except Exception as e: logger.error(f"Failed to load ControlNet: {e}") raise def load_pipeline( self, model_key: Optional[str] = None, mode: Optional[ImageMode] = None ) -> Any: """ Load an inpainting pipeline for the specified model. Parameters ---------- model_key : str, optional Specific model key to load mode : ImageMode, optional If model_key not specified, load default for this mode Returns ------- StableDiffusionXLControlNetInpaintPipeline Loaded pipeline ready for inference """ # Determine which model to load if model_key is None: if mode is None: mode = ImageMode.PHOTO model_key = self.get_default_model(mode) # Check if already loaded if self._current_model_key == model_key and self._pipeline is not None: logger.info(f"Model {model_key} already loaded") return self._pipeline # Unload current model if different if self._current_model_key != model_key: self.unload_pipeline() # Get model config config = self.AVAILABLE_MODELS.get(model_key) if config is None: raise ValueError(f"Unknown model key: {model_key}") logger.info(f"Loading model: {config.name} ({config.model_id})") try: # Ensure ControlNet is loaded controlnet = self.load_controlnet() # Load pipeline dtype = torch.float16 if self.device == "cuda" else torch.float32 load_kwargs = { "controlnet": controlnet, "torch_dtype": dtype, "use_safetensors": True, } if config.requires_variant: load_kwargs["variant"] = config.variant self._pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( config.model_id, **load_kwargs ) # Move to device and optimize self._pipeline = self._pipeline.to(self.device) if self.device == "cuda": self._pipeline.enable_vae_tiling() try: self._pipeline.enable_xformers_memory_efficient_attention() logger.info("xformers enabled") except Exception: logger.info("xformers not available, using default attention") self._current_model_key = model_key logger.info(f"Model {config.name} loaded successfully") return self._pipeline except Exception as e: logger.error(f"Failed to load model {model_key}: {e}") raise def unload_pipeline(self) -> None: """Unload the current pipeline to free memory.""" if self._pipeline is not None: logger.info(f"Unloading model: {self._current_model_key}") del self._pipeline self._pipeline = None self._current_model_key = None if self.device == "cuda": torch.cuda.empty_cache() gc.collect() def switch_model(self, model_key: str) -> Any: """ Switch to a different model. Parameters ---------- model_key : str Model key to switch to Returns ------- Pipeline Newly loaded pipeline """ return self.load_pipeline(model_key=model_key) def get_current_model_config(self) -> Optional[ModelConfig]: """Get the configuration of the currently loaded model.""" if self._current_model_key is None: return None return self.AVAILABLE_MODELS.get(self._current_model_key) def get_pipeline(self) -> Optional[Any]: """Get the currently loaded pipeline.""" return self._pipeline def is_loaded(self) -> bool: """Check if a pipeline is currently loaded.""" return self._pipeline is not None def get_status(self) -> Dict[str, Any]: """ Get current status of the model manager. Returns ------- dict Status information """ current_config = self.get_current_model_config() return { "device": self.device, "current_model": self._current_model_key, "current_model_name": current_config.name if current_config else None, "is_loaded": self.is_loaded(), "controlnet_loaded": self._controlnet_loaded, "available_models": list(self.AVAILABLE_MODELS.keys()) } def get_model_selection_guide() -> str: """ Get HTML guide for model selection to display in UI. Returns ------- str HTML formatted guide """ return """

📸 Model Selection Guide

🖼️ Photo Mode

Best for: Real photographs, portraits, product shots, nature photos

Recommended: JuggernautXL for portraits, RealVisXL for scenes

🎨 Anime Mode

Best for: Anime, manga, illustrations, digital art, cartoons

Recommended: Animagine XL for anime, SDXL Base for general art

"""