# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import gc import os import sys import warnings from pathlib import Path import gradio as gr from huggingface_hub import snapshot_download import spaces warnings.filterwarnings("ignore") import wan from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS from wan.utils.prompt_extend import QwenPromptExpander from wan.utils.utils import cache_video # Model ID from Hugging Face Hub MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P" # Global variables prompt_expander = None wan_flf2v_720P = None # Download model snapshots from Hugging Face Hub print(f"Downloading/loading checkpoints for {MODEL_ID}...") ckpt_dir = snapshot_download(MODEL_ID, local_dir_use_symlinks=False) print(f"Using checkpoints from {ckpt_dir}") # Load the model configuration cfg = WAN_CONFIGS["flf2v-14B"] # Instantiate the model in the global scope print("Initializing WanFLF2V pipeline...") wan_flf2v_720P = wan.WanFLF2V( config=cfg, checkpoint_dir=ckpt_dir, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) print("Pipeline initialized and ready.") def prompt_enhance(prompt, img_first, img_last, tar_lang): """Enhance prompt using Qwen vision model""" print("Enhancing prompt...") if img_first is None or img_last is None: print("Please upload the first and last frames") return prompt global prompt_expander if prompt_expander is None: try: # Initialize prompt expander (local Qwen model) prompt_expander = QwenPromptExpander( model_name=None, is_vl=True, device=0 # Will use default model ) except Exception as e: print(f"Warning: Could not initialize prompt expander: {e}") return prompt try: prompt_output = prompt_expander( prompt, image=[img_first, img_last], tar_lang=tar_lang.lower() ) if prompt_output.status == False: return prompt else: return prompt_output.prompt except Exception as e: print(f"Error enhancing prompt: {e}") return prompt def get_duration( flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt, sample_solver, frame_num, progress=None, ): """Calculate dynamic GPU duration based on parameters.""" BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624 BASE_STEP_DURATION = 15 # Get dimensions from resolution or from first image if flf2vid_image_first is not None: width, height = flf2vid_image_first.size else: # Fallback to resolution string mapping resolution_map = { "720P": (1280, 720), "1280x720": (1280, 720), "480P": (832, 480), "832x480": (832, 480), } width, height = resolution_map.get(resolution, (1280, 720)) # Use frame_num directly (already provided) frames = int(frame_num) if frame_num else 81 # Calculate duration factor factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH step_duration = BASE_STEP_DURATION * (factor**1.5) # Return total duration in seconds return 10 + int(sd_steps) * step_duration @spaces.GPU(duration=get_duration) def flf2v_generation( flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt, sample_solver, frame_num, progress=gr.Progress(track_tqdm=True), ): """Generate video from first and last frame images + text prompt""" if wan_flf2v_720P is None: return None, "Model failed to load. Please check the logs." if flf2vid_image_first is None or flf2vid_image_last is None: return None, "Please upload both first and last frame images" if not flf2vid_prompt or flf2vid_prompt.strip() == "": return None, "Please provide a text prompt" # Validate frame_num (must be 4n+1) if (frame_num - 1) % 4 != 0: return ( None, f"Frame number must be 4n+1 (e.g., 17, 21, 25, ..., 81). Got {frame_num}", ) try: print(f"Generating video with parameters:") print(f" Resolution: {resolution}") print(f" Steps: {sd_steps}") print(f" Guide scale: {guide_scale}") print(f" Shift scale: {shift_scale}") print(f" Seed: {seed}") print(f" Solver: {sample_solver}") print(f" Frame num: {frame_num}") if resolution == "720P": max_area = MAX_AREA_CONFIGS["720*1280"] elif resolution == "1280x720": max_area = MAX_AREA_CONFIGS["1280*720"] elif resolution == "480P": max_area = MAX_AREA_CONFIGS["480*832"] elif resolution == "832x480": max_area = MAX_AREA_CONFIGS["832*480"] else: max_area = MAX_AREA_CONFIGS["720*1280"] video = wan_flf2v_720P.generate( flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, max_area=max_area, frame_num=frame_num, shift=shift_scale, sample_solver=sample_solver, sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt if n_prompt else "", seed=seed, offload_model=True, ) # Save video output_path = "generated_video.mp4" cache_video( tensor=video[None], save_file=output_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1), ) return output_path, "Video generated successfully!" except Exception as e: error_msg = f"Error generating video: {str(e)}" print(error_msg) import traceback traceback.print_exc() return None, error_msg def create_interface(): """Create the Gradio interface""" with gr.Blocks(title="Wan2.1 FLF2V - First & Last Frame to Video") as demo: gr.Markdown( """

🎬 Wan2.1 FLF2V-14B-720P

Generate videos from first & last frame images + text prompt using Wan2.1 model

Model: Wan-AI/Wan2.1-FLF2V-14B-720P

""" ) with gr.Row(): with gr.Column(scale=1): # Input section with gr.Group(): gr.Markdown("### 📸 Input Images") flf2vid_image_first = gr.Image( type="pil", label="First Frame", height=300, ) flf2vid_image_last = gr.Image( type="pil", label="Last Frame", height=300, ) # Prompt section with gr.Group(): gr.Markdown("### ✍️ Text Prompt") flf2vid_prompt = gr.Textbox( label="Prompt", placeholder="Describe the video you want to generate...", lines=3, ) tar_lang = gr.Radio( choices=["ZH", "EN"], label="Prompt Enhancement Language", value="ZH", info="Language for prompt enhancement", ) enhance_prompt_btn = gr.Button( "✨ Enhance Prompt", variant="secondary" ) # Advanced options with gr.Accordion("⚙️ Advanced Options", open=False): resolution = gr.Dropdown( label="Resolution", choices=["720P", "1280x720", "480P", "832x480"], value="720P", info="Output video resolution", ) with gr.Row(): sd_steps = gr.Slider( label="Diffusion Steps", minimum=1, maximum=100, value=50, step=1, info="Number of diffusion sampling steps", ) guide_scale = gr.Slider( label="Guide Scale", minimum=0.0, maximum=20.0, value=5.0, step=0.1, info="Classifier-free guidance scale", ) with gr.Row(): shift_scale = gr.Slider( label="Shift Scale", minimum=0.0, maximum=20.0, value=5.0, step=0.1, info="Noise schedule shift parameter", ) seed = gr.Number( label="Seed", value=-1, precision=0, info="Random seed (-1 for random)", ) sample_solver = gr.Dropdown( label="Sample Solver", choices=["unipc", "dpm++"], value="unipc", info="Solver used for sampling", ) frame_num = gr.Slider( label="Number of Frames", minimum=17, maximum=81, value=81, step=4, info="Number of frames to generate (must be 4n+1)", ) n_prompt = gr.Textbox( label="Negative Prompt", placeholder="Describe what you want to avoid in the video...", lines=2, ) generate_btn = gr.Button( "🎬 Generate Video", variant="primary", size="lg" ) with gr.Column(scale=1): gr.Markdown("### 🎥 Generated Video") result_video = gr.Video(label="Output Video", height=600) result_status = gr.Textbox(label="Status", interactive=False) # Event handlers enhance_prompt_btn.click( fn=prompt_enhance, inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang], outputs=[flf2vid_prompt], ) generate_btn.click( fn=flf2v_generation, inputs=[ flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt, sample_solver, frame_num, ], outputs=[result_video, result_status], ) # Examples gr.Markdown("### 📚 Examples") gr.Examples( examples=[ [ "examples/flf2v_input_first_frame.png", "examples/flf2v_input_last_frame.png", "A beautiful scene transition", ] ], inputs=[flf2vid_image_first, flf2vid_image_last, flf2vid_prompt], label="Example inputs", ) return demo if __name__ == "__main__": # Initialize prompt expander on startup (optional, can be lazy loaded) try: print("Initializing prompt expander...", end="", flush=True) prompt_expander = QwenPromptExpander(model_name=None, is_vl=True, device=0) print(" done", flush=True) except Exception as e: print(f"Warning: Could not initialize prompt expander on startup: {e}") print("Prompt enhancement will be disabled.") prompt_expander = None demo = create_interface() # Launch with ZeroGPU support # ZeroGPU spaces automatically handle GPU allocation via @spaces.GPU decorator demo.launch(server_name="0.0.0.0", server_port=7860, share=False)