from dataclasses import dataclass from typing import Dict, Any, Optional import base64 import logging import random import traceback import torch from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel from varnish import Varnish from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight from teacache import enable_teacache, disable_teacache # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class GenerationConfig: """Configuration for video generation""" # Content settings prompt: str negative_prompt: str = "" # Model settings num_frames: int = 49 # Should be 4k + 1 format height: int = 320 width: int = 576 num_inference_steps: int = 50 guidance_scale: float = 7.0 # Reproducibility seed: int = -1 # Varnish post-processing settings fps: int = 30 double_num_frames: bool = False super_resolution: bool = False grain_amount: float = 0.0 quality: int = 18 # CRF scale (0-51, lower is better) # Audio settings enable_audio: bool = False audio_prompt: str = "" audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # TeaCache settings enable_teacache: bool = True teacache_threshold: float = 0.15 # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup) # Enhance-A-Video settings enable_enhance_a_video: bool = True enhance_a_video_weight: float = 4.0 def validate_and_adjust(self) -> 'GenerationConfig': """Validate and adjust parameters""" # Ensure num_frames follows 4k + 1 format k = (self.num_frames - 1) // 4 self.num_frames = (k * 4) + 1 # Set random seed if not specified if self.seed == -1: self.seed = random.randint(0, 2**32 - 1) return self class EndpointHandler: """Handles video generation requests using HunyuanVideo and Varnish""" def __init__(self, path: str = ""): """Initialize handler with models Args: path: Path to model weights """ self.device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize transformer with Enhance-A-Video injection first transformer = HunyuanVideoTransformer3DModel.from_pretrained( path, subfolder="transformer", torch_dtype=torch.bfloat16 ) inject_enhance_for_hunyuanvideo(transformer) # Initialize HunyuanVideo pipeline with the enhanced transformer self.pipeline = HunyuanVideoPipeline.from_pretrained( path, transformer=transformer, torch_dtype=torch.float16, ).to(self.device) # Initialize text encoders in float16 self.pipeline.text_encoder = self.pipeline.text_encoder.half() self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.half() # Initialize transformer in bfloat16 self.pipeline.transformer = self.pipeline.transformer.to(torch.bfloat16) # Initialize VAE in float16 self.pipeline.vae = self.pipeline.vae.half() # Initialize Varnish for post-processing self.varnish = Varnish( device=self.device, model_base_dir="/repository/varnish" ) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Process video generation requests Args: data: Request data containing: - inputs (str): Prompt for video generation - parameters (dict): Generation parameters Returns: Dictionary containing: - video: Base64 encoded MP4 data URI - content-type: MIME type - metadata: Generation metadata """ # Extract inputs inputs = data.pop("inputs", data) if isinstance(inputs, dict): prompt = inputs.get("prompt", "") else: prompt = inputs params = data.get("parameters", {}) # Create and validate config config = GenerationConfig( prompt=prompt, negative_prompt=params.get("negative_prompt", ""), num_frames=params.get("num_frames", 49), height=params.get("height", 320), width=params.get("width", 576), num_inference_steps=params.get("num_inference_steps", 50), guidance_scale=params.get("guidance_scale", 7.0), seed=params.get("seed", -1), fps=params.get("fps", 30), double_num_frames=params.get("double_num_frames", False), super_resolution=params.get("super_resolution", False), grain_amount=params.get("grain_amount", 0.0), quality=params.get("quality", 18), enable_audio=params.get("enable_audio", False), audio_prompt=params.get("audio_prompt", ""), audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"), enable_teacache=params.get("enable_teacache", True), # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup). teacache_threshold=params.get("teacache_threshold", 0.15), enable_enhance_a_video=params.get("enable_enhance_a_video", True), enhance_a_video_weight=params.get("enhance_a_video_weight", 4.0) ).validate_and_adjust() try: # Set random seeds if config.seed != -1: torch.manual_seed(config.seed) random.seed(config.seed) generator = torch.Generator(device=self.device).manual_seed(config.seed) else: generator = None # Configure TeaCache #if config.enable_teacache: # enable_teacache( # self.pipeline.transformer, # num_inference_steps=config.num_inference_steps, # rel_l1_thresh=config.teacache_threshold # ) #else: # disable_teacache(self.pipeline.transformer) # Configure Enhance-A-Video weight if enabled if config.enable_enhance_a_video: set_enhance_weight(config.enhance_a_video_weight) enable_enhance() else: # Reset enhance weight to 0 to effectively disable it set_enhance_weight(0) # Generate video frames with torch.inference_mode(): output = self.pipeline( prompt=config.prompt, # Failed to generate video: HunyuanVideoPipeline.__call__() got an unexpected keyword argument 'negative_prompt' #negative_prompt=config.negative_prompt, num_frames=config.num_frames, height=config.height, width=config.width, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, generator=generator, output_type="pt", ).frames # Process with Varnish import asyncio try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) result = loop.run_until_complete( self.varnish( input_data=output, fps=config.fps, double_num_frames=config.double_num_frames, super_resolution=config.super_resolution, grain_amount=config.grain_amount, enable_audio=config.enable_audio, audio_prompt=config.audio_prompt, audio_negative_prompt=config.audio_negative_prompt, ) ) # Get video data URI video_uri = loop.run_until_complete( result.write( type="data-uri", quality=config.quality ) ) return { "video": video_uri, "content-type": "video/mp4", "metadata": { "width": result.metadata.width, "height": result.metadata.height, "num_frames": result.metadata.frame_count, "fps": result.metadata.fps, "duration": result.metadata.duration, "seed": config.seed, "enable_teacache": config.enable_teacache, "teacache_threshold": config.teacache_threshold if config.enable_teacache else 0, "enable_enhance_a_video": config.enable_enhance_a_video, "enhance_a_video_weight": config.enhance_a_video_weight if config.enable_enhance_a_video else 0, } } except Exception as e: message = f"Error generating video ({str(e)})\n{traceback.format_exc()}" logger.error(message) raise RuntimeError(message)