Video-Fx / image_generator.py
garyuzair's picture
Upload 6 files
56f6fcc verified
import streamlit as st
import os
import tempfile
from PIL import Image
import torch
import time
import numpy as np
import gc
class ImageGenerator:
def __init__(self):
self.model = None
self.processor = None
self.target_size = (512, 512)
self.inference_steps = 30 # Increased for better quality
self.guidance_scale = 8.5 # Increased for better adherence to prompt
self.aspect_ratio = "1:1" # Default aspect ratio
self.image_cache = {}
self.vram_optimization = False # Default to no VRAM optimization
def set_vram_optimization(self, enabled):
"""Enable or disable VRAM optimization techniques"""
self.vram_optimization = enabled
def set_aspect_ratio(self, aspect_ratio):
"""Set the aspect ratio for image generation"""
self.aspect_ratio = aspect_ratio
def set_target_size(self, size):
"""Set the target size for generated images"""
self.target_size = size
def set_inference_steps(self, steps):
"""Set the number of inference steps for image generation"""
self.inference_steps = steps
def get_size_for_aspect_ratio(self, base_size, aspect_ratio=None):
"""Calculate image dimensions based on aspect ratio"""
if aspect_ratio is None:
aspect_ratio = self.aspect_ratio
# Calculate base pixels (total pixels in the image)
base_pixels = base_size[0] * base_size[1]
if aspect_ratio == "1:1":
# Square format
side = int(np.sqrt(base_pixels))
# Ensure even dimensions for compatibility
side = side if side % 2 == 0 else side + 1
return (side, side)
elif aspect_ratio == "16:9":
# Landscape format
width = int(np.sqrt(base_pixels * 16 / 9))
height = int(width * 9 / 16)
# Ensure even dimensions for compatibility
width = width if width % 2 == 0 else width + 1
height = height if height % 2 == 0 else height + 1
return (width, height)
elif aspect_ratio == "9:16":
# Portrait format
height = int(np.sqrt(base_pixels * 16 / 9))
width = int(height * 9 / 16)
# Ensure even dimensions for compatibility
width = width if width % 2 == 0 else width + 1
height = height if height % 2 == 0 else height + 1
return (width, height)
else:
# Default to original size
return base_size
def load_model(self):
"""Load the image generation model with optimizations for CPU"""
if self.model is None:
with st.spinner("Loading image generation model..."):
try:
# Force garbage collection before loading model
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Import here to avoid loading until needed
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
# Use the most reliable model for Hugging Face Spaces
model_id = "CompVis/stable-diffusion-v1-4" # Most compatible model
# Optimize for Hugging Face Spaces with memory constraints
self.model = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16, # Use float16 for memory efficiency
safety_checker=None, # Disable safety checker for speed
use_safetensors=True # Use safetensors for better memory usage
)
# Use a better scheduler for higher quality results
self.model.scheduler = DPMSolverMultistepScheduler.from_config(
self.model.scheduler.config,
algorithm_type="dpmsolver++",
solver_order=2
)
# Use CUDA if available, otherwise CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self.model.to(device)
# Enable maximum memory optimization for Hugging Face
self.model.enable_attention_slicing(slice_size=1)
# Try to enable xformers if available
try:
import xformers
self.model.enable_xformers_memory_efficient_attention()
except (ImportError, AttributeError):
pass
# Enable model CPU offloading if on CPU
if device == "cpu" and hasattr(self.model, "enable_model_cpu_offload"):
self.model.enable_model_cpu_offload()
# Enable sequential CPU offload if on CPU
if device == "cpu" and hasattr(self.model, "enable_sequential_cpu_offload"):
self.model.enable_sequential_cpu_offload()
# Use tiled VAE for larger images with less memory
if hasattr(self.model, "vae") and hasattr(self.model.vae, "enable_tiling"):
self.model.vae.enable_tiling()
except Exception as e:
st.error(f"Error loading image generation model: {str(e)}. Please try again with VRAM optimization enabled.")
self.model = None
return self.model
def generate_image(self, prompt, negative_prompt="blurry, bad quality, distorted, disfigured, low resolution, worst quality, deformed, text, watermark, writing, letters, numbers"):
"""Generate an image from a text prompt with optimized settings"""
# Apply VRAM optimization if enabled
inference_steps = self.inference_steps
if self.vram_optimization:
# Reduce inference steps for VRAM optimization
inference_steps = min(inference_steps, 20)
else:
# Even without explicit VRAM optimization, limit steps for Hugging Face
inference_steps = min(inference_steps, 30)
# Generate a cache key based on the prompt and settings
import hashlib
cache_key = f"{hashlib.md5(prompt.encode()).hexdigest()}_{self.target_size}_{inference_steps}_{self.guidance_scale}_{self.aspect_ratio}"
# Check if result is in cache
if cache_key in self.image_cache:
return self.image_cache[cache_key]
# Ensure output directory exists
os.makedirs("temp", exist_ok=True)
try:
# Load the model if not already loaded
model = self.load_model()
if model is not None:
# Clean and enhance the prompt for better image generation
enhanced_prompt = self.enhance_prompt_for_aspect_ratio(prompt)
# Clean the prompt to remove problematic patterns that might cause text rendering
enhanced_prompt = self.clean_prompt_for_image_generation(enhanced_prompt)
# Simplify prompt for Hugging Face environment
simplified_prompt = self.simplify_prompt(enhanced_prompt)
# Force garbage collection before inference
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Generate the image with optimized settings
with torch.no_grad(): # Disable gradient calculation for memory efficiency
# Use autocast for the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Set a lower guidance scale for better results with limited resources
guidance_scale = min(self.guidance_scale, 7.5)
# Generate with minimal but effective settings
image = model(
prompt=simplified_prompt,
negative_prompt=negative_prompt,
num_inference_steps=inference_steps,
guidance_scale=guidance_scale,
width=min(self.target_size[0], 512), # Limit size for Hugging Face
height=min(self.target_size[1], 512) # Limit size for Hugging Face
).images[0]
# Save the image to a temporary file with explicit format
output_path = f"temp/image_{int(time.time() * 1000)}.jpg"
image = image.convert("RGB") # Ensure image is in RGB mode
image.save(output_path, format="JPEG", quality=95) # Use JPEG format explicitly
# Verify the image was saved correctly
try:
from PIL import Image
test_load = Image.open(output_path)
test_load.verify() # Verify image is valid
test_load.close()
except Exception as e:
st.error(f"Image verification failed: {str(e)}. Using fallback.")
return self.create_fallback_image(prompt)
# Force garbage collection after inference
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Cache the result
self.image_cache[cache_key] = output_path
return output_path
else:
# If model failed to load, try one more time with reduced settings
st.warning("Retrying with reduced settings...")
return self.retry_with_reduced_settings(prompt)
except Exception as e:
st.error(f"Error generating image: {str(e)}. Retrying with reduced settings.")
return self.retry_with_reduced_settings(prompt)
def retry_with_reduced_settings(self, prompt):
"""Retry image generation with reduced settings for better compatibility"""
try:
# Force garbage collection
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Reload model with more conservative settings
from diffusers import StableDiffusionPipeline
# Use the most stable model
model_id = "CompVis/stable-diffusion-v1-4"
# Load with minimal settings
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
safety_checker=None,
use_safetensors=True
)
# Move to appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)
# Enable maximum memory optimization
pipe.enable_attention_slicing(slice_size=1)
# Clean the prompt to be very simple
simple_prompt = self.simplify_prompt(prompt)
# Generate with minimal settings
image = pipe(
prompt=simple_prompt,
num_inference_steps=20,
guidance_scale=7.0,
width=512,
height=512
).images[0]
# Save the image with explicit format
output_path = f"temp/retry_image_{int(time.time() * 1000)}.jpg"
image = image.convert("RGB") # Ensure image is in RGB mode
image.save(output_path, format="JPEG", quality=95) # Use JPEG format explicitly
# Verify the image was saved correctly
try:
from PIL import Image
test_load = Image.open(output_path)
test_load.verify() # Verify image is valid
test_load.close()
except Exception as e:
st.error(f"Image verification failed: {str(e)}. Using fallback.")
return self.create_fallback_image(prompt)
return output_path
except Exception as e:
st.error(f"Final attempt failed: {str(e)}. Using fallback image.")
return self.create_fallback_image(prompt)
def simplify_prompt(self, prompt):
"""Simplify a prompt to its core elements for better compatibility"""
# Extract first sentence or up to 100 characters
simple = prompt.split('.')[0].strip()
if len(simple) > 100:
simple = simple[:100]
# Add minimal styling
return f"{simple}, high quality, detailed"
def clean_prompt_for_image_generation(self, prompt):
"""Clean prompt to avoid patterns that might cause text rendering in images"""
# Remove patterns that might cause text rendering
import re
# Remove explicit text formatting instructions
cleaned = re.sub(r'text\s+that\s+says', '', prompt, flags=re.IGNORECASE)
cleaned = re.sub(r'with\s+text', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'showing\s+text', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'displaying\s+text', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'with\s+the\s+words', '', cleaned, flags=re.IGNORECASE)
# Remove quotes which might encourage text
cleaned = re.sub(r'["\'].*?["\']', '', cleaned)
# Add negative prompt elements directly in the prompt
cleaned += ", no text, no words, no writing, no letters, no numbers, no watermark"
return cleaned
def enhance_prompt_for_aspect_ratio(self, prompt):
"""Enhance the prompt based on the selected aspect ratio"""
# Base enhancement for all prompts
base_enhancement = "hyper realistic, photo realistic, ultra detailed, hyper detailed textures, 8K resolution"
# Add cinematic lighting
lighting_options = [
"golden hour glow", "moody overcast", "dramatic lighting",
"soft natural light", "cinematic lighting", "film noir shadows"
]
# Add camera effects
camera_effects = [
"shallow depth of field", "motion blur", "film grain",
"professional photography", "award winning photograph"
]
# Add environmental details
environmental_details = [
"atmospheric", "detailed environment", "rich textures",
"detailed background", "immersive scene"
]
# Select enhancements based on aspect ratio
import random
random.seed(hash(prompt)) # Use prompt as seed for deterministic selection
selected_lighting = random.choice(lighting_options)
selected_effect = random.choice(camera_effects)
selected_detail = random.choice(environmental_details)
# Aspect ratio specific enhancements
if self.aspect_ratio == "16:9":
# Landscape format - cinematic, wide view
aspect_enhancement = "cinematic wide shot, landscape composition, panoramic view"
elif self.aspect_ratio == "9:16":
# Portrait format - vertical composition
aspect_enhancement = "vertical composition, portrait framing, tall perspective"
else:
# Square format - balanced composition
aspect_enhancement = "balanced composition, centered framing, square format"
# Combine all enhancements
enhanced_prompt = f"{prompt}, {base_enhancement}, {selected_lighting}, {selected_effect}, {selected_detail}, {aspect_enhancement}"
return enhanced_prompt
def create_fallback_image(self, prompt):
"""Create a fallback image when model generation fails"""
from PIL import Image, ImageDraw, ImageFont
# Create a gradient background
width, height = self.target_size
image = Image.new('RGB', (width, height), color=(240, 240, 240))
draw = ImageDraw.Draw(image)
# Add a gradient
for y in range(height):
r = int(240 * (1 - y / height))
g = int(240 * (1 - y / height))
b = int(255 * (1 - y / height * 0.5))
for x in range(width):
draw.point((x, y), fill=(r, g, b))
# Add text
try:
# Try to use a nice font if available
font = ImageFont.truetype("Arial", 20)
except:
# Fallback to default font
font = ImageFont.load_default()
# Wrap text to fit width
words = prompt.split()
lines = []
current_line = []
for word in words:
test_line = ' '.join(current_line + [word])
# Estimate text width (approximate method)
if len(test_line) * 10 < width - 40: # 10 pixels per character, 20 pixel margin on each side
current_line.append(word)
else:
lines.append(' '.join(current_line))
current_line = [word]
if current_line:
lines.append(' '.join(current_line))
# Draw text
y_position = height // 4
for line in lines[:8]: # Limit to 8 lines
draw.text((20, y_position), line, fill=(0, 0, 0), font=font)
y_position += 30
# Save the image
output_path = f"temp/fallback_{int(time.time() * 1000)}.png"
image.save(output_path)
return output_path
def clear_cache(self):
"""Clear the image cache"""
self.image_cache = {}
return True