garyuzair commited on
Commit
a4f9b16
·
verified ·
1 Parent(s): fd65cf5

Update image_generator.py

Browse files
Files changed (1) hide show
  1. image_generator.py +163 -61
image_generator.py CHANGED
@@ -12,41 +12,47 @@ class ImageGenerator:
12
  self.model = None
13
  self.inference_steps = 20
14
  self.target_size = (384, 384)
 
15
 
16
  def load_model(self):
17
  """Load a lightweight image generation model"""
18
  if self.model is None:
19
  with st.spinner("Loading image generation model... This may take a moment."):
20
- # Using a lightweight model for image generation
21
- from diffusers import StableDiffusionPipeline
22
-
23
- model_id = "runwayml/stable-diffusion-v1-5"
24
-
25
- # Load with memory optimization settings
26
- self.model = StableDiffusionPipeline.from_pretrained(
27
- model_id,
28
- torch_dtype=torch.float32,
29
- safety_checker=None,
30
- requires_safety_checker=False
31
- )
32
-
33
- # Use CPU for inference to save memory
34
- self.model = self.model.to("cpu")
35
-
36
- # Enable memory efficient attention if available
37
- if hasattr(self.model, 'enable_attention_slicing'):
38
- self.model.enable_attention_slicing()
39
-
40
- # Enable memory efficient attention
41
- if hasattr(self.model, 'enable_vae_slicing'):
42
- self.model.enable_vae_slicing()
43
-
44
- # Enable xformers memory efficient attention if available
45
  try:
46
- if hasattr(self.model, 'enable_xformers_memory_efficient_attention'):
47
- self.model.enable_xformers_memory_efficient_attention()
48
- except:
49
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  return self.model
52
 
@@ -57,25 +63,140 @@ class ImageGenerator:
57
  def set_target_size(self, size):
58
  """Set the target image size"""
59
  self.target_size = size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def generate_image(self, prompt, output_dir="temp"):
62
  """Generate a single image from a prompt"""
63
- # Load the model if not already loaded
64
- model = self.load_model()
65
-
66
  # Ensure output directory exists
67
  os.makedirs(output_dir, exist_ok=True)
68
 
69
- # Generate image with minimal inference steps to save resources
70
- image = model(
71
- prompt,
72
- num_inference_steps=self.inference_steps,
73
- guidance_scale=7.5
74
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- # Resize to target size for consistency and performance
77
- if image.size != self.target_size:
78
- image = image.resize(self.target_size, Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  # Save the image
81
  image_path = f"{output_dir}/image_{int(time.time() * 1000)}.png"
@@ -85,9 +206,6 @@ class ImageGenerator:
85
 
86
  def generate_images(self, prompts, output_dir="temp", progress_callback=None, parallel=False, max_workers=4):
87
  """Generate images from the prompts"""
88
- # Load the model if not already loaded
89
- model = self.load_model()
90
-
91
  # Ensure output directory exists
92
  os.makedirs(output_dir, exist_ok=True)
93
 
@@ -153,9 +271,6 @@ class ImageGenerator:
153
 
154
  def batch_generate_images(self, prompts, batch_size=2, output_dir="temp", progress_callback=None):
155
  """Generate images in batches to optimize memory usage"""
156
- # Load the model if not already loaded
157
- model = self.load_model()
158
-
159
  # Ensure output directory exists
160
  os.makedirs(output_dir, exist_ok=True)
161
 
@@ -171,20 +286,7 @@ class ImageGenerator:
171
  # Generate images for this batch
172
  batch_images = []
173
  for j, prompt in enumerate(batch_prompts):
174
- # Generate image
175
- image = model(
176
- prompt,
177
- num_inference_steps=self.inference_steps,
178
- guidance_scale=7.5
179
- ).images[0]
180
-
181
- # Resize to target size
182
- if image.size != self.target_size:
183
- image = image.resize(self.target_size, Image.LANCZOS)
184
-
185
- # Save the image
186
- image_path = f"{output_dir}/image_{i+j}_{int(time.time() * 1000)}.png"
187
- image.save(image_path)
188
  batch_images.append(image_path)
189
 
190
  # Add batch results to overall results
 
12
  self.model = None
13
  self.inference_steps = 20
14
  self.target_size = (384, 384)
15
+ self.aspect_ratio = "1:1" # Default aspect ratio
16
 
17
  def load_model(self):
18
  """Load a lightweight image generation model"""
19
  if self.model is None:
20
  with st.spinner("Loading image generation model... This may take a moment."):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  try:
22
+ # Using a lightweight model for image generation
23
+ from diffusers import StableDiffusionPipeline
24
+
25
+ model_id = "runwayml/stable-diffusion-v1-5"
26
+
27
+ # Load with memory optimization settings
28
+ self.model = StableDiffusionPipeline.from_pretrained(
29
+ model_id,
30
+ torch_dtype=torch.float32,
31
+ safety_checker=None,
32
+ requires_safety_checker=False,
33
+ low_cpu_mem_usage=True
34
+ )
35
+
36
+ # Use CPU for inference to save memory
37
+ self.model = self.model.to("cpu")
38
+
39
+ # Enable memory efficient attention if available
40
+ if hasattr(self.model, 'enable_attention_slicing'):
41
+ self.model.enable_attention_slicing()
42
+
43
+ # Enable memory efficient attention
44
+ if hasattr(self.model, 'enable_vae_slicing'):
45
+ self.model.enable_vae_slicing()
46
+
47
+ # Enable xformers memory efficient attention if available
48
+ try:
49
+ if hasattr(self.model, 'enable_xformers_memory_efficient_attention'):
50
+ self.model.enable_xformers_memory_efficient_attention()
51
+ except:
52
+ pass
53
+ except Exception as e:
54
+ st.warning(f"Error loading image generation model: {str(e)}. Using fallback method.")
55
+ self.model = None
56
 
57
  return self.model
58
 
 
63
  def set_target_size(self, size):
64
  """Set the target image size"""
65
  self.target_size = size
66
+
67
+ def set_aspect_ratio(self, aspect_ratio):
68
+ """Set the aspect ratio for generated images"""
69
+ self.aspect_ratio = aspect_ratio
70
+
71
+ # Update target size based on aspect ratio while maintaining total pixels
72
+ base_pixels = self.target_size[0] * self.target_size[1]
73
+
74
+ if aspect_ratio == "1:1":
75
+ # Square format
76
+ side = int(np.sqrt(base_pixels))
77
+ self.target_size = (side, side)
78
+ elif aspect_ratio == "16:9":
79
+ # Landscape format
80
+ width = int(np.sqrt(base_pixels * 16 / 9))
81
+ height = int(width * 9 / 16)
82
+ self.target_size = (width, height)
83
+ elif aspect_ratio == "9:16":
84
+ # Portrait format
85
+ height = int(np.sqrt(base_pixels * 16 / 9))
86
+ width = int(height * 9 / 16)
87
+ self.target_size = (width, height)
88
+
89
+ def get_size_for_aspect_ratio(self, base_size, aspect_ratio):
90
+ """Calculate dimensions for a given aspect ratio while maintaining approximate total pixels"""
91
+ base_pixels = base_size[0] * base_size[1]
92
+
93
+ if aspect_ratio == "1:1":
94
+ # Square format
95
+ side = int(np.sqrt(base_pixels))
96
+ return (side, side)
97
+ elif aspect_ratio == "16:9":
98
+ # Landscape format
99
+ width = int(np.sqrt(base_pixels * 16 / 9))
100
+ height = int(width * 9 / 16)
101
+ # Ensure dimensions are even numbers for video compatibility
102
+ width = width if width % 2 == 0 else width + 1
103
+ height = height if height % 2 == 0 else height + 1
104
+ return (width, height)
105
+ elif aspect_ratio == "9:16":
106
+ # Portrait format
107
+ height = int(np.sqrt(base_pixels * 16 / 9))
108
+ width = int(height * 9 / 16)
109
+ # Ensure dimensions are even numbers for video compatibility
110
+ width = width if width % 2 == 0 else width + 1
111
+ height = height if height % 2 == 0 else height + 1
112
+ return (width, height)
113
+ else:
114
+ # Default to original size
115
+ return base_size
116
 
117
  def generate_image(self, prompt, output_dir="temp"):
118
  """Generate a single image from a prompt"""
 
 
 
119
  # Ensure output directory exists
120
  os.makedirs(output_dir, exist_ok=True)
121
 
122
+ try:
123
+ # Load the model if not already loaded
124
+ model = self.load_model()
125
+
126
+ if model is not None:
127
+ # Generate image with minimal inference steps to save resources
128
+ image = model(
129
+ prompt,
130
+ num_inference_steps=self.inference_steps,
131
+ guidance_scale=7.5
132
+ ).images[0]
133
+
134
+ # Resize to target size for consistency and performance
135
+ if image.size != self.target_size:
136
+ image = image.resize(self.target_size, Image.LANCZOS)
137
+ else:
138
+ # Fallback: Create a colored gradient image with text
139
+ from PIL import Image, ImageDraw, ImageFont, ImageFilter
140
+
141
+ # Create a base image with gradient background
142
+ image = Image.new('RGB', self.target_size, color=(240, 240, 240))
143
+ draw = ImageDraw.Draw(image)
144
+
145
+ # Create a gradient background
146
+ for y in range(image.height):
147
+ for x in range(image.width):
148
+ # Create a simple gradient
149
+ r = int(200 + (x * 55 / image.width))
150
+ g = int(200 + (y * 55 / image.height))
151
+ b = 240
152
+ draw.point((x, y), fill=(r, g, b))
153
+
154
+ # Add some noise/texture
155
+ image = image.filter(ImageFilter.GaussianBlur(radius=1))
156
+
157
+ # Add text from prompt (truncated)
158
+ draw = ImageDraw.Draw(image)
159
+ text = prompt[:50] + "..." if len(prompt) > 50 else prompt
160
+
161
+ # Position text
162
+ text_width = draw.textlength(text, font=None)
163
+ text_position = ((image.width - text_width) / 2, image.height / 2)
164
+
165
+ # Draw text
166
+ draw.text(text_position, text, fill=(0, 0, 0))
167
 
168
+ except Exception as e:
169
+ st.warning(f"Error generating image: {str(e)}. Using fallback method.")
170
+
171
+ # Fallback: Create a colored gradient image with text
172
+ from PIL import Image, ImageDraw, ImageFilter
173
+
174
+ # Create a base image with gradient background
175
+ image = Image.new('RGB', self.target_size, color=(240, 240, 240))
176
+ draw = ImageDraw.Draw(image)
177
+
178
+ # Create a gradient background
179
+ for y in range(image.height):
180
+ for x in range(image.width):
181
+ # Create a simple gradient
182
+ r = int(200 + (x * 55 / image.width))
183
+ g = int(200 + (y * 55 / image.height))
184
+ b = 240
185
+ draw.point((x, y), fill=(r, g, b))
186
+
187
+ # Add some noise/texture
188
+ image = image.filter(ImageFilter.GaussianBlur(radius=1))
189
+
190
+ # Add text from prompt (truncated)
191
+ draw = ImageDraw.Draw(image)
192
+ text = prompt[:50] + "..." if len(prompt) > 50 else prompt
193
+
194
+ # Position text
195
+ text_width = draw.textlength(text, font=None)
196
+ text_position = ((image.width - text_width) / 2, image.height / 2)
197
+
198
+ # Draw text
199
+ draw.text(text_position, text, fill=(0, 0, 0))
200
 
201
  # Save the image
202
  image_path = f"{output_dir}/image_{int(time.time() * 1000)}.png"
 
206
 
207
  def generate_images(self, prompts, output_dir="temp", progress_callback=None, parallel=False, max_workers=4):
208
  """Generate images from the prompts"""
 
 
 
209
  # Ensure output directory exists
210
  os.makedirs(output_dir, exist_ok=True)
211
 
 
271
 
272
  def batch_generate_images(self, prompts, batch_size=2, output_dir="temp", progress_callback=None):
273
  """Generate images in batches to optimize memory usage"""
 
 
 
274
  # Ensure output directory exists
275
  os.makedirs(output_dir, exist_ok=True)
276
 
 
286
  # Generate images for this batch
287
  batch_images = []
288
  for j, prompt in enumerate(batch_prompts):
289
+ image_path = self.generate_image(prompt, output_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  batch_images.append(image_path)
291
 
292
  # Add batch results to overall results