File size: 18,452 Bytes
15ca2ca
 
ad6d387
 
 
15ca2ca
ad6d387
 
15ca2ca
 
 
 
ad6d387
 
b4d330b
 
a4f9b16
ad6d387
b4d330b
 
 
 
 
15ca2ca
ad6d387
 
 
7323bbb
 
ad6d387
7323bbb
a4f9b16
ad6d387
 
 
a4f9b16
ad6d387
 
 
 
 
 
a4f9b16
 
 
 
 
ad6d387
 
a4f9b16
 
 
 
 
ad6d387
a4f9b16
 
 
 
 
 
 
ad6d387
a4f9b16
 
 
 
 
 
ad6d387
 
 
 
 
 
 
 
 
 
7bf78f6
9ed05b7
ad6d387
7bf78f6
 
ad6d387
b4d330b
ad6d387
 
b4d330b
ad6d387
b4d330b
ad6d387
 
9ed05b7
 
 
 
 
 
 
b4d330b
 
 
ad6d387
7bf78f6
 
ad6d387
7bf78f6
b4d330b
 
 
 
 
ad6d387
b4d330b
 
ad6d387
 
b4d330b
 
 
 
 
 
 
ad6d387
 
9ed05b7
ad6d387
 
 
7323bbb
7bf78f6
b4d330b
 
 
 
 
7bf78f6
 
 
 
 
ad6d387
 
7bf78f6
ad6d387
 
 
 
 
7323bbb
ad6d387
7323bbb
a4f9b16
 
 
 
 
9ed05b7
ad6d387
a4f9b16
9ed05b7
 
 
7bf78f6
 
 
ad6d387
 
 
a4f9b16
b4d330b
ad6d387
b4d330b
 
7bf78f6
 
 
 
 
 
 
 
 
 
 
 
 
a4f9b16
56f6fcc
 
 
 
 
 
 
 
 
 
 
 
 
 
a4f9b16
ad6d387
 
 
a4f9b16
ad6d387
 
a4f9b16
ad6d387
 
9ed05b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bf78f6
9ed05b7
 
 
 
56f6fcc
 
 
 
 
 
 
 
 
 
 
 
 
 
9ed05b7
 
a4f9b16
9ed05b7
ad6d387
9ed05b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f3247a
9ed05b7
 
 
 
 
ad6d387
 
 
 
 
7323bbb
ad6d387
 
 
 
 
7323bbb
ad6d387
 
 
 
 
15ca2ca
ad6d387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7323bbb
ad6d387
 
 
 
 
15ca2ca
ad6d387
15ca2ca
ad6d387
 
 
15ca2ca
ad6d387
 
 
 
15ca2ca
ad6d387
 
 
 
 
 
 
a87d440
ad6d387
 
 
 
 
 
 
15ca2ca
ad6d387
 
 
 
15ca2ca
ad6d387
 
 
 
 
 
 
 
7323bbb
ad6d387
 
7323bbb
ad6d387
 
 
 
 
7323bbb
ad6d387
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
import streamlit as st
import os
import tempfile
from PIL import Image
import torch
import time
import numpy as np
import gc

class ImageGenerator:
    def __init__(self):
        self.model = None
        self.processor = None
        self.target_size = (512, 512)
        self.inference_steps = 30  # Increased for better quality
        self.guidance_scale = 8.5  # Increased for better adherence to prompt
        self.aspect_ratio = "1:1"  # Default aspect ratio
        self.image_cache = {}
        self.vram_optimization = False  # Default to no VRAM optimization
        
    def set_vram_optimization(self, enabled):
        """Enable or disable VRAM optimization techniques"""
        self.vram_optimization = enabled
        
    def set_aspect_ratio(self, aspect_ratio):
        """Set the aspect ratio for image generation"""
        self.aspect_ratio = aspect_ratio
        
    def set_target_size(self, size):
        """Set the target size for generated images"""
        self.target_size = size
        
    def set_inference_steps(self, steps):
        """Set the number of inference steps for image generation"""
        self.inference_steps = steps
        
    def get_size_for_aspect_ratio(self, base_size, aspect_ratio=None):
        """Calculate image dimensions based on aspect ratio"""
        if aspect_ratio is None:
            aspect_ratio = self.aspect_ratio
            
        # Calculate base pixels (total pixels in the image)
        base_pixels = base_size[0] * base_size[1]
        
        if aspect_ratio == "1:1":
            # Square format
            side = int(np.sqrt(base_pixels))
            # Ensure even dimensions for compatibility
            side = side if side % 2 == 0 else side + 1
            return (side, side)
        elif aspect_ratio == "16:9":
            # Landscape format
            width = int(np.sqrt(base_pixels * 16 / 9))
            height = int(width * 9 / 16)
            # Ensure even dimensions for compatibility
            width = width if width % 2 == 0 else width + 1
            height = height if height % 2 == 0 else height + 1
            return (width, height)
        elif aspect_ratio == "9:16":
            # Portrait format
            height = int(np.sqrt(base_pixels * 16 / 9))
            width = int(height * 9 / 16)
            # Ensure even dimensions for compatibility
            width = width if width % 2 == 0 else width + 1
            height = height if height % 2 == 0 else height + 1
            return (width, height)
        else:
            # Default to original size
            return base_size
        
    def load_model(self):
        """Load the image generation model with optimizations for CPU"""
        if self.model is None:
            with st.spinner("Loading image generation model..."):
                try:
                    # Force garbage collection before loading model
                    gc.collect()
                    torch.cuda.empty_cache() if torch.cuda.is_available() else None
                    
                    # Import here to avoid loading until needed
                    from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
                    
                    # Use the most reliable model for Hugging Face Spaces
                    model_id = "CompVis/stable-diffusion-v1-4"  # Most compatible model
                    
                    # Optimize for Hugging Face Spaces with memory constraints
                    self.model = StableDiffusionPipeline.from_pretrained(
                        model_id,
                        torch_dtype=torch.float16,  # Use float16 for memory efficiency
                        safety_checker=None,        # Disable safety checker for speed
                        use_safetensors=True        # Use safetensors for better memory usage
                    )
                    
                    # Use a better scheduler for higher quality results
                    self.model.scheduler = DPMSolverMultistepScheduler.from_config(
                        self.model.scheduler.config,
                        algorithm_type="dpmsolver++",
                        solver_order=2
                    )
                    
                    # Use CUDA if available, otherwise CPU
                    device = "cuda" if torch.cuda.is_available() else "cpu"
                    self.model = self.model.to(device)
                    
                    # Enable maximum memory optimization for Hugging Face
                    self.model.enable_attention_slicing(slice_size=1)
                    
                    # Try to enable xformers if available
                    try:
                        import xformers
                        self.model.enable_xformers_memory_efficient_attention()
                    except (ImportError, AttributeError):
                        pass
                    
                    # Enable model CPU offloading if on CPU
                    if device == "cpu" and hasattr(self.model, "enable_model_cpu_offload"):
                        self.model.enable_model_cpu_offload()
                    
                    # Enable sequential CPU offload if on CPU
                    if device == "cpu" and hasattr(self.model, "enable_sequential_cpu_offload"):
                        self.model.enable_sequential_cpu_offload()
                        
                    # Use tiled VAE for larger images with less memory
                    if hasattr(self.model, "vae") and hasattr(self.model.vae, "enable_tiling"):
                        self.model.vae.enable_tiling()
                        
                except Exception as e:
                    st.error(f"Error loading image generation model: {str(e)}. Please try again with VRAM optimization enabled.")
                    self.model = None
        
        return self.model
    
    def generate_image(self, prompt, negative_prompt="blurry, bad quality, distorted, disfigured, low resolution, worst quality, deformed, text, watermark, writing, letters, numbers"):
        """Generate an image from a text prompt with optimized settings"""
        # Apply VRAM optimization if enabled
        inference_steps = self.inference_steps
        if self.vram_optimization:
            # Reduce inference steps for VRAM optimization
            inference_steps = min(inference_steps, 20)
        else:
            # Even without explicit VRAM optimization, limit steps for Hugging Face
            inference_steps = min(inference_steps, 30)
            
        # Generate a cache key based on the prompt and settings
        import hashlib
        cache_key = f"{hashlib.md5(prompt.encode()).hexdigest()}_{self.target_size}_{inference_steps}_{self.guidance_scale}_{self.aspect_ratio}"
        
        # Check if result is in cache
        if cache_key in self.image_cache:
            return self.image_cache[cache_key]
        
        # Ensure output directory exists
        os.makedirs("temp", exist_ok=True)
        
        try:
            # Load the model if not already loaded
            model = self.load_model()
            
            if model is not None:
                # Clean and enhance the prompt for better image generation
                enhanced_prompt = self.enhance_prompt_for_aspect_ratio(prompt)
                
                # Clean the prompt to remove problematic patterns that might cause text rendering
                enhanced_prompt = self.clean_prompt_for_image_generation(enhanced_prompt)
                
                # Simplify prompt for Hugging Face environment
                simplified_prompt = self.simplify_prompt(enhanced_prompt)
                
                # Force garbage collection before inference
                gc.collect()
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
                
                # Generate the image with optimized settings
                with torch.no_grad():  # Disable gradient calculation for memory efficiency
                    # Use autocast for the appropriate device
                    device = "cuda" if torch.cuda.is_available() else "cpu"
                    
                    # Set a lower guidance scale for better results with limited resources
                    guidance_scale = min(self.guidance_scale, 7.5)
                    
                    # Generate with minimal but effective settings
                    image = model(
                        prompt=simplified_prompt,
                        negative_prompt=negative_prompt,
                        num_inference_steps=inference_steps,
                        guidance_scale=guidance_scale,
                        width=min(self.target_size[0], 512),  # Limit size for Hugging Face
                        height=min(self.target_size[1], 512)  # Limit size for Hugging Face
                    ).images[0]
                
                # Save the image to a temporary file with explicit format
                output_path = f"temp/image_{int(time.time() * 1000)}.jpg"
                image = image.convert("RGB")  # Ensure image is in RGB mode
                image.save(output_path, format="JPEG", quality=95)  # Use JPEG format explicitly
                
                # Verify the image was saved correctly
                try:
                    from PIL import Image
                    test_load = Image.open(output_path)
                    test_load.verify()  # Verify image is valid
                    test_load.close()
                except Exception as e:
                    st.error(f"Image verification failed: {str(e)}. Using fallback.")
                    return self.create_fallback_image(prompt)
                
                # Force garbage collection after inference
                gc.collect()
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
                
                # Cache the result
                self.image_cache[cache_key] = output_path
                
                return output_path
            else:
                # If model failed to load, try one more time with reduced settings
                st.warning("Retrying with reduced settings...")
                return self.retry_with_reduced_settings(prompt)
        except Exception as e:
            st.error(f"Error generating image: {str(e)}. Retrying with reduced settings.")
            return self.retry_with_reduced_settings(prompt)
            
    def retry_with_reduced_settings(self, prompt):
        """Retry image generation with reduced settings for better compatibility"""
        try:
            # Force garbage collection
            gc.collect()
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
            # Reload model with more conservative settings
            from diffusers import StableDiffusionPipeline
            
            # Use the most stable model
            model_id = "CompVis/stable-diffusion-v1-4"
            
            # Load with minimal settings
            pipe = StableDiffusionPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.float16,
                safety_checker=None,
                use_safetensors=True
            )
            
            # Move to appropriate device
            device = "cuda" if torch.cuda.is_available() else "cpu"
            pipe = pipe.to(device)
            
            # Enable maximum memory optimization
            pipe.enable_attention_slicing(slice_size=1)
            
            # Clean the prompt to be very simple
            simple_prompt = self.simplify_prompt(prompt)
            
            # Generate with minimal settings
            image = pipe(
                prompt=simple_prompt,
                num_inference_steps=20,
                guidance_scale=7.0,
                width=512,
                height=512
            ).images[0]
            
            # Save the image with explicit format
            output_path = f"temp/retry_image_{int(time.time() * 1000)}.jpg"
            image = image.convert("RGB")  # Ensure image is in RGB mode
            image.save(output_path, format="JPEG", quality=95)  # Use JPEG format explicitly
            
            # Verify the image was saved correctly
            try:
                from PIL import Image
                test_load = Image.open(output_path)
                test_load.verify()  # Verify image is valid
                test_load.close()
            except Exception as e:
                st.error(f"Image verification failed: {str(e)}. Using fallback.")
                return self.create_fallback_image(prompt)
            
            return output_path
        except Exception as e:
            st.error(f"Final attempt failed: {str(e)}. Using fallback image.")
            return self.create_fallback_image(prompt)
            
    def simplify_prompt(self, prompt):
        """Simplify a prompt to its core elements for better compatibility"""
        # Extract first sentence or up to 100 characters
        simple = prompt.split('.')[0].strip()
        if len(simple) > 100:
            simple = simple[:100]
            
        # Add minimal styling
        return f"{simple}, high quality, detailed"
        
    def clean_prompt_for_image_generation(self, prompt):
        """Clean prompt to avoid patterns that might cause text rendering in images"""
        # Remove patterns that might cause text rendering
        import re
        
        # Remove explicit text formatting instructions
        cleaned = re.sub(r'text\s+that\s+says', '', prompt, flags=re.IGNORECASE)
        cleaned = re.sub(r'with\s+text', '', cleaned, flags=re.IGNORECASE)
        cleaned = re.sub(r'showing\s+text', '', cleaned, flags=re.IGNORECASE)
        cleaned = re.sub(r'displaying\s+text', '', cleaned, flags=re.IGNORECASE)
        cleaned = re.sub(r'with\s+the\s+words', '', cleaned, flags=re.IGNORECASE)
        
        # Remove quotes which might encourage text
        cleaned = re.sub(r'["\'].*?["\']', '', cleaned)
        
        # Add negative prompt elements directly in the prompt
        cleaned += ", no text, no words, no writing, no letters, no numbers, no watermark"
        
        return cleaned
    
    def enhance_prompt_for_aspect_ratio(self, prompt):
        """Enhance the prompt based on the selected aspect ratio"""
        # Base enhancement for all prompts
        base_enhancement = "hyper realistic, photo realistic, ultra detailed, hyper detailed textures, 8K resolution"
        
        # Add cinematic lighting
        lighting_options = [
            "golden hour glow", "moody overcast", "dramatic lighting", 
            "soft natural light", "cinematic lighting", "film noir shadows"
        ]
        
        # Add camera effects
        camera_effects = [
            "shallow depth of field", "motion blur", "film grain", 
            "professional photography", "award winning photograph"
        ]
        
        # Add environmental details
        environmental_details = [
            "atmospheric", "detailed environment", "rich textures",
            "detailed background", "immersive scene"
        ]
        
        # Select enhancements based on aspect ratio
        import random
        random.seed(hash(prompt))  # Use prompt as seed for deterministic selection
        
        selected_lighting = random.choice(lighting_options)
        selected_effect = random.choice(camera_effects)
        selected_detail = random.choice(environmental_details)
        
        # Aspect ratio specific enhancements
        if self.aspect_ratio == "16:9":
            # Landscape format - cinematic, wide view
            aspect_enhancement = "cinematic wide shot, landscape composition, panoramic view"
        elif self.aspect_ratio == "9:16":
            # Portrait format - vertical composition
            aspect_enhancement = "vertical composition, portrait framing, tall perspective"
        else:
            # Square format - balanced composition
            aspect_enhancement = "balanced composition, centered framing, square format"
        
        # Combine all enhancements
        enhanced_prompt = f"{prompt}, {base_enhancement}, {selected_lighting}, {selected_effect}, {selected_detail}, {aspect_enhancement}"
        
        return enhanced_prompt
    
    def create_fallback_image(self, prompt):
        """Create a fallback image when model generation fails"""
        from PIL import Image, ImageDraw, ImageFont
        
        # Create a gradient background
        width, height = self.target_size
        image = Image.new('RGB', (width, height), color=(240, 240, 240))
        draw = ImageDraw.Draw(image)
        
        # Add a gradient
        for y in range(height):
            r = int(240 * (1 - y / height))
            g = int(240 * (1 - y / height))
            b = int(255 * (1 - y / height * 0.5))
            for x in range(width):
                draw.point((x, y), fill=(r, g, b))
        
        # Add text
        try:
            # Try to use a nice font if available
            font = ImageFont.truetype("Arial", 20)
        except:
            # Fallback to default font
            font = ImageFont.load_default()
        
        # Wrap text to fit width
        words = prompt.split()
        lines = []
        current_line = []
        
        for word in words:
            test_line = ' '.join(current_line + [word])
            # Estimate text width (approximate method)
            if len(test_line) * 10 < width - 40:  # 10 pixels per character, 20 pixel margin on each side
                current_line.append(word)
            else:
                lines.append(' '.join(current_line))
                current_line = [word]
        
        if current_line:
            lines.append(' '.join(current_line))
        
        # Draw text
        y_position = height // 4
        for line in lines[:8]:  # Limit to 8 lines
            draw.text((20, y_position), line, fill=(0, 0, 0), font=font)
            y_position += 30
        
        # Save the image
        output_path = f"temp/fallback_{int(time.time() * 1000)}.png"
        image.save(output_path)
        
        return output_path
    
    def clear_cache(self):
        """Clear the image cache"""
        self.image_cache = {}
        return True