vaibhavpandeyvpz's picture
Import files for official repo
bef42b6
raw
history blame
12.5 kB
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import os
import sys
import warnings
import gradio as gr
warnings.filterwarnings("ignore")
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import QwenPromptExpander
from wan.utils.utils import cache_video
# Model ID from Hugging Face Hub
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P"
# Global variables
prompt_expander = None
wan_flf2v_720P = None
model_loaded = False
def load_model():
"""Load the model from Hugging Face Hub (ZeroGPU compatible)"""
global wan_flf2v_720P, model_loaded
if model_loaded and wan_flf2v_720P is not None:
return "Model already loaded"
try:
gc.collect()
print(
"Loading Wan2.1-FLF2V-14B-720P model from Hugging Face Hub...",
end="",
flush=True,
)
cfg = WAN_CONFIGS["flf2v-14B"]
# Load from Hugging Face Hub
checkpoint_dir = MODEL_ID
wan_flf2v_720P = wan.WanFLF2V(
config=cfg,
checkpoint_dir=checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
model_loaded = True
print(" done", flush=True)
return "Model loaded successfully!"
except Exception as e:
error_msg = f"Error loading model: {str(e)}"
print(error_msg)
return error_msg
def prompt_enhance(prompt, img_first, img_last, tar_lang):
"""Enhance prompt using Qwen vision model"""
print("Enhancing prompt...")
if img_first is None or img_last is None:
print("Please upload the first and last frames")
return prompt
global prompt_expander
if prompt_expander is None:
try:
# Initialize prompt expander (local Qwen model)
prompt_expander = QwenPromptExpander(
model_name=None, is_vl=True, device=0 # Will use default model
)
except Exception as e:
print(f"Warning: Could not initialize prompt expander: {e}")
return prompt
try:
prompt_output = prompt_expander(
prompt, image=[img_first, img_last], tar_lang=tar_lang.lower()
)
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
except Exception as e:
print(f"Error enhancing prompt: {e}")
return prompt
def flf2v_generation(
flf2vid_prompt,
flf2vid_image_first,
flf2vid_image_last,
resolution,
sd_steps,
guide_scale,
shift_scale,
seed,
n_prompt,
sample_solver,
frame_num,
):
"""Generate video from first and last frame images + text prompt"""
if wan_flf2v_720P is None:
return None, "Model is still loading. Please wait a moment and try again."
if flf2vid_image_first is None or flf2vid_image_last is None:
return None, "Please upload both first and last frame images"
if not flf2vid_prompt or flf2vid_prompt.strip() == "":
return None, "Please provide a text prompt"
# Validate frame_num (must be 4n+1)
if (frame_num - 1) % 4 != 0:
return (
None,
f"Frame number must be 4n+1 (e.g., 17, 21, 25, ..., 81). Got {frame_num}",
)
try:
print(f"Generating video with parameters:")
print(f" Resolution: {resolution}")
print(f" Steps: {sd_steps}")
print(f" Guide scale: {guide_scale}")
print(f" Shift scale: {shift_scale}")
print(f" Seed: {seed}")
print(f" Solver: {sample_solver}")
print(f" Frame num: {frame_num}")
if resolution == "720P":
max_area = MAX_AREA_CONFIGS["720*1280"]
elif resolution == "1280x720":
max_area = MAX_AREA_CONFIGS["1280*720"]
elif resolution == "480P":
max_area = MAX_AREA_CONFIGS["480*832"]
elif resolution == "832x480":
max_area = MAX_AREA_CONFIGS["832*480"]
else:
max_area = MAX_AREA_CONFIGS["720*1280"]
video = wan_flf2v_720P.generate(
flf2vid_prompt,
flf2vid_image_first,
flf2vid_image_last,
max_area=max_area,
frame_num=frame_num,
shift=shift_scale,
sample_solver=sample_solver,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt if n_prompt else "",
seed=seed,
offload_model=True,
)
# Save video
output_path = "generated_video.mp4"
cache_video(
tensor=video[None],
save_file=output_path,
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1),
)
return output_path, "Video generated successfully!"
except Exception as e:
error_msg = f"Error generating video: {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return None, error_msg
def create_interface():
"""Create the Gradio interface"""
with gr.Blocks(
title="Wan2.1 FLF2V - First & Last Frame to Video", theme=gr.themes.Soft()
) as demo:
gr.Markdown(
"""
<div style="text-align: center;">
<h1>🎬 Wan2.1 FLF2V-14B-720P</h1>
<p style="font-size: 18px; color: #666;">
Generate videos from first & last frame images + text prompt using Wan2.1 model
</p>
<p style="font-size: 14px; color: #888;">
Model: <a href="https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P" target="_blank">Wan-AI/Wan2.1-FLF2V-14B-720P</a>
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
# Input section
with gr.Group():
gr.Markdown("### πŸ“Έ Input Images")
flf2vid_image_first = gr.Image(
type="pil",
label="First Frame",
height=300,
)
flf2vid_image_last = gr.Image(
type="pil",
label="Last Frame",
height=300,
)
# Prompt section
with gr.Group():
gr.Markdown("### ✍️ Text Prompt")
flf2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate...",
lines=3,
)
tar_lang = gr.Radio(
choices=["ZH", "EN"],
label="Prompt Enhancement Language",
value="ZH",
info="Language for prompt enhancement",
)
enhance_prompt_btn = gr.Button(
"✨ Enhance Prompt", variant="secondary"
)
# Advanced options
with gr.Accordion("βš™οΈ Advanced Options", open=False):
resolution = gr.Dropdown(
label="Resolution",
choices=["720P", "1280x720", "480P", "832x480"],
value="720P",
info="Output video resolution",
)
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion Steps",
minimum=1,
maximum=100,
value=50,
step=1,
info="Number of diffusion sampling steps",
)
guide_scale = gr.Slider(
label="Guide Scale",
minimum=0.0,
maximum=20.0,
value=5.0,
step=0.1,
info="Classifier-free guidance scale",
)
with gr.Row():
shift_scale = gr.Slider(
label="Shift Scale",
minimum=0.0,
maximum=20.0,
value=5.0,
step=0.1,
info="Noise schedule shift parameter",
)
seed = gr.Number(
label="Seed",
value=-1,
precision=0,
info="Random seed (-1 for random)",
)
sample_solver = gr.Dropdown(
label="Sample Solver",
choices=["unipc", "dpm++"],
value="unipc",
info="Solver used for sampling",
)
frame_num = gr.Slider(
label="Number of Frames",
minimum=17,
maximum=81,
value=81,
step=4,
info="Number of frames to generate (must be 4n+1)",
)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe what you want to avoid in the video...",
lines=2,
)
generate_btn = gr.Button(
"🎬 Generate Video", variant="primary", size="lg"
)
with gr.Column(scale=1):
gr.Markdown("### πŸŽ₯ Generated Video")
result_video = gr.Video(
label="Output Video", height=600, show_download_button=True
)
result_status = gr.Textbox(label="Status", interactive=False)
# Event handlers
enhance_prompt_btn.click(
fn=prompt_enhance,
inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang],
outputs=[flf2vid_prompt],
)
generate_btn.click(
fn=flf2v_generation,
inputs=[
flf2vid_prompt,
flf2vid_image_first,
flf2vid_image_last,
resolution,
sd_steps,
guide_scale,
shift_scale,
seed,
n_prompt,
sample_solver,
frame_num,
],
outputs=[result_video, result_status],
)
# Examples
gr.Markdown("### πŸ“š Examples")
gr.Examples(
examples=[
[
"examples/flf2v_input_first_frame.png",
"examples/flf2v_input_last_frame.png",
"A beautiful scene transition",
]
],
inputs=[flf2vid_image_first, flf2vid_image_last, flf2vid_prompt],
label="Example inputs",
)
return demo
if __name__ == "__main__":
# Initialize prompt expander on startup (optional, can be lazy loaded)
try:
print("Initializing prompt expander...", end="", flush=True)
prompt_expander = QwenPromptExpander(model_name=None, is_vl=True, device=0)
print(" done", flush=True)
except Exception as e:
print(f"Warning: Could not initialize prompt expander on startup: {e}")
print("Prompt enhancement will be disabled until model is loaded.")
prompt_expander = None
# Load model automatically on startup
print("\n" + "=" * 50)
print("Loading Wan2.1-FLF2V-14B-720P model...")
print("=" * 50)
load_model()
if wan_flf2v_720P is not None:
print("βœ“ Model loaded successfully!")
else:
print(
"βœ— Failed to load model. The app will still start, but video generation will not work."
)
print("=" * 50 + "\n")
demo = create_interface()
# Launch with ZeroGPU support
# ZeroGPU spaces automatically handle GPU allocation
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)