HAL1993 commited on
Commit
4f8f533
·
verified ·
1 Parent(s): e797c87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -100
app.py CHANGED
@@ -1,18 +1,32 @@
1
  # =============================================================
2
- # 1️⃣ FORCE ALL CACHE TO RAM‑DISK ( /tmp )
3
  # =============================================================
4
- import os
5
-
6
- # All hugging‑face / torch caches point to /tmp this area is NOT
7
- # counted towards the 150 GB quota of a Space.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
11
  os.environ["TORCH_HOME"] = "/tmp/torch_home"
12
 
13
- # ------------------------------------------------------------
14
- # 2️⃣ IMPORTS
15
- # ------------------------------------------------------------
16
  import spaces
17
  import torch
18
  import numpy as np
@@ -22,8 +36,6 @@ import tempfile
22
  import requests
23
  import logging
24
  from PIL import Image
25
- import shutil
26
- import pathlib
27
 
28
  import gradio as gr
29
  from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
@@ -36,9 +48,9 @@ import aoti
36
  logging.basicConfig(level=logging.INFO)
37
  logger = logging.getLogger(__name__)
38
 
39
- # ------------------------------------------------------------
40
- # 3️⃣ CONFIG
41
- # ------------------------------------------------------------
42
  MAX_DIM = 832
43
  MIN_DIM = 480
44
  SQUARE_DIM = 640
@@ -51,13 +63,14 @@ MAX_FRAMES_MODEL = 80
51
 
52
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
53
  default_negative_prompt = (
54
- "colorful tones, overexposed, static, unclear details, subtitles, style, artwork, painting, screen, still, overall gray, worst quality, "
55
- "low quality, JPEG compression artifacts, ugly, deformed, extra fingers, poorly drawn hands, poorly drawn face, deformed, mutated, "
56
- "deformed limbs, fused fingers, still screen, messy background, three legs, many people in background, walking backwards"
 
57
  )
58
 
59
  # ------------------------------------------------------------
60
- # 4️⃣ UNIVERSAL TRANSLATOR (Albanian → English)
61
  # ------------------------------------------------------------
62
  def translate_albanian_to_english(text: str) -> str:
63
  if not text.strip():
@@ -81,23 +94,7 @@ def translate_albanian_to_english(text: str) -> str:
81
  return text
82
 
83
  # ------------------------------------------------------------
84
- # 5️⃣ CLEAN ANY PRE‑EXISTING CACHE (only needed on the *first* run)
85
- # ------------------------------------------------------------
86
- def _clean_existing_cache():
87
- for p in [
88
- pathlib.Path.home() / ".cache",
89
- pathlib.Path("/workspace") / ".cache",
90
- pathlib.Path("/tmp") / "hf_cache",
91
- pathlib.Path("/tmp") / "torch_home",
92
- ]:
93
- if p.exists():
94
- logger.info(f"Removing existing cache folder: {p}")
95
- shutil.rmtree(p, ignore_errors=True)
96
-
97
- _clean_existing_cache()
98
-
99
- # ------------------------------------------------------------
100
- # 6️⃣ MODEL LOADING (all caches forced to /tmp)
101
  # ------------------------------------------------------------
102
  pipe = WanImageToVideoPipeline.from_pretrained(
103
  "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
@@ -149,42 +146,36 @@ aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da")
149
  aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
150
 
151
  # ------------------------------------------------------------
152
- # 7️⃣ IMAGE RESIZING HELPERS
153
  # ------------------------------------------------------------
154
  def resize_image(image: Image.Image) -> Image.Image:
155
  """Resize / crop the input image so the model receives a valid size."""
156
- width, height = image.size
157
-
158
- if width == height:
159
  return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
160
 
161
- aspect_ratio = width / height
162
- MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
163
- MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
164
 
165
  img = image
166
-
167
- if aspect_ratio > MAX_ASPECT_RATIO:
168
- # Very wide crop width
169
- crop_w = int(round(height * MAX_ASPECT_RATIO))
170
- left = (width - crop_w) // 2
171
- img = image.crop((left, 0, left + crop_w, height))
172
- elif aspect_ratio < MIN_ASPECT_RATIO:
173
- # Very tall crop height
174
- crop_h = int(round(width / MIN_ASPECT_RATIO))
175
- top = (height - crop_h) // 2
176
- img = image.crop((0, top, width, top + crop_h))
177
  else:
178
- # No cropping needed – just compute target size
179
- if width > height: # landscape
180
  target_w = MAX_DIM
181
- target_h = int(round(target_w / aspect_ratio))
182
- else: # portrait
183
  target_h = MAX_DIM
184
- target_w = int(round(target_h * aspect_ratio))
185
  img = image
186
 
187
- # Round to the nearest multiple of MULTIPLE_OF and clamp
188
  final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
189
  final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
190
  final_w = max(MIN_DIM, min(MAX_DIM, final_w))
@@ -194,7 +185,7 @@ def resize_image(image: Image.Image) -> Image.Image:
194
 
195
 
196
  def get_num_frames(duration_seconds: float) -> int:
197
- """Number of frames the model will generate for the requested duration."""
198
  return 1 + int(
199
  np.clip(
200
  int(round(duration_seconds * FIXED_FPS)),
@@ -214,26 +205,24 @@ def get_duration(
214
  guidance_scale_2,
215
  seed,
216
  randomize_seed,
217
- progress, # <-- required by @spaces.GPU
218
  ):
219
  """
220
- Rough estimate of how long the GPU will be occupied.
221
- Used by the @spaces.GPU decorator to enforce the 30‑second safety cap.
222
  """
223
- BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
224
- BASE_STEP_DURATION = 15
225
 
226
  w, h = resize_image(input_image).size
227
  frames = get_num_frames(duration_seconds)
228
- factor = frames * w * h / BASE_FRAMES_HEIGHT_WIDTH
229
- step_duration = BASE_STEP_DURATION * factor ** 1.5
230
- est = 10 + int(steps) * step_duration
231
-
232
- # Never block the GPU > 30 s
233
- return min(est, 30)
234
 
235
  # ------------------------------------------------------------
236
- # 8️⃣ GENERATION FUNCTION (keeps memory low)
237
  # ------------------------------------------------------------
238
  @spaces.GPU(duration=get_duration)
239
  def generate_video(
@@ -248,25 +237,19 @@ def generate_video(
248
  randomize_seed=False,
249
  progress=gr.Progress(track_tqdm=True),
250
  ):
251
- """Generate a video from an image + prompt. Returns (video_path, seed_used)."""
252
  if input_image is None:
253
  raise gr.Error("Please upload an input image.")
254
 
255
- # -----------------------------------------------------------------
256
- # Prompt translation (Albanian → English)
257
- # -----------------------------------------------------------------
258
  prompt = translate_albanian_to_english(prompt_input)
259
 
260
- # -----------------------------------------------------------------
261
- # Prepare model inputs
262
- # -----------------------------------------------------------------
263
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
264
  resized = resize_image(input_image)
265
  num_frames = get_num_frames(duration_seconds)
266
 
267
- # -----------------------------------------------------------------
268
- # Model inference
269
- # -----------------------------------------------------------------
270
  out = pipe(
271
  image=resized,
272
  prompt=prompt,
@@ -281,22 +264,16 @@ def generate_video(
281
  )
282
  frames = out.frames[0]
283
 
284
- # -----------------------------------------------------------------
285
- # Write temporary MP4 (still inside /tmp, will be removed later)
286
- # -----------------------------------------------------------------
287
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp:
288
  video_path = tmp.name
289
  export_to_video(frames, video_path, fps=FIXED_FPS)
290
 
291
- # -----------------------------------------------------------------
292
- # Unload AoT blocks – they occupy several GB on disk
293
- # -----------------------------------------------------------------
294
  aoti.aoti_blocks_unload(pipe.transformer)
295
  aoti.aoti_blocks_unload(pipe.transformer_2)
296
 
297
- # -----------------------------------------------------------------
298
- # GPU cleanup
299
- # -----------------------------------------------------------------
300
  gc.collect()
301
  torch.cuda.empty_cache()
302
 
@@ -304,7 +281,7 @@ def generate_video(
304
 
305
 
306
  # ------------------------------------------------------------
307
- # 9️⃣ UI – EXACT SAME LOOK & FEEL AS THE ORIGINAL
308
  # ------------------------------------------------------------
309
  with gr.Blocks(
310
  css="""
@@ -559,8 +536,8 @@ footer,.gr-button-secondary{
559
  .gr-group{
560
  background:#000!important;
561
  border:none!important;
562
- width:100%!important;
563
- max-width:100vw!important;
564
  }
565
  @media (max-width:768px){
566
  h1{font-size:4rem;}
@@ -583,19 +560,19 @@ footer,.gr-button-secondary{
583
  ) as demo:
584
 
585
  # -------------------------------------------------
586
- # 500‑ERROR GUARD – same unique link as before
587
  # -------------------------------------------------
588
  gr.HTML("""
589
  <script>
590
  if (!window.location.pathname.includes('b9v0c1x2z3a4s5d6f7g8h9j0k1l2m3n4b5v6c7x8z9a0s1d2f3g4h5j6k7l8m9n0')) {
591
- document.body.innerHTML = '<h1 style="color:#ef4444;font-family:Orbitron,sans-serif;text-align:center;margin-top:100px;">500 Internal Server Error</h1>';
592
- throw new Error('500');
593
  }
594
  </script>
595
  """)
596
 
597
  # -------------------------------------------------
598
- # UI layout – identical to the original demo
599
  # -------------------------------------------------
600
  with gr.Row(elem_id="general_items"):
601
  gr.Markdown("# ")
@@ -637,6 +614,14 @@ footer,.gr-button-secondary{
637
  # -------------------------------------------------
638
  # Wiring – order must match generate_video signature
639
  # -------------------------------------------------
 
 
 
 
 
 
 
 
640
  generate_btn.click(
641
  fn=generate_video,
642
  inputs=[
@@ -649,13 +634,14 @@ footer,.gr-button-secondary{
649
  gr.State(value=1.5), # guidance_scale_2
650
  gr.State(value=42), # seed
651
  gr.State(value=True), # randomize_seed
652
- # progress is injected automatically by @spaces.GPU
653
  ],
654
- outputs=[output_video, gr.State(value=42)], # hidden seed output
 
655
  )
656
 
657
  # ------------------------------------------------------------
658
- # 10️⃣ MAIN
659
  # ------------------------------------------------------------
660
  if __name__ == "__main__":
661
  demo.queue().launch(share=True)
 
1
  # =============================================================
2
+ # 0️⃣ FORCE ALL CACHES TO EPHEMERAL /tmp (DO NOT COUNT TO 150 GB)
3
  # =============================================================
4
+ import os, shutil, pathlib
5
+ # -----------------------------------------------------------------
6
+ # Clean any leftover cache that may already be on the persistent volume.
7
+ # This runs **once** at container start, before any import that touches HF.
8
+ # -----------------------------------------------------------------
9
+ for p in [
10
+ pathlib.Path.home() / ".cache",
11
+ pathlib.Path("/workspace") / ".cache",
12
+ pathlib.Path("/tmp") / "hf_cache",
13
+ pathlib.Path("/tmp") / "torch_home",
14
+ ]:
15
+ if p.exists():
16
+ shutil.rmtree(p, ignore_errors=True)
17
+
18
+ # -----------------------------------------------------------------
19
+ # Point every HF / torch cache to /tmp (which is a RAM‑disk and is
20
+ # NOT counted against the Space’s disk quota).
21
+ # -----------------------------------------------------------------
22
  os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
23
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
24
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
25
  os.environ["TORCH_HOME"] = "/tmp/torch_home"
26
 
27
+ # =============================================================
28
+ # 1️⃣ IMPORTS
29
+ # =============================================================
30
  import spaces
31
  import torch
32
  import numpy as np
 
36
  import requests
37
  import logging
38
  from PIL import Image
 
 
39
 
40
  import gradio as gr
41
  from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
 
48
  logging.basicConfig(level=logging.INFO)
49
  logger = logging.getLogger(__name__)
50
 
51
+ # =============================================================
52
+ # 2️⃣ CONFIG
53
+ # =============================================================
54
  MAX_DIM = 832
55
  MIN_DIM = 480
56
  SQUARE_DIM = 640
 
63
 
64
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
65
  default_negative_prompt = (
66
+ "colorful tones, overexposed, static, unclear details, subtitles, style, artwork, painting, screen, "
67
+ "still, overall gray, worst quality, low quality, JPEG compression artifacts, ugly, deformed, "
68
+ "extra fingers, poorly drawn hands, poorly drawn face, deformed, mutated, deformed limbs, "
69
+ "fused fingers, still screen, messy background, three legs, many people in background, walking backwards"
70
  )
71
 
72
  # ------------------------------------------------------------
73
+ # 3️⃣ TRANSLATOR (Albanian → English) – unchanged
74
  # ------------------------------------------------------------
75
  def translate_albanian_to_english(text: str) -> str:
76
  if not text.strip():
 
94
  return text
95
 
96
  # ------------------------------------------------------------
97
+ # 4️⃣ MODEL LOADING (all caches forced to /tmp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # ------------------------------------------------------------
99
  pipe = WanImageToVideoPipeline.from_pretrained(
100
  "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
 
146
  aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
147
 
148
  # ------------------------------------------------------------
149
+ # 5️⃣ HELPER FUNCTIONS (resize, frame count, GPU‑time estimate)
150
  # ------------------------------------------------------------
151
  def resize_image(image: Image.Image) -> Image.Image:
152
  """Resize / crop the input image so the model receives a valid size."""
153
+ w, h = image.size
154
+ if w == h:
 
155
  return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
156
 
157
+ aspect = w / h
158
+ MAX_ASPECT = MAX_DIM / MIN_DIM
159
+ MIN_ASPECT = MIN_DIM / MAX_DIM
160
 
161
  img = image
162
+ if aspect > MAX_ASPECT: # very wide → crop width
163
+ crop_w = int(round(h * MAX_ASPECT))
164
+ left = (w - crop_w) // 2
165
+ img = image.crop((left, 0, left + crop_w, h))
166
+ elif aspect < MIN_ASPECT: # very tall → crop height
167
+ crop_h = int(round(w / MIN_ASPECT))
168
+ top = (h - crop_h) // 2
169
+ img = image.crop((0, top, w, top + crop_h))
 
 
 
170
  else:
171
+ if w > h: # landscape
 
172
  target_w = MAX_DIM
173
+ target_h = int(round(target_w / aspect))
174
+ else: # portrait
175
  target_h = MAX_DIM
176
+ target_w = int(round(target_h * aspect))
177
  img = image
178
 
 
179
  final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
180
  final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
181
  final_w = max(MIN_DIM, min(MAX_DIM, final_w))
 
185
 
186
 
187
  def get_num_frames(duration_seconds: float) -> int:
188
+ """Number of frames for the requested duration."""
189
  return 1 + int(
190
  np.clip(
191
  int(round(duration_seconds * FIXED_FPS)),
 
205
  guidance_scale_2,
206
  seed,
207
  randomize_seed,
208
+ progress, # <- required by @spaces.GPU
209
  ):
210
  """
211
+ Rough estimate of the GPU run‑time.
212
+ The @spaces.GPU decorator will cut the job at 30 s.
213
  """
214
+ BASE = 81 * 832 * 624 # reference size used by the original demo
215
+ BASE_STEP = 15
216
 
217
  w, h = resize_image(input_image).size
218
  frames = get_num_frames(duration_seconds)
219
+ factor = frames * w * h / BASE
220
+ step_time = BASE_STEP * factor ** 1.5
221
+ est = 10 + int(steps) * step_time
222
+ return min(est, 30) # never exceed the 30‑second safety cap
 
 
223
 
224
  # ------------------------------------------------------------
225
+ # 6️⃣ GENERATION FUNCTION
226
  # ------------------------------------------------------------
227
  @spaces.GPU(duration=get_duration)
228
  def generate_video(
 
237
  randomize_seed=False,
238
  progress=gr.Progress(track_tqdm=True),
239
  ):
240
+ """Run the model return a temporary MP4 path and the seed used."""
241
  if input_image is None:
242
  raise gr.Error("Please upload an input image.")
243
 
244
+ # ---- translate prompt (Albanian → English) -----------------
 
 
245
  prompt = translate_albanian_to_english(prompt_input)
246
 
247
+ # ---- prepare inputs ----------------------------------------
 
 
248
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
249
  resized = resize_image(input_image)
250
  num_frames = get_num_frames(duration_seconds)
251
 
252
+ # ---- model inference ----------------------------------------
 
 
253
  out = pipe(
254
  image=resized,
255
  prompt=prompt,
 
264
  )
265
  frames = out.frames[0]
266
 
267
+ # ---- write a temporary MP4 (still inside /tmp) -------------
 
 
268
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp:
269
  video_path = tmp.name
270
  export_to_video(frames, video_path, fps=FIXED_FPS)
271
 
272
+ # ---- unload AoT blocks (they occupy a few GB on disk) -----
 
 
273
  aoti.aoti_blocks_unload(pipe.transformer)
274
  aoti.aoti_blocks_unload(pipe.transformer_2)
275
 
276
+ # ---- GPU cleanup -------------------------------------------
 
 
277
  gc.collect()
278
  torch.cuda.empty_cache()
279
 
 
281
 
282
 
283
  # ------------------------------------------------------------
284
+ # 7️⃣ UI – 100 % identical visual appearance to the original demo
285
  # ------------------------------------------------------------
286
  with gr.Blocks(
287
  css="""
 
536
  .gr-group{
537
  background:#000!important;
538
  border:none!important;
539
+ width:100% !important;
540
+ max-width:100vw !important;
541
  }
542
  @media (max-width:768px){
543
  h1{font-size:4rem;}
 
560
  ) as demo:
561
 
562
  # -------------------------------------------------
563
+ # 500‑ERROR GUARD – exact same unique path string
564
  # -------------------------------------------------
565
  gr.HTML("""
566
  <script>
567
  if (!window.location.pathname.includes('b9v0c1x2z3a4s5d6f7g8h9j0k1l2m3n4b5v6c7x8z9a0s1d2f3g4h5j6k7l8m9n0')) {
568
+ document.body.innerHTML = '<h1 style="color:#ef4444;font-family:Orbitron,sans-serif;text-align:center;margin-top:300px;">500 Internal Server Error</h1>';
569
+ throw new Error('Access denied');
570
  }
571
  </script>
572
  """)
573
 
574
  # -------------------------------------------------
575
+ # UI layout – identical visual hierarchy
576
  # -------------------------------------------------
577
  with gr.Row(elem_id="general_items"):
578
  gr.Markdown("# ")
 
614
  # -------------------------------------------------
615
  # Wiring – order must match generate_video signature
616
  # -------------------------------------------------
617
+ def _postprocess(video_path, seed):
618
+ """Delete the temporary file *after* Gradio has streamed it."""
619
+ try:
620
+ os.remove(video_path)
621
+ except OSError:
622
+ pass
623
+ return video_path, seed
624
+
625
  generate_btn.click(
626
  fn=generate_video,
627
  inputs=[
 
634
  gr.State(value=1.5), # guidance_scale_2
635
  gr.State(value=42), # seed
636
  gr.State(value=True), # randomize_seed
637
+ # progress is injected by @spaces.GPU – do NOT pass it here
638
  ],
639
+ outputs=[output_video, gr.State(value=42)],
640
+ postprocess=_postprocess, # <-- guarantees the MP4 is removed
641
  )
642
 
643
  # ------------------------------------------------------------
644
+ # 8️⃣ MAIN
645
  # ------------------------------------------------------------
646
  if __name__ == "__main__":
647
  demo.queue().launch(share=True)