Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import gc | |
| import os | |
| import sys | |
| import warnings | |
| import gradio as gr | |
| warnings.filterwarnings("ignore") | |
| import wan | |
| from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS | |
| from wan.utils.prompt_extend import QwenPromptExpander | |
| from wan.utils.utils import cache_video | |
| # Model ID from Hugging Face Hub | |
| MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P" | |
| # Global variables | |
| prompt_expander = None | |
| wan_flf2v_720P = None | |
| model_loaded = False | |
| def load_model(): | |
| """Load the model from Hugging Face Hub (ZeroGPU compatible)""" | |
| global wan_flf2v_720P, model_loaded | |
| if model_loaded and wan_flf2v_720P is not None: | |
| return "Model already loaded" | |
| try: | |
| gc.collect() | |
| print( | |
| "Loading Wan2.1-FLF2V-14B-720P model from Hugging Face Hub...", | |
| end="", | |
| flush=True, | |
| ) | |
| cfg = WAN_CONFIGS["flf2v-14B"] | |
| # Load from Hugging Face Hub | |
| checkpoint_dir = MODEL_ID | |
| wan_flf2v_720P = wan.WanFLF2V( | |
| config=cfg, | |
| checkpoint_dir=checkpoint_dir, | |
| device_id=0, | |
| rank=0, | |
| t5_fsdp=False, | |
| dit_fsdp=False, | |
| use_usp=False, | |
| ) | |
| model_loaded = True | |
| print(" done", flush=True) | |
| return "Model loaded successfully!" | |
| except Exception as e: | |
| error_msg = f"Error loading model: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| def prompt_enhance(prompt, img_first, img_last, tar_lang): | |
| """Enhance prompt using Qwen vision model""" | |
| print("Enhancing prompt...") | |
| if img_first is None or img_last is None: | |
| print("Please upload the first and last frames") | |
| return prompt | |
| global prompt_expander | |
| if prompt_expander is None: | |
| try: | |
| # Initialize prompt expander (local Qwen model) | |
| prompt_expander = QwenPromptExpander( | |
| model_name=None, is_vl=True, device=0 # Will use default model | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Could not initialize prompt expander: {e}") | |
| return prompt | |
| try: | |
| prompt_output = prompt_expander( | |
| prompt, image=[img_first, img_last], tar_lang=tar_lang.lower() | |
| ) | |
| if prompt_output.status == False: | |
| return prompt | |
| else: | |
| return prompt_output.prompt | |
| except Exception as e: | |
| print(f"Error enhancing prompt: {e}") | |
| return prompt | |
| def flf2v_generation( | |
| flf2vid_prompt, | |
| flf2vid_image_first, | |
| flf2vid_image_last, | |
| resolution, | |
| sd_steps, | |
| guide_scale, | |
| shift_scale, | |
| seed, | |
| n_prompt, | |
| sample_solver, | |
| frame_num, | |
| ): | |
| """Generate video from first and last frame images + text prompt""" | |
| if wan_flf2v_720P is None: | |
| return None, "Model is still loading. Please wait a moment and try again." | |
| if flf2vid_image_first is None or flf2vid_image_last is None: | |
| return None, "Please upload both first and last frame images" | |
| if not flf2vid_prompt or flf2vid_prompt.strip() == "": | |
| return None, "Please provide a text prompt" | |
| # Validate frame_num (must be 4n+1) | |
| if (frame_num - 1) % 4 != 0: | |
| return ( | |
| None, | |
| f"Frame number must be 4n+1 (e.g., 17, 21, 25, ..., 81). Got {frame_num}", | |
| ) | |
| try: | |
| print(f"Generating video with parameters:") | |
| print(f" Resolution: {resolution}") | |
| print(f" Steps: {sd_steps}") | |
| print(f" Guide scale: {guide_scale}") | |
| print(f" Shift scale: {shift_scale}") | |
| print(f" Seed: {seed}") | |
| print(f" Solver: {sample_solver}") | |
| print(f" Frame num: {frame_num}") | |
| if resolution == "720P": | |
| max_area = MAX_AREA_CONFIGS["720*1280"] | |
| elif resolution == "1280x720": | |
| max_area = MAX_AREA_CONFIGS["1280*720"] | |
| elif resolution == "480P": | |
| max_area = MAX_AREA_CONFIGS["480*832"] | |
| elif resolution == "832x480": | |
| max_area = MAX_AREA_CONFIGS["832*480"] | |
| else: | |
| max_area = MAX_AREA_CONFIGS["720*1280"] | |
| video = wan_flf2v_720P.generate( | |
| flf2vid_prompt, | |
| flf2vid_image_first, | |
| flf2vid_image_last, | |
| max_area=max_area, | |
| frame_num=frame_num, | |
| shift=shift_scale, | |
| sample_solver=sample_solver, | |
| sampling_steps=sd_steps, | |
| guide_scale=guide_scale, | |
| n_prompt=n_prompt if n_prompt else "", | |
| seed=seed, | |
| offload_model=True, | |
| ) | |
| # Save video | |
| output_path = "generated_video.mp4" | |
| cache_video( | |
| tensor=video[None], | |
| save_file=output_path, | |
| fps=16, | |
| nrow=1, | |
| normalize=True, | |
| value_range=(-1, 1), | |
| ) | |
| return output_path, "Video generated successfully!" | |
| except Exception as e: | |
| error_msg = f"Error generating video: {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| return None, error_msg | |
| def create_interface(): | |
| """Create the Gradio interface""" | |
| with gr.Blocks( | |
| title="Wan2.1 FLF2V - First & Last Frame to Video", theme=gr.themes.Soft() | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| <div style="text-align: center;"> | |
| <h1>π¬ Wan2.1 FLF2V-14B-720P</h1> | |
| <p style="font-size: 18px; color: #666;"> | |
| Generate videos from first & last frame images + text prompt using Wan2.1 model | |
| </p> | |
| <p style="font-size: 14px; color: #888;"> | |
| Model: <a href="https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P" target="_blank">Wan-AI/Wan2.1-FLF2V-14B-720P</a> | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input section | |
| with gr.Group(): | |
| gr.Markdown("### πΈ Input Images") | |
| flf2vid_image_first = gr.Image( | |
| type="pil", | |
| label="First Frame", | |
| height=300, | |
| ) | |
| flf2vid_image_last = gr.Image( | |
| type="pil", | |
| label="Last Frame", | |
| height=300, | |
| ) | |
| # Prompt section | |
| with gr.Group(): | |
| gr.Markdown("### βοΈ Text Prompt") | |
| flf2vid_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the video you want to generate...", | |
| lines=3, | |
| ) | |
| tar_lang = gr.Radio( | |
| choices=["ZH", "EN"], | |
| label="Prompt Enhancement Language", | |
| value="ZH", | |
| info="Language for prompt enhancement", | |
| ) | |
| enhance_prompt_btn = gr.Button( | |
| "β¨ Enhance Prompt", variant="secondary" | |
| ) | |
| # Advanced options | |
| with gr.Accordion("βοΈ Advanced Options", open=False): | |
| resolution = gr.Dropdown( | |
| label="Resolution", | |
| choices=["720P", "1280x720", "480P", "832x480"], | |
| value="720P", | |
| info="Output video resolution", | |
| ) | |
| with gr.Row(): | |
| sd_steps = gr.Slider( | |
| label="Diffusion Steps", | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| info="Number of diffusion sampling steps", | |
| ) | |
| guide_scale = gr.Slider( | |
| label="Guide Scale", | |
| minimum=0.0, | |
| maximum=20.0, | |
| value=5.0, | |
| step=0.1, | |
| info="Classifier-free guidance scale", | |
| ) | |
| with gr.Row(): | |
| shift_scale = gr.Slider( | |
| label="Shift Scale", | |
| minimum=0.0, | |
| maximum=20.0, | |
| value=5.0, | |
| step=0.1, | |
| info="Noise schedule shift parameter", | |
| ) | |
| seed = gr.Number( | |
| label="Seed", | |
| value=-1, | |
| precision=0, | |
| info="Random seed (-1 for random)", | |
| ) | |
| sample_solver = gr.Dropdown( | |
| label="Sample Solver", | |
| choices=["unipc", "dpm++"], | |
| value="unipc", | |
| info="Solver used for sampling", | |
| ) | |
| frame_num = gr.Slider( | |
| label="Number of Frames", | |
| minimum=17, | |
| maximum=81, | |
| value=81, | |
| step=4, | |
| info="Number of frames to generate (must be 4n+1)", | |
| ) | |
| n_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| placeholder="Describe what you want to avoid in the video...", | |
| lines=2, | |
| ) | |
| generate_btn = gr.Button( | |
| "π¬ Generate Video", variant="primary", size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π₯ Generated Video") | |
| result_video = gr.Video( | |
| label="Output Video", height=600, show_download_button=True | |
| ) | |
| result_status = gr.Textbox(label="Status", interactive=False) | |
| # Event handlers | |
| enhance_prompt_btn.click( | |
| fn=prompt_enhance, | |
| inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang], | |
| outputs=[flf2vid_prompt], | |
| ) | |
| generate_btn.click( | |
| fn=flf2v_generation, | |
| inputs=[ | |
| flf2vid_prompt, | |
| flf2vid_image_first, | |
| flf2vid_image_last, | |
| resolution, | |
| sd_steps, | |
| guide_scale, | |
| shift_scale, | |
| seed, | |
| n_prompt, | |
| sample_solver, | |
| frame_num, | |
| ], | |
| outputs=[result_video, result_status], | |
| ) | |
| # Examples | |
| gr.Markdown("### π Examples") | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "examples/flf2v_input_first_frame.png", | |
| "examples/flf2v_input_last_frame.png", | |
| "A beautiful scene transition", | |
| ] | |
| ], | |
| inputs=[flf2vid_image_first, flf2vid_image_last, flf2vid_prompt], | |
| label="Example inputs", | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Initialize prompt expander on startup (optional, can be lazy loaded) | |
| try: | |
| print("Initializing prompt expander...", end="", flush=True) | |
| prompt_expander = QwenPromptExpander(model_name=None, is_vl=True, device=0) | |
| print(" done", flush=True) | |
| except Exception as e: | |
| print(f"Warning: Could not initialize prompt expander on startup: {e}") | |
| print("Prompt enhancement will be disabled until model is loaded.") | |
| prompt_expander = None | |
| # Load model automatically on startup | |
| print("\n" + "=" * 50) | |
| print("Loading Wan2.1-FLF2V-14B-720P model...") | |
| print("=" * 50) | |
| load_model() | |
| if wan_flf2v_720P is not None: | |
| print("β Model loaded successfully!") | |
| else: | |
| print( | |
| "β Failed to load model. The app will still start, but video generation will not work." | |
| ) | |
| print("=" * 50 + "\n") | |
| demo = create_interface() | |
| # Launch with ZeroGPU support | |
| # ZeroGPU spaces automatically handle GPU allocation | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |