vaibhavpandeyvpz commited on
Commit
d3022d8
·
1 Parent(s): 747f11e

Improve & simplify model loading

Browse files
Files changed (2) hide show
  1. app.py +29 -67
  2. requirements.txt +1 -0
app.py CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
7
 
8
  import gradio as gr
9
  from huggingface_hub import snapshot_download
 
10
 
11
  warnings.filterwarnings("ignore")
12
 
@@ -21,60 +22,27 @@ MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P"
21
  # Global variables
22
  prompt_expander = None
23
  wan_flf2v_720P = None
24
- model_loaded = False
25
 
 
 
 
 
26
 
27
- def load_model():
28
- """Load the model from Hugging Face Hub (ZeroGPU compatible)"""
29
- global wan_flf2v_720P, model_loaded
30
 
31
- if model_loaded and wan_flf2v_720P is not None:
32
- return "Model already loaded"
33
-
34
- try:
35
- gc.collect()
36
-
37
- print(
38
- "Loading Wan2.1-FLF2V-14B-720P model from Hugging Face Hub...",
39
- end="",
40
- flush=True,
41
- )
42
- cfg = WAN_CONFIGS["flf2v-14B"]
43
-
44
- # Download model from Hugging Face Hub to local cache
45
- print("\nDownloading model files from Hugging Face Hub...", flush=True)
46
- # Use HF_HOME environment variable if set (for Hugging Face Spaces)
47
- # Otherwise use default cache location
48
- cache_base = os.environ.get("HF_HOME")
49
- if cache_base:
50
- cache_dir_path = Path(cache_base) / "hub"
51
- else:
52
- cache_dir_path = Path.home() / ".cache" / "huggingface" / "hub"
53
-
54
- checkpoint_dir = snapshot_download(
55
- repo_id=MODEL_ID,
56
- cache_dir=str(cache_dir_path),
57
- local_files_only=False,
58
- )
59
- print(f"Model downloaded to: {checkpoint_dir}", flush=True)
60
-
61
- wan_flf2v_720P = wan.WanFLF2V(
62
- config=cfg,
63
- checkpoint_dir=checkpoint_dir,
64
- device_id=0,
65
- rank=0,
66
- t5_fsdp=False,
67
- dit_fsdp=False,
68
- use_usp=False,
69
- )
70
-
71
- model_loaded = True
72
- print(" done", flush=True)
73
- return "Model loaded successfully!"
74
- except Exception as e:
75
- error_msg = f"Error loading model: {str(e)}"
76
- print(error_msg)
77
- return error_msg
78
 
79
 
80
  def prompt_enhance(prompt, img_first, img_last, tar_lang):
@@ -108,6 +76,12 @@ def prompt_enhance(prompt, img_first, img_last, tar_lang):
108
  return prompt
109
 
110
 
 
 
 
 
 
 
111
  def flf2v_generation(
112
  flf2vid_prompt,
113
  flf2vid_image_first,
@@ -120,11 +94,12 @@ def flf2v_generation(
120
  n_prompt,
121
  sample_solver,
122
  frame_num,
 
123
  ):
124
  """Generate video from first and last frame images + text prompt"""
125
 
126
  if wan_flf2v_720P is None:
127
- return None, "Model is still loading. Please wait a moment and try again."
128
 
129
  if flf2vid_image_first is None or flf2vid_image_last is None:
130
  return None, "Please upload both first and last frame images"
@@ -374,24 +349,11 @@ if __name__ == "__main__":
374
  print(" done", flush=True)
375
  except Exception as e:
376
  print(f"Warning: Could not initialize prompt expander on startup: {e}")
377
- print("Prompt enhancement will be disabled until model is loaded.")
378
  prompt_expander = None
379
 
380
- # Load model automatically on startup
381
- print("\n" + "=" * 50)
382
- print("Loading Wan2.1-FLF2V-14B-720P model...")
383
- print("=" * 50)
384
- load_model()
385
- if wan_flf2v_720P is not None:
386
- print("✓ Model loaded successfully!")
387
- else:
388
- print(
389
- "✗ Failed to load model. The app will still start, but video generation will not work."
390
- )
391
- print("=" * 50 + "\n")
392
-
393
  demo = create_interface()
394
 
395
  # Launch with ZeroGPU support
396
- # ZeroGPU spaces automatically handle GPU allocation
397
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
7
 
8
  import gradio as gr
9
  from huggingface_hub import snapshot_download
10
+ import spaces
11
 
12
  warnings.filterwarnings("ignore")
13
 
 
22
  # Global variables
23
  prompt_expander = None
24
  wan_flf2v_720P = None
 
25
 
26
+ # Download model snapshots from Hugging Face Hub
27
+ print(f"Downloading/loading checkpoints for {MODEL_ID}...")
28
+ ckpt_dir = snapshot_download(MODEL_ID, local_dir_use_symlinks=False)
29
+ print(f"Using checkpoints from {ckpt_dir}")
30
 
31
+ # Load the model configuration
32
+ cfg = WAN_CONFIGS["flf2v-14B"]
 
33
 
34
+ # Instantiate the model in the global scope
35
+ print("Initializing WanFLF2V pipeline...")
36
+ wan_flf2v_720P = wan.WanFLF2V(
37
+ config=cfg,
38
+ checkpoint_dir=ckpt_dir,
39
+ device_id=0,
40
+ rank=0,
41
+ t5_fsdp=False,
42
+ dit_fsdp=False,
43
+ use_usp=False,
44
+ )
45
+ print("Pipeline initialized and ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  def prompt_enhance(prompt, img_first, img_last, tar_lang):
 
76
  return prompt
77
 
78
 
79
+ def get_duration(sd_steps, *args, **kwargs):
80
+ """Calculate dynamic GPU duration based on parameters."""
81
+ return sd_steps * 15
82
+
83
+
84
+ @spaces.GPU(duration=get_duration)
85
  def flf2v_generation(
86
  flf2vid_prompt,
87
  flf2vid_image_first,
 
94
  n_prompt,
95
  sample_solver,
96
  frame_num,
97
+ progress=gr.Progress(track_tqdm=True),
98
  ):
99
  """Generate video from first and last frame images + text prompt"""
100
 
101
  if wan_flf2v_720P is None:
102
+ return None, "Model failed to load. Please check the logs."
103
 
104
  if flf2vid_image_first is None or flf2vid_image_last is None:
105
  return None, "Please upload both first and last frame images"
 
349
  print(" done", flush=True)
350
  except Exception as e:
351
  print(f"Warning: Could not initialize prompt expander on startup: {e}")
352
+ print("Prompt enhancement will be disabled.")
353
  prompt_expander = None
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  demo = create_interface()
356
 
357
  # Launch with ZeroGPU support
358
+ # ZeroGPU spaces automatically handle GPU allocation via @spaces.GPU decorator
359
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
requirements.txt CHANGED
@@ -15,4 +15,5 @@ gradio>=5.0.0
15
  numpy>=1.23.5,<2
16
  huggingface-hub
17
  Pillow
 
18
  https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.6.3+cu128torch2.8-cp310-cp310-linux_x86_64.whl
 
15
  numpy>=1.23.5,<2
16
  huggingface-hub
17
  Pillow
18
+ spaces
19
  https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.6.3+cu128torch2.8-cp310-cp310-linux_x86_64.whl