# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import gc import os import sys import warnings from pathlib import Path import gradio as gr from huggingface_hub import snapshot_download import spaces warnings.filterwarnings("ignore") import wan from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS from wan.utils.prompt_extend import QwenPromptExpander from wan.utils.utils import cache_video # Model ID from Hugging Face Hub MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P" # Global variables prompt_expander = None wan_flf2v_720P = None # Download model snapshots from Hugging Face Hub print(f"Downloading/loading checkpoints for {MODEL_ID}...") ckpt_dir = snapshot_download(MODEL_ID, local_dir_use_symlinks=False) print(f"Using checkpoints from {ckpt_dir}") # Load the model configuration cfg = WAN_CONFIGS["flf2v-14B"] # Instantiate the model in the global scope print("Initializing WanFLF2V pipeline...") wan_flf2v_720P = wan.WanFLF2V( config=cfg, checkpoint_dir=ckpt_dir, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) print("Pipeline initialized and ready.") def prompt_enhance(prompt, img_first, img_last, tar_lang): """Enhance prompt using Qwen vision model""" print("Enhancing prompt...") if img_first is None or img_last is None: print("Please upload the first and last frames") return prompt global prompt_expander if prompt_expander is None: try: # Initialize prompt expander (local Qwen model) prompt_expander = QwenPromptExpander( model_name=None, is_vl=True, device=0 # Will use default model ) except Exception as e: print(f"Warning: Could not initialize prompt expander: {e}") return prompt try: prompt_output = prompt_expander( prompt, image=[img_first, img_last], tar_lang=tar_lang.lower() ) if prompt_output.status == False: return prompt else: return prompt_output.prompt except Exception as e: print(f"Error enhancing prompt: {e}") return prompt def get_duration( flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt, sample_solver, frame_num, progress=None, ): """Calculate dynamic GPU duration based on parameters.""" BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624 BASE_STEP_DURATION = 15 # Get dimensions from resolution or from first image if flf2vid_image_first is not None: width, height = flf2vid_image_first.size else: # Fallback to resolution string mapping resolution_map = { "720P": (1280, 720), "1280x720": (1280, 720), "480P": (832, 480), "832x480": (832, 480), } width, height = resolution_map.get(resolution, (1280, 720)) # Use frame_num directly (already provided) frames = int(frame_num) if frame_num else 81 # Calculate duration factor factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH step_duration = BASE_STEP_DURATION * (factor**1.5) # Return total duration in seconds return 10 + int(sd_steps) * step_duration @spaces.GPU(duration=get_duration) def flf2v_generation( flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt, sample_solver, frame_num, progress=gr.Progress(track_tqdm=True), ): """Generate video from first and last frame images + text prompt""" if wan_flf2v_720P is None: return None, "Model failed to load. Please check the logs." if flf2vid_image_first is None or flf2vid_image_last is None: return None, "Please upload both first and last frame images" if not flf2vid_prompt or flf2vid_prompt.strip() == "": return None, "Please provide a text prompt" # Validate frame_num (must be 4n+1) if (frame_num - 1) % 4 != 0: return ( None, f"Frame number must be 4n+1 (e.g., 17, 21, 25, ..., 81). Got {frame_num}", ) try: print(f"Generating video with parameters:") print(f" Resolution: {resolution}") print(f" Steps: {sd_steps}") print(f" Guide scale: {guide_scale}") print(f" Shift scale: {shift_scale}") print(f" Seed: {seed}") print(f" Solver: {sample_solver}") print(f" Frame num: {frame_num}") if resolution == "720P": max_area = MAX_AREA_CONFIGS["720*1280"] elif resolution == "1280x720": max_area = MAX_AREA_CONFIGS["1280*720"] elif resolution == "480P": max_area = MAX_AREA_CONFIGS["480*832"] elif resolution == "832x480": max_area = MAX_AREA_CONFIGS["832*480"] else: max_area = MAX_AREA_CONFIGS["720*1280"] video = wan_flf2v_720P.generate( flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, max_area=max_area, frame_num=frame_num, shift=shift_scale, sample_solver=sample_solver, sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt if n_prompt else "", seed=seed, offload_model=True, ) # Save video output_path = "generated_video.mp4" cache_video( tensor=video[None], save_file=output_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1), ) return output_path, "Video generated successfully!" except Exception as e: error_msg = f"Error generating video: {str(e)}" print(error_msg) import traceback traceback.print_exc() return None, error_msg def create_interface(): """Create the Gradio interface""" with gr.Blocks(title="Wan2.1 FLF2V - First & Last Frame to Video") as demo: gr.Markdown( """
Generate videos from first & last frame images + text prompt using Wan2.1 model
Model: Wan-AI/Wan2.1-FLF2V-14B-720P