import streamlit as st import os import tempfile from PIL import Image import torch import time import numpy as np import gc class ImageGenerator: def __init__(self): self.model = None self.processor = None self.target_size = (512, 512) self.inference_steps = 30 # Increased for better quality self.guidance_scale = 8.5 # Increased for better adherence to prompt self.aspect_ratio = "1:1" # Default aspect ratio self.image_cache = {} self.vram_optimization = False # Default to no VRAM optimization def set_vram_optimization(self, enabled): """Enable or disable VRAM optimization techniques""" self.vram_optimization = enabled def set_aspect_ratio(self, aspect_ratio): """Set the aspect ratio for image generation""" self.aspect_ratio = aspect_ratio def set_target_size(self, size): """Set the target size for generated images""" self.target_size = size def set_inference_steps(self, steps): """Set the number of inference steps for image generation""" self.inference_steps = steps def get_size_for_aspect_ratio(self, base_size, aspect_ratio=None): """Calculate image dimensions based on aspect ratio""" if aspect_ratio is None: aspect_ratio = self.aspect_ratio # Calculate base pixels (total pixels in the image) base_pixels = base_size[0] * base_size[1] if aspect_ratio == "1:1": # Square format side = int(np.sqrt(base_pixels)) # Ensure even dimensions for compatibility side = side if side % 2 == 0 else side + 1 return (side, side) elif aspect_ratio == "16:9": # Landscape format width = int(np.sqrt(base_pixels * 16 / 9)) height = int(width * 9 / 16) # Ensure even dimensions for compatibility width = width if width % 2 == 0 else width + 1 height = height if height % 2 == 0 else height + 1 return (width, height) elif aspect_ratio == "9:16": # Portrait format height = int(np.sqrt(base_pixels * 16 / 9)) width = int(height * 9 / 16) # Ensure even dimensions for compatibility width = width if width % 2 == 0 else width + 1 height = height if height % 2 == 0 else height + 1 return (width, height) else: # Default to original size return base_size def load_model(self): """Load the image generation model with optimizations for CPU""" if self.model is None: with st.spinner("Loading image generation model..."): try: # Force garbage collection before loading model gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Import here to avoid loading until needed from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler # Use the most reliable model for Hugging Face Spaces model_id = "CompVis/stable-diffusion-v1-4" # Most compatible model # Optimize for Hugging Face Spaces with memory constraints self.model = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16, # Use float16 for memory efficiency safety_checker=None, # Disable safety checker for speed use_safetensors=True # Use safetensors for better memory usage ) # Use a better scheduler for higher quality results self.model.scheduler = DPMSolverMultistepScheduler.from_config( self.model.scheduler.config, algorithm_type="dpmsolver++", solver_order=2 ) # Use CUDA if available, otherwise CPU device = "cuda" if torch.cuda.is_available() else "cpu" self.model = self.model.to(device) # Enable maximum memory optimization for Hugging Face self.model.enable_attention_slicing(slice_size=1) # Try to enable xformers if available try: import xformers self.model.enable_xformers_memory_efficient_attention() except (ImportError, AttributeError): pass # Enable model CPU offloading if on CPU if device == "cpu" and hasattr(self.model, "enable_model_cpu_offload"): self.model.enable_model_cpu_offload() # Enable sequential CPU offload if on CPU if device == "cpu" and hasattr(self.model, "enable_sequential_cpu_offload"): self.model.enable_sequential_cpu_offload() # Use tiled VAE for larger images with less memory if hasattr(self.model, "vae") and hasattr(self.model.vae, "enable_tiling"): self.model.vae.enable_tiling() except Exception as e: st.error(f"Error loading image generation model: {str(e)}. Please try again with VRAM optimization enabled.") self.model = None return self.model def generate_image(self, prompt, negative_prompt="blurry, bad quality, distorted, disfigured, low resolution, worst quality, deformed, text, watermark, writing, letters, numbers"): """Generate an image from a text prompt with optimized settings""" # Apply VRAM optimization if enabled inference_steps = self.inference_steps if self.vram_optimization: # Reduce inference steps for VRAM optimization inference_steps = min(inference_steps, 20) else: # Even without explicit VRAM optimization, limit steps for Hugging Face inference_steps = min(inference_steps, 30) # Generate a cache key based on the prompt and settings import hashlib cache_key = f"{hashlib.md5(prompt.encode()).hexdigest()}_{self.target_size}_{inference_steps}_{self.guidance_scale}_{self.aspect_ratio}" # Check if result is in cache if cache_key in self.image_cache: return self.image_cache[cache_key] # Ensure output directory exists os.makedirs("temp", exist_ok=True) try: # Load the model if not already loaded model = self.load_model() if model is not None: # Clean and enhance the prompt for better image generation enhanced_prompt = self.enhance_prompt_for_aspect_ratio(prompt) # Clean the prompt to remove problematic patterns that might cause text rendering enhanced_prompt = self.clean_prompt_for_image_generation(enhanced_prompt) # Simplify prompt for Hugging Face environment simplified_prompt = self.simplify_prompt(enhanced_prompt) # Force garbage collection before inference gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Generate the image with optimized settings with torch.no_grad(): # Disable gradient calculation for memory efficiency # Use autocast for the appropriate device device = "cuda" if torch.cuda.is_available() else "cpu" # Set a lower guidance scale for better results with limited resources guidance_scale = min(self.guidance_scale, 7.5) # Generate with minimal but effective settings image = model( prompt=simplified_prompt, negative_prompt=negative_prompt, num_inference_steps=inference_steps, guidance_scale=guidance_scale, width=min(self.target_size[0], 512), # Limit size for Hugging Face height=min(self.target_size[1], 512) # Limit size for Hugging Face ).images[0] # Save the image to a temporary file with explicit format output_path = f"temp/image_{int(time.time() * 1000)}.jpg" image = image.convert("RGB") # Ensure image is in RGB mode image.save(output_path, format="JPEG", quality=95) # Use JPEG format explicitly # Verify the image was saved correctly try: from PIL import Image test_load = Image.open(output_path) test_load.verify() # Verify image is valid test_load.close() except Exception as e: st.error(f"Image verification failed: {str(e)}. Using fallback.") return self.create_fallback_image(prompt) # Force garbage collection after inference gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Cache the result self.image_cache[cache_key] = output_path return output_path else: # If model failed to load, try one more time with reduced settings st.warning("Retrying with reduced settings...") return self.retry_with_reduced_settings(prompt) except Exception as e: st.error(f"Error generating image: {str(e)}. Retrying with reduced settings.") return self.retry_with_reduced_settings(prompt) def retry_with_reduced_settings(self, prompt): """Retry image generation with reduced settings for better compatibility""" try: # Force garbage collection gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Reload model with more conservative settings from diffusers import StableDiffusionPipeline # Use the most stable model model_id = "CompVis/stable-diffusion-v1-4" # Load with minimal settings pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16, safety_checker=None, use_safetensors=True ) # Move to appropriate device device = "cuda" if torch.cuda.is_available() else "cpu" pipe = pipe.to(device) # Enable maximum memory optimization pipe.enable_attention_slicing(slice_size=1) # Clean the prompt to be very simple simple_prompt = self.simplify_prompt(prompt) # Generate with minimal settings image = pipe( prompt=simple_prompt, num_inference_steps=20, guidance_scale=7.0, width=512, height=512 ).images[0] # Save the image with explicit format output_path = f"temp/retry_image_{int(time.time() * 1000)}.jpg" image = image.convert("RGB") # Ensure image is in RGB mode image.save(output_path, format="JPEG", quality=95) # Use JPEG format explicitly # Verify the image was saved correctly try: from PIL import Image test_load = Image.open(output_path) test_load.verify() # Verify image is valid test_load.close() except Exception as e: st.error(f"Image verification failed: {str(e)}. Using fallback.") return self.create_fallback_image(prompt) return output_path except Exception as e: st.error(f"Final attempt failed: {str(e)}. Using fallback image.") return self.create_fallback_image(prompt) def simplify_prompt(self, prompt): """Simplify a prompt to its core elements for better compatibility""" # Extract first sentence or up to 100 characters simple = prompt.split('.')[0].strip() if len(simple) > 100: simple = simple[:100] # Add minimal styling return f"{simple}, high quality, detailed" def clean_prompt_for_image_generation(self, prompt): """Clean prompt to avoid patterns that might cause text rendering in images""" # Remove patterns that might cause text rendering import re # Remove explicit text formatting instructions cleaned = re.sub(r'text\s+that\s+says', '', prompt, flags=re.IGNORECASE) cleaned = re.sub(r'with\s+text', '', cleaned, flags=re.IGNORECASE) cleaned = re.sub(r'showing\s+text', '', cleaned, flags=re.IGNORECASE) cleaned = re.sub(r'displaying\s+text', '', cleaned, flags=re.IGNORECASE) cleaned = re.sub(r'with\s+the\s+words', '', cleaned, flags=re.IGNORECASE) # Remove quotes which might encourage text cleaned = re.sub(r'["\'].*?["\']', '', cleaned) # Add negative prompt elements directly in the prompt cleaned += ", no text, no words, no writing, no letters, no numbers, no watermark" return cleaned def enhance_prompt_for_aspect_ratio(self, prompt): """Enhance the prompt based on the selected aspect ratio""" # Base enhancement for all prompts base_enhancement = "hyper realistic, photo realistic, ultra detailed, hyper detailed textures, 8K resolution" # Add cinematic lighting lighting_options = [ "golden hour glow", "moody overcast", "dramatic lighting", "soft natural light", "cinematic lighting", "film noir shadows" ] # Add camera effects camera_effects = [ "shallow depth of field", "motion blur", "film grain", "professional photography", "award winning photograph" ] # Add environmental details environmental_details = [ "atmospheric", "detailed environment", "rich textures", "detailed background", "immersive scene" ] # Select enhancements based on aspect ratio import random random.seed(hash(prompt)) # Use prompt as seed for deterministic selection selected_lighting = random.choice(lighting_options) selected_effect = random.choice(camera_effects) selected_detail = random.choice(environmental_details) # Aspect ratio specific enhancements if self.aspect_ratio == "16:9": # Landscape format - cinematic, wide view aspect_enhancement = "cinematic wide shot, landscape composition, panoramic view" elif self.aspect_ratio == "9:16": # Portrait format - vertical composition aspect_enhancement = "vertical composition, portrait framing, tall perspective" else: # Square format - balanced composition aspect_enhancement = "balanced composition, centered framing, square format" # Combine all enhancements enhanced_prompt = f"{prompt}, {base_enhancement}, {selected_lighting}, {selected_effect}, {selected_detail}, {aspect_enhancement}" return enhanced_prompt def create_fallback_image(self, prompt): """Create a fallback image when model generation fails""" from PIL import Image, ImageDraw, ImageFont # Create a gradient background width, height = self.target_size image = Image.new('RGB', (width, height), color=(240, 240, 240)) draw = ImageDraw.Draw(image) # Add a gradient for y in range(height): r = int(240 * (1 - y / height)) g = int(240 * (1 - y / height)) b = int(255 * (1 - y / height * 0.5)) for x in range(width): draw.point((x, y), fill=(r, g, b)) # Add text try: # Try to use a nice font if available font = ImageFont.truetype("Arial", 20) except: # Fallback to default font font = ImageFont.load_default() # Wrap text to fit width words = prompt.split() lines = [] current_line = [] for word in words: test_line = ' '.join(current_line + [word]) # Estimate text width (approximate method) if len(test_line) * 10 < width - 40: # 10 pixels per character, 20 pixel margin on each side current_line.append(word) else: lines.append(' '.join(current_line)) current_line = [word] if current_line: lines.append(' '.join(current_line)) # Draw text y_position = height // 4 for line in lines[:8]: # Limit to 8 lines draw.text((20, y_position), line, fill=(0, 0, 0), font=font) y_position += 30 # Save the image output_path = f"temp/fallback_{int(time.time() * 1000)}.png" image.save(output_path) return output_path def clear_cache(self): """Clear the image cache""" self.image_cache = {} return True