HAL1993 commited on
Commit
52c9b64
·
verified ·
1 Parent(s): 3db3884

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -147
app.py CHANGED
@@ -9,17 +9,19 @@ from custom_pipeline import FluxWithCFGPipeline
9
 
10
  # --- Torch Optimizations ---
11
  torch.backends.cuda.matmul.allow_tf32 = True
12
- torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
13
 
14
  # --- Constants ---
15
  MAX_SEED = np.iinfo(np.int32).max
16
- MAX_IMAGE_SIZE = 2048 # Keep a reasonable limit to prevent OOMs
17
  DEFAULT_WIDTH = 1024
18
- DEFAULT_HEIGHT = 1024
19
- DEFAULT_INFERENCE_STEPS = 1 # FLUX Schnell is designed for few steps
20
- MIN_INFERENCE_STEPS = 1
21
- MAX_INFERENCE_STEPS = 8 # Allow slightly more steps for potential quality boost
22
- ENHANCE_STEPS = 2 # Fixed steps for the enhance button
 
 
23
 
24
  # --- Device and Model Setup ---
25
  dtype = torch.float16
@@ -27,174 +29,95 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
27
 
28
  pipe = FluxWithCFGPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
29
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
30
-
31
  pipe.to(device)
32
 
33
  # --- Inference Function ---
34
  @spaces.GPU
35
- def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
36
- """Generates an image using the FLUX pipeline with error handling."""
37
-
38
  if pipe is None:
39
- raise gr.Error("Diffusion pipeline failed to load. Cannot generate images.")
40
-
41
  if not prompt or prompt.strip() == "":
42
- gr.Warning("Prompt is empty. Please enter a description.")
43
- return None, seed, "Error: Empty prompt"
44
 
45
- start_time = time.time()
46
-
47
  if randomize_seed:
48
  seed = random.randint(0, MAX_SEED)
49
-
50
- # Clamp dimensions to avoid excessive memory usage
51
- width = min(width, MAX_IMAGE_SIZE)
52
- height = min(height, MAX_IMAGE_SIZE)
53
-
54
- # Use fixed steps for enhance button, otherwise use slider value
55
- steps_to_use = ENHANCE_STEPS if is_enhance else num_inference_steps
56
- # Clamp steps
57
- steps_to_use = max(MIN_INFERENCE_STEPS, min(steps_to_use, MAX_INFERENCE_STEPS))
58
 
59
  try:
60
- # Ensure generator is on the correct device
61
  generator = torch.Generator(device=device).manual_seed(int(float(seed)))
62
-
63
- # Use inference_mode for efficiency
64
  with torch.inference_mode():
65
- # Generate the image (assuming pipe returns list/tuple with image first)
66
- # Modify pipe call based on its actual signature if needed
67
- result_img = pipe(
68
- prompt=prompt,
69
  width=width,
70
  height=height,
71
- num_inference_steps=steps_to_use,
72
  generator=generator,
73
- output_type="pil", # Ensure PIL output for Gradio Image component
74
- return_dict=False # Assuming the custom pipeline supports this for direct output
75
- )[0][0] # Assuming the output structure is [[img]]
76
-
77
  latency = time.time() - start_time
78
- latency_str = f"Latency: {latency:.2f} seconds (Steps: {steps_to_use})"
79
- return result_img, seed, latency_str
80
-
81
- except torch.cuda.OutOfMemoryError as e:
82
- # Clear cache and suggest reducing size/steps
83
- if torch.cuda.is_available():
84
- torch.cuda.empty_cache()
85
- raise gr.Error("GPU ran out of memory. Try reducing the image width/height or the number of inference steps.")
86
-
87
  except Exception as e:
88
- # Clear cache just in case
89
  if torch.cuda.is_available():
90
  torch.cuda.empty_cache()
91
- raise gr.Error(f"An error occurred during generation: {e}")
92
 
93
-
94
- # --- Example Prompts ---
95
  examples = [
96
- "a tiny astronaut hatching from an egg on the moon",
97
- "a cute white cat holding a sign that says hello world",
98
- "an anime illustration of Steve Jobs",
99
- "Create image of Modern house in minecraft style",
100
- "photo of a woman on the beach, shot from above. She is facing the sea, while wearing a white dress. She has long blonde hair",
101
- "Selfie photo of a wizard with long beard and purple robes, he is apparently in the middle of Tokyo. Probably taken from a phone.",
102
- "Photo of a young woman with long, wavy brown hair tied in a bun and glasses. She has a fair complexion and is wearing subtle makeup, emphasizing her eyes and lips. She is dressed in a black top. The background appears to be an urban setting with a building facade, and the sunlight casts a warm glow on her face.",
103
- "High-resolution photorealistic render of a sleek, futuristic motorcycle parked on a neon-lit street at night, rain reflecting the lights.",
104
- "Watercolor painting of a cozy bookstore interior with overflowing shelves and a cat sleeping in a sunbeam.",
105
  ]
106
 
107
- # --- Gradio UI ---
108
- with gr.Blocks() as demo:
109
- with gr.Column(elem_id="app-container"):
110
- gr.Markdown("# 🎨 Realtime FLUX Image Generator")
111
- gr.Markdown("Generate stunning images in real-time with Modified Flux.Schnell pipeline.")
112
- gr.Markdown("<span style='color: red;'>Note: Sometimes it stucks or stops generating images (I don't know why). In that situation just refresh the site.</span>")
113
-
114
- with gr.Row():
115
- with gr.Column(scale=2.5):
116
- result = gr.Image(label="Generated Image", show_label=False, interactive=False)
117
- with gr.Column(scale=1):
118
- prompt = gr.Text(
119
- label="Prompt",
120
- placeholder="Describe the image you want to generate...",
121
- lines=3,
122
- show_label=False,
123
- container=False,
124
- )
125
- generateBtn = gr.Button("🖼️ Generate Image")
126
- enhanceBtn = gr.Button("🚀 Enhance Image")
127
-
128
- with gr.Column("Advanced Options"):
129
- with gr.Row():
130
- realtime = gr.Checkbox(label="Realtime Toggler", info="If TRUE then uses more GPU but create image in realtime.", value=False)
131
- latency = gr.Text(label="Latency")
132
- with gr.Row():
133
- seed = gr.Number(label="Seed", value=42)
134
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
135
- with gr.Row():
136
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_WIDTH)
137
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_HEIGHT)
138
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=MIN_INFERENCE_STEPS, maximum=MAX_INFERENCE_STEPS, step=1, value=DEFAULT_INFERENCE_STEPS)
139
-
140
- with gr.Row():
141
- gr.Markdown("### 🌟 Inspiration Gallery")
142
- with gr.Row():
143
- gr.Examples(
144
- examples=examples,
145
- fn=generate_image,
146
- inputs=[prompt],
147
- outputs=[result, seed, latency],
148
- cache_examples=True,
149
- cache_mode="eager"
150
- )
151
-
152
- enhanceBtn.click(
153
- fn=generate_image,
154
- inputs=[prompt, seed, width, height],
155
- outputs=[result, seed, latency],
156
- show_progress="full"
157
- )
158
-
159
- generateBtn.click(
160
  fn=generate_image,
161
- inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
162
- outputs=[result, seed, latency],
163
- show_progress="full",
164
- api_name="RealtimeFlux",
165
- )
166
-
167
- def update_ui(realtime_enabled):
168
- return {
169
- prompt: gr.update(interactive=True),
170
- generateBtn: gr.update(visible=not realtime_enabled)
171
- }
172
-
173
- def realtime_generation(*args):
174
- if args[0]: # If realtime is enabled
175
- return next(generate_image(*args[1:]))
176
-
177
- realtime.change(
178
- fn=update_ui,
179
- inputs=[realtime],
180
- outputs=[prompt, generateBtn]
181
  )
182
 
183
- prompt.submit(
184
  fn=generate_image,
185
- inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
186
- outputs=[result, seed, latency],
187
  show_progress="full"
188
  )
189
 
190
- for component in [prompt, width, height, num_inference_steps]:
191
- component.input(
192
- fn=realtime_generation,
193
- inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
194
- outputs=[result, seed, latency],
195
- show_progress="hidden",
196
- trigger_mode="always_last"
197
- )
198
-
199
- # Launch the app
200
- demo.launch()
 
9
 
10
  # --- Torch Optimizations ---
11
  torch.backends.cuda.matmul.allow_tf32 = True
12
+ torch.backends.cudnn.benchmark = True
13
 
14
  # --- Constants ---
15
  MAX_SEED = np.iinfo(np.int32).max
16
+ MAX_IMAGE_SIZE = 1024
17
  DEFAULT_WIDTH = 1024
18
+ DEFAULT_HEIGHT = 576
19
+ ASPECT_RATIOS = {
20
+ "16:9": (1024, 576),
21
+ "1:1": (1024, 1024),
22
+ "9:16": (576, 1024)
23
+ }
24
+ INFERENCE_STEPS = 8
25
 
26
  # --- Device and Model Setup ---
27
  dtype = torch.float16
 
29
 
30
  pipe = FluxWithCFGPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
31
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
 
32
  pipe.to(device)
33
 
34
  # --- Inference Function ---
35
  @spaces.GPU
36
+ def generate_image(prompt: str, seed: int = 42, aspect_ratio: str = "16:9", randomize_seed: bool = False):
 
 
37
  if pipe is None:
38
+ raise gr.Error("Pipelinei nuk u ngarkua.")
39
+
40
  if not prompt or prompt.strip() == "":
41
+ return None, seed, "Gabim: Plotësoni përshkrimin."
 
42
 
 
 
43
  if randomize_seed:
44
  seed = random.randint(0, MAX_SEED)
45
+
46
+ width, height = ASPECT_RATIOS.get(aspect_ratio, (DEFAULT_WIDTH, DEFAULT_HEIGHT))
47
+
48
+ # Allow prompts in Albanian or English and still enhance them
49
+ enhanced_prompt = prompt.strip() + ", ultra realistic, sharp, 8k resolution, cinematic lighting"
 
 
 
 
50
 
51
  try:
 
52
  generator = torch.Generator(device=device).manual_seed(int(float(seed)))
53
+ start_time = time.time()
 
54
  with torch.inference_mode():
55
+ image = pipe(
56
+ prompt=enhanced_prompt,
 
 
57
  width=width,
58
  height=height,
59
+ num_inference_steps=INFERENCE_STEPS,
60
  generator=generator,
61
+ output_type="pil",
62
+ return_dict=False
63
+ )[0][0]
 
64
  latency = time.time() - start_time
65
+ return image, seed, f"Koha e përpunimit: {latency:.2f} sekonda"
 
 
 
 
 
 
 
 
66
  except Exception as e:
 
67
  if torch.cuda.is_available():
68
  torch.cuda.empty_cache()
69
+ raise gr.Error(f"Gabim gjatë gjenerimit: {e}")
70
 
 
 
71
  examples = [
72
+ "Qytet futuristik natën me drita neon",
73
+ "Një mace e bardhë mban një tabelë përshëndetëse",
74
+ "Një astronaut del nga një vezë në Hënë",
75
+ "Pamje nga një shtëpi moderne stilin Minecraft"
 
 
 
 
 
76
  ]
77
 
78
+ # --- App Layout ---
79
+ with gr.Blocks(css="""
80
+ body::before {
81
+ content: "";
82
+ display: block;
83
+ height: 640px;
84
+ background-color: #0f1117;
85
+ }
86
+ button[aria-label="Download"] {
87
+ transform: scale(1.5);
88
+ transform-origin: top right;
89
+ margin: 0 !important;
90
+ padding: 6px !important;
91
+ }
92
+ """) as app:
93
+ gr.Markdown("# 🖼️ Gjenerues Imazhesh FLUX")
94
+ gr.Markdown("Përdor modelin FLUX për të krijuar imazhe fantastike nga përshkrime në **gjuhën shqipe ose angleze**.")
95
+
96
+ with gr.Row():
97
+ with gr.Column(scale=2):
98
+ output_image = gr.Image(label="Imazhi i Gjeneruar", interactive=False, show_download_button=True)
99
+ with gr.Column(scale=1):
100
+ prompt = gr.Text(label="Përshkrimi", placeholder="Shkruani se çfarë doni të krijoni...", lines=3)
101
+ generate_btn = gr.Button("🎨 Gjenero")
102
+ aspect_ratio = gr.Radio(label="Raporti i Imazhit", choices=list(ASPECT_RATIOS.keys()), value="16:9")
103
+ randomize_seed = gr.Checkbox(label="Përdor numër të rastësishëm", value=True)
104
+ latency = gr.Text(label="Koha", interactive=False)
105
+
106
+ gr.Markdown("### 📌 Shembuj Frymëzues")
107
+ gr.Examples(
108
+ examples=examples,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  fn=generate_image,
110
+ inputs=[prompt],
111
+ outputs=[output_image, gr.Number(visible=False), latency],
112
+ cache_examples=True,
113
+ cache_mode="eager"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  )
115
 
116
+ generate_btn.click(
117
  fn=generate_image,
118
+ inputs=[prompt, gr.Number(value=42, visible=False), aspect_ratio, randomize_seed],
119
+ outputs=[output_image, gr.Number(visible=False), latency],
120
  show_progress="full"
121
  )
122
 
123
+ app.launch(share=True)