Upload 5 files
Browse files- README.md +18 -0
- handler.py +255 -0
- model_index.json +32 -0
- requirements.txt +20 -0
- teacache.py +146 -0
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            language:
         | 
| 3 | 
            +
            - en
         | 
| 4 | 
            +
            base_model:
         | 
| 5 | 
            +
            - tencent/HunyuanVideo
         | 
| 6 | 
            +
            pipeline_tag: text-to-video
         | 
| 7 | 
            +
            library_name: diffusers
         | 
| 8 | 
            +
            tags:
         | 
| 9 | 
            +
            - HunyuanVideo
         | 
| 10 | 
            +
            - Tencent
         | 
| 11 | 
            +
            - Video
         | 
| 12 | 
            +
            license: other
         | 
| 13 | 
            +
            license_name: tencent-hunyuan-community
         | 
| 14 | 
            +
            license_link: LICENSE
         | 
| 15 | 
            +
            ---
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            This model is [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo) adapted to run on the Hugging Face Inference Endpoints.
         | 
| 18 | 
            +
             | 
    	
        handler.py
    ADDED
    
    | @@ -0,0 +1,255 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
            from typing import Dict, Any, Optional
         | 
| 3 | 
            +
            import base64
         | 
| 4 | 
            +
            import logging
         | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
            import traceback
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
         | 
| 9 | 
            +
            from varnish import Varnish
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight
         | 
| 12 | 
            +
            from teacache import enable_teacache, disable_teacache
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # Configure logging
         | 
| 15 | 
            +
            logging.basicConfig(level=logging.INFO)
         | 
| 16 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            @dataclass
         | 
| 19 | 
            +
            class GenerationConfig:
         | 
| 20 | 
            +
                """Configuration for video generation"""
         | 
| 21 | 
            +
                # Content settings
         | 
| 22 | 
            +
                prompt: str
         | 
| 23 | 
            +
                negative_prompt: str = ""
         | 
| 24 | 
            +
                
         | 
| 25 | 
            +
                # Model settings
         | 
| 26 | 
            +
                num_frames: int = 49  # Should be 4k + 1 format
         | 
| 27 | 
            +
                height: int = 320
         | 
| 28 | 
            +
                width: int = 576
         | 
| 29 | 
            +
                num_inference_steps: int = 50
         | 
| 30 | 
            +
                guidance_scale: float = 7.0
         | 
| 31 | 
            +
                
         | 
| 32 | 
            +
                # Reproducibility
         | 
| 33 | 
            +
                seed: int = -1
         | 
| 34 | 
            +
                
         | 
| 35 | 
            +
                # Varnish post-processing settings
         | 
| 36 | 
            +
                fps: int = 30
         | 
| 37 | 
            +
                double_num_frames: bool = False
         | 
| 38 | 
            +
                super_resolution: bool = False
         | 
| 39 | 
            +
                grain_amount: float = 0.0
         | 
| 40 | 
            +
                quality: int = 18  # CRF scale (0-51, lower is better)
         | 
| 41 | 
            +
                
         | 
| 42 | 
            +
                # Audio settings
         | 
| 43 | 
            +
                enable_audio: bool = False
         | 
| 44 | 
            +
                audio_prompt: str = ""
         | 
| 45 | 
            +
                audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                # TeaCache settings
         | 
| 48 | 
            +
                enable_teacache: bool = True
         | 
| 49 | 
            +
                teacache_threshold: float = 0.15 # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                
         | 
| 52 | 
            +
                # Enhance-A-Video settings
         | 
| 53 | 
            +
                enable_enhance_a_video: bool = True
         | 
| 54 | 
            +
                enhance_a_video_weight: float = 4.0
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def validate_and_adjust(self) -> 'GenerationConfig':
         | 
| 57 | 
            +
                    """Validate and adjust parameters"""
         | 
| 58 | 
            +
                    # Ensure num_frames follows 4k + 1 format
         | 
| 59 | 
            +
                    k = (self.num_frames - 1) // 4
         | 
| 60 | 
            +
                    self.num_frames = (k * 4) + 1
         | 
| 61 | 
            +
                    
         | 
| 62 | 
            +
                    # Set random seed if not specified
         | 
| 63 | 
            +
                    if self.seed == -1:
         | 
| 64 | 
            +
                        self.seed = random.randint(0, 2**32 - 1)
         | 
| 65 | 
            +
                        
         | 
| 66 | 
            +
                    return self
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            class EndpointHandler:
         | 
| 69 | 
            +
                """Handles video generation requests using HunyuanVideo and Varnish"""
         | 
| 70 | 
            +
                
         | 
| 71 | 
            +
                def __init__(self, path: str = ""):
         | 
| 72 | 
            +
                    """Initialize handler with models
         | 
| 73 | 
            +
                    
         | 
| 74 | 
            +
                    Args:
         | 
| 75 | 
            +
                        path: Path to model weights
         | 
| 76 | 
            +
                    """
         | 
| 77 | 
            +
                    self.device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                      
         | 
| 80 | 
            +
                    # Initialize transformer with Enhance-A-Video injection first
         | 
| 81 | 
            +
                    transformer = HunyuanVideoTransformer3DModel.from_pretrained(
         | 
| 82 | 
            +
                        path,
         | 
| 83 | 
            +
                        subfolder="transformer",
         | 
| 84 | 
            +
                        torch_dtype=torch.bfloat16
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    inject_enhance_for_hunyuanvideo(transformer)
         | 
| 87 | 
            +
                    
         | 
| 88 | 
            +
                    # Initialize HunyuanVideo pipeline with the enhanced transformer
         | 
| 89 | 
            +
                    self.pipeline = HunyuanVideoPipeline.from_pretrained(
         | 
| 90 | 
            +
                        path,
         | 
| 91 | 
            +
                        transformer=transformer,
         | 
| 92 | 
            +
                        torch_dtype=torch.float16,
         | 
| 93 | 
            +
                    ).to(self.device)
         | 
| 94 | 
            +
                
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    # Initialize text encoders in float16
         | 
| 97 | 
            +
                    self.pipeline.text_encoder = self.pipeline.text_encoder.half()
         | 
| 98 | 
            +
                    self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.half()
         | 
| 99 | 
            +
                    
         | 
| 100 | 
            +
                    # Initialize transformer in bfloat16 
         | 
| 101 | 
            +
                    self.pipeline.transformer = self.pipeline.transformer.to(torch.bfloat16)
         | 
| 102 | 
            +
                    
         | 
| 103 | 
            +
                    # Initialize VAE in float16
         | 
| 104 | 
            +
                    self.pipeline.vae = self.pipeline.vae.half()
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    # Initialize Varnish for post-processing
         | 
| 107 | 
            +
                    self.varnish = Varnish(
         | 
| 108 | 
            +
                        device=self.device,
         | 
| 109 | 
            +
                        model_base_dir="/repository/varnish"
         | 
| 110 | 
            +
                    )
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
         | 
| 113 | 
            +
                    """Process video generation requests
         | 
| 114 | 
            +
                    
         | 
| 115 | 
            +
                    Args:
         | 
| 116 | 
            +
                        data: Request data containing:
         | 
| 117 | 
            +
                            - inputs (str): Prompt for video generation
         | 
| 118 | 
            +
                            - parameters (dict): Generation parameters
         | 
| 119 | 
            +
                            
         | 
| 120 | 
            +
                    Returns:
         | 
| 121 | 
            +
                        Dictionary containing:
         | 
| 122 | 
            +
                            - video: Base64 encoded MP4 data URI
         | 
| 123 | 
            +
                            - content-type: MIME type
         | 
| 124 | 
            +
                            - metadata: Generation metadata
         | 
| 125 | 
            +
                    """
         | 
| 126 | 
            +
                    # Extract inputs
         | 
| 127 | 
            +
                    inputs = data.pop("inputs", data)
         | 
| 128 | 
            +
                    if isinstance(inputs, dict):
         | 
| 129 | 
            +
                        prompt = inputs.get("prompt", "")
         | 
| 130 | 
            +
                    else:
         | 
| 131 | 
            +
                        prompt = inputs
         | 
| 132 | 
            +
                        
         | 
| 133 | 
            +
                    params = data.get("parameters", {})
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                    # Create and validate config
         | 
| 136 | 
            +
                    config = GenerationConfig(
         | 
| 137 | 
            +
                        prompt=prompt,
         | 
| 138 | 
            +
                        negative_prompt=params.get("negative_prompt", ""),
         | 
| 139 | 
            +
                        num_frames=params.get("num_frames", 49),
         | 
| 140 | 
            +
                        height=params.get("height", 320),
         | 
| 141 | 
            +
                        width=params.get("width", 576),
         | 
| 142 | 
            +
                        num_inference_steps=params.get("num_inference_steps", 50),
         | 
| 143 | 
            +
                        guidance_scale=params.get("guidance_scale", 7.0),
         | 
| 144 | 
            +
                        seed=params.get("seed", -1),
         | 
| 145 | 
            +
                        fps=params.get("fps", 30),
         | 
| 146 | 
            +
                        double_num_frames=params.get("double_num_frames", False),
         | 
| 147 | 
            +
                        super_resolution=params.get("super_resolution", False),
         | 
| 148 | 
            +
                        grain_amount=params.get("grain_amount", 0.0),
         | 
| 149 | 
            +
                        quality=params.get("quality", 18),
         | 
| 150 | 
            +
                        enable_audio=params.get("enable_audio", False),
         | 
| 151 | 
            +
                        audio_prompt=params.get("audio_prompt", ""),
         | 
| 152 | 
            +
                        audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
         | 
| 153 | 
            +
                        enable_teacache=params.get("enable_teacache", True),
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                        # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup).
         | 
| 156 | 
            +
                        teacache_threshold=params.get("teacache_threshold", 0.15),
         | 
| 157 | 
            +
                        
         | 
| 158 | 
            +
                        enable_enhance_a_video=params.get("enable_enhance_a_video", True),
         | 
| 159 | 
            +
                        enhance_a_video_weight=params.get("enhance_a_video_weight", 4.0)
         | 
| 160 | 
            +
                    ).validate_and_adjust()
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    try:
         | 
| 163 | 
            +
                        # Set random seeds
         | 
| 164 | 
            +
                        if config.seed != -1:
         | 
| 165 | 
            +
                            torch.manual_seed(config.seed)
         | 
| 166 | 
            +
                            random.seed(config.seed)
         | 
| 167 | 
            +
                            generator = torch.Generator(device=self.device).manual_seed(config.seed)
         | 
| 168 | 
            +
                        else:
         | 
| 169 | 
            +
                            generator = None
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                        # Configure TeaCache
         | 
| 172 | 
            +
                        #if config.enable_teacache:
         | 
| 173 | 
            +
                        #    enable_teacache(
         | 
| 174 | 
            +
                        #        self.pipeline.transformer,
         | 
| 175 | 
            +
                        #        num_inference_steps=config.num_inference_steps,
         | 
| 176 | 
            +
                        #        rel_l1_thresh=config.teacache_threshold
         | 
| 177 | 
            +
                        #    )
         | 
| 178 | 
            +
                        #else:
         | 
| 179 | 
            +
                        #    disable_teacache(self.pipeline.transformer)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                        # Configure Enhance-A-Video weight if enabled
         | 
| 182 | 
            +
                        if config.enable_enhance_a_video:
         | 
| 183 | 
            +
                            set_enhance_weight(config.enhance_a_video_weight)
         | 
| 184 | 
            +
                            enable_enhance()
         | 
| 185 | 
            +
                        else:
         | 
| 186 | 
            +
                            # Reset enhance weight to 0 to effectively disable it
         | 
| 187 | 
            +
                            set_enhance_weight(0)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                        # Generate video frames
         | 
| 190 | 
            +
                        with torch.inference_mode():
         | 
| 191 | 
            +
                            output = self.pipeline(
         | 
| 192 | 
            +
                                prompt=config.prompt,
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                                # Failed to generate video: HunyuanVideoPipeline.__call__() got an unexpected keyword argument 'negative_prompt'
         | 
| 195 | 
            +
                                #negative_prompt=config.negative_prompt,
         | 
| 196 | 
            +
                                
         | 
| 197 | 
            +
                                num_frames=config.num_frames,
         | 
| 198 | 
            +
                                height=config.height,
         | 
| 199 | 
            +
                                width=config.width,
         | 
| 200 | 
            +
                                num_inference_steps=config.num_inference_steps,
         | 
| 201 | 
            +
                                guidance_scale=config.guidance_scale,
         | 
| 202 | 
            +
                                generator=generator,
         | 
| 203 | 
            +
                                output_type="pt",
         | 
| 204 | 
            +
                            ).frames
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                            # Process with Varnish
         | 
| 207 | 
            +
                            import asyncio
         | 
| 208 | 
            +
                            try:
         | 
| 209 | 
            +
                                loop = asyncio.get_event_loop()
         | 
| 210 | 
            +
                            except RuntimeError:
         | 
| 211 | 
            +
                                loop = asyncio.new_event_loop()
         | 
| 212 | 
            +
                                asyncio.set_event_loop(loop)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                            result = loop.run_until_complete(
         | 
| 215 | 
            +
                                self.varnish(
         | 
| 216 | 
            +
                                    input_data=output,
         | 
| 217 | 
            +
                                    fps=config.fps,
         | 
| 218 | 
            +
                                    double_num_frames=config.double_num_frames,
         | 
| 219 | 
            +
                                    super_resolution=config.super_resolution,
         | 
| 220 | 
            +
                                    grain_amount=config.grain_amount,
         | 
| 221 | 
            +
                                    enable_audio=config.enable_audio,
         | 
| 222 | 
            +
                                    audio_prompt=config.audio_prompt,
         | 
| 223 | 
            +
                                    audio_negative_prompt=config.audio_negative_prompt,
         | 
| 224 | 
            +
                                )
         | 
| 225 | 
            +
                            )
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                            # Get video data URI
         | 
| 228 | 
            +
                            video_uri = loop.run_until_complete(
         | 
| 229 | 
            +
                                result.write(
         | 
| 230 | 
            +
                                    type="data-uri",
         | 
| 231 | 
            +
                                    quality=config.quality
         | 
| 232 | 
            +
                                )
         | 
| 233 | 
            +
                            )
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                            return {
         | 
| 236 | 
            +
                                "video": video_uri,
         | 
| 237 | 
            +
                                "content-type": "video/mp4",
         | 
| 238 | 
            +
                                "metadata": {
         | 
| 239 | 
            +
                                    "width": result.metadata.width,
         | 
| 240 | 
            +
                                    "height": result.metadata.height,
         | 
| 241 | 
            +
                                    "num_frames": result.metadata.frame_count,
         | 
| 242 | 
            +
                                    "fps": result.metadata.fps,
         | 
| 243 | 
            +
                                    "duration": result.metadata.duration,
         | 
| 244 | 
            +
                                    "seed": config.seed,
         | 
| 245 | 
            +
                                    "enable_teacache": config.enable_teacache,
         | 
| 246 | 
            +
                                    "teacache_threshold": config.teacache_threshold if config.enable_teacache else 0,
         | 
| 247 | 
            +
                                    "enable_enhance_a_video": config.enable_enhance_a_video,
         | 
| 248 | 
            +
                                    "enhance_a_video_weight": config.enhance_a_video_weight if config.enable_enhance_a_video else 0,
         | 
| 249 | 
            +
                                }
         | 
| 250 | 
            +
                            }
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    except Exception as e:
         | 
| 253 | 
            +
                        message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
         | 
| 254 | 
            +
                        logger.error(message)
         | 
| 255 | 
            +
                        raise RuntimeError(message)
         | 
    	
        model_index.json
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_class_name": "HunyuanVideoPipeline",
         | 
| 3 | 
            +
              "_diffusers_version": "0.32.0.dev0",
         | 
| 4 | 
            +
              "scheduler": [
         | 
| 5 | 
            +
                "diffusers",
         | 
| 6 | 
            +
                "FlowMatchEulerDiscreteScheduler"
         | 
| 7 | 
            +
              ],
         | 
| 8 | 
            +
              "text_encoder": [
         | 
| 9 | 
            +
                "transformers",
         | 
| 10 | 
            +
                "LlamaModel"
         | 
| 11 | 
            +
              ],
         | 
| 12 | 
            +
              "text_encoder_2": [
         | 
| 13 | 
            +
                "transformers",
         | 
| 14 | 
            +
                "CLIPTextModel"
         | 
| 15 | 
            +
              ],
         | 
| 16 | 
            +
              "tokenizer": [
         | 
| 17 | 
            +
                "transformers",
         | 
| 18 | 
            +
                "LlamaTokenizerFast"
         | 
| 19 | 
            +
              ],
         | 
| 20 | 
            +
              "tokenizer_2": [
         | 
| 21 | 
            +
                "transformers",
         | 
| 22 | 
            +
                "CLIPTokenizer"
         | 
| 23 | 
            +
              ],
         | 
| 24 | 
            +
              "transformer": [
         | 
| 25 | 
            +
                "diffusers",
         | 
| 26 | 
            +
                "HunyuanVideoTransformer3DModel"
         | 
| 27 | 
            +
              ],
         | 
| 28 | 
            +
              "vae": [
         | 
| 29 | 
            +
                "diffusers",
         | 
| 30 | 
            +
                "AutoencoderKLHunyuanVideo"
         | 
| 31 | 
            +
              ]
         | 
| 32 | 
            +
            }
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            diffusers @ git+https://github.com/huggingface/diffusers.git@main
         | 
| 2 | 
            +
            varnish @ git+https://github.com/jbilcke-hf/varnish.git@main
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            opencv-python>=4.10.0.84
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            transformers==4.48.0
         | 
| 7 | 
            +
            huggingface_hub==0.27.1
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            tokenizers>=0.20.3
         | 
| 10 | 
            +
            accelerate>=1.1.1
         | 
| 11 | 
            +
            pandas>=2.0.3
         | 
| 12 | 
            +
            numpy
         | 
| 13 | 
            +
            einops==0.7.0
         | 
| 14 | 
            +
            tqdm>=4.66.5
         | 
| 15 | 
            +
            loguru>=0.7.2
         | 
| 16 | 
            +
            imageio>=2.34.2
         | 
| 17 | 
            +
            imageio-ffmpeg>=0.5.1
         | 
| 18 | 
            +
            safetensors>=0.4.5
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            moviepy==1.0.3
         | 
    	
        teacache.py
    ADDED
    
    | @@ -0,0 +1,146 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # teacache.py
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            from typing import Optional, Dict, Union, Any
         | 
| 5 | 
            +
            from functools import wraps
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            class TeaCacheConfig:
         | 
| 8 | 
            +
               """Configuration for TeaCache acceleration"""
         | 
| 9 | 
            +
               def __init__(
         | 
| 10 | 
            +
                   self,
         | 
| 11 | 
            +
                   rel_l1_thresh: float = 0.15,
         | 
| 12 | 
            +
                   enable: bool = True
         | 
| 13 | 
            +
               ):
         | 
| 14 | 
            +
                   self.rel_l1_thresh = rel_l1_thresh
         | 
| 15 | 
            +
                   self.enable = enable
         | 
| 16 | 
            +
                   self._reset_state()
         | 
| 17 | 
            +
               
         | 
| 18 | 
            +
               def _reset_state(self):
         | 
| 19 | 
            +
                   """Reset internal state"""
         | 
| 20 | 
            +
                   self.cnt = 0
         | 
| 21 | 
            +
                   self.accumulated_rel_l1_distance = 0
         | 
| 22 | 
            +
                   self.previous_modulated_input = None
         | 
| 23 | 
            +
                   self.previous_residual = None
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            def create_teacache_forward(original_forward):
         | 
| 26 | 
            +
               """Factory function to create a TeaCache-enabled forward pass"""
         | 
| 27 | 
            +
               @wraps(original_forward)
         | 
| 28 | 
            +
               def teacache_forward(
         | 
| 29 | 
            +
                   self,
         | 
| 30 | 
            +
                   hidden_states: torch.Tensor,
         | 
| 31 | 
            +
                   timestep: torch.Tensor,
         | 
| 32 | 
            +
                   encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 33 | 
            +
                   encoder_attention_mask: Optional[torch.Tensor] = None,
         | 
| 34 | 
            +
                   pooled_projections: Optional[torch.Tensor] = None,
         | 
| 35 | 
            +
                   guidance: Optional[torch.Tensor] = None,
         | 
| 36 | 
            +
                   attention_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 37 | 
            +
                   return_dict: bool = True,
         | 
| 38 | 
            +
               ):
         | 
| 39 | 
            +
                   # Skip TeaCache if not enabled
         | 
| 40 | 
            +
                   if not hasattr(self, 'teacache_config') or not self.teacache_config.enable:
         | 
| 41 | 
            +
                       return original_forward(
         | 
| 42 | 
            +
                           self,
         | 
| 43 | 
            +
                           hidden_states=hidden_states,
         | 
| 44 | 
            +
                           timestep=timestep,
         | 
| 45 | 
            +
                           encoder_hidden_states=encoder_hidden_states,
         | 
| 46 | 
            +
                           encoder_attention_mask=encoder_attention_mask,
         | 
| 47 | 
            +
                           pooled_projections=pooled_projections,
         | 
| 48 | 
            +
                           guidance=guidance,
         | 
| 49 | 
            +
                           attention_kwargs=attention_kwargs,
         | 
| 50 | 
            +
                           return_dict=return_dict
         | 
| 51 | 
            +
                       )
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                   config = self.teacache_config
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                   # Prepare modulation vectors similar to HunyuanVideo implementation
         | 
| 56 | 
            +
                   if pooled_projections is not None:
         | 
| 57 | 
            +
                       vec = self.vector_in(pooled_projections)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                   if guidance is not None:
         | 
| 60 | 
            +
                       if vec is None:
         | 
| 61 | 
            +
                           vec = self.guidance_in(guidance)
         | 
| 62 | 
            +
                       else:
         | 
| 63 | 
            +
                           vec = vec + self.guidance_in(guidance)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                   # TeaCache optimization logic
         | 
| 66 | 
            +
                   inp = hidden_states.clone()
         | 
| 67 | 
            +
                   if hasattr(self.double_blocks[0], 'img_norm1'):
         | 
| 68 | 
            +
                       # HunyuanVideo specific modulation
         | 
| 69 | 
            +
                       img_mod1_shift, img_mod1_scale, _, _, _, _ = self.double_blocks[0].img_mod(vec).chunk(6, dim=-1)
         | 
| 70 | 
            +
                       normed_inp = self.double_blocks[0].img_norm1(inp)
         | 
| 71 | 
            +
                       modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
         | 
| 72 | 
            +
                   else:
         | 
| 73 | 
            +
                       # Fallback modulation
         | 
| 74 | 
            +
                       normed_inp = self.transformer_blocks[0].norm1(inp)
         | 
| 75 | 
            +
                       modulated_inp = normed_inp
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                   # Determine if we should calculate or use cache
         | 
| 78 | 
            +
                   should_calc = True
         | 
| 79 | 
            +
                   if config.cnt == 0 or config.cnt == self.num_inference_steps - 1:
         | 
| 80 | 
            +
                       should_calc = True
         | 
| 81 | 
            +
                       config.accumulated_rel_l1_distance = 0
         | 
| 82 | 
            +
                   elif config.previous_modulated_input is not None:
         | 
| 83 | 
            +
                       coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, 
         | 
| 84 | 
            +
                                     -3.14987800e+00, 9.61237896e-02]
         | 
| 85 | 
            +
                       rescale_func = np.poly1d(coefficients)
         | 
| 86 | 
            +
                       
         | 
| 87 | 
            +
                       rel_l1 = ((modulated_inp - config.previous_modulated_input).abs().mean() / 
         | 
| 88 | 
            +
                                config.previous_modulated_input.abs().mean()).cpu().item()
         | 
| 89 | 
            +
                       config.accumulated_rel_l1_distance += rescale_func(rel_l1)
         | 
| 90 | 
            +
                       
         | 
| 91 | 
            +
                       should_calc = config.accumulated_rel_l1_distance >= config.rel_l1_thresh
         | 
| 92 | 
            +
                       if should_calc:
         | 
| 93 | 
            +
                           config.accumulated_rel_l1_distance = 0
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                   config.previous_modulated_input = modulated_inp
         | 
| 96 | 
            +
                   config.cnt += 1
         | 
| 97 | 
            +
                   if config.cnt >= self.num_inference_steps:
         | 
| 98 | 
            +
                       config.cnt = 0
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                   # Use cache or calculate new result
         | 
| 101 | 
            +
                   if not should_calc and config.previous_residual is not None:
         | 
| 102 | 
            +
                       hidden_states += config.previous_residual
         | 
| 103 | 
            +
                   else:
         | 
| 104 | 
            +
                       ori_hidden_states = hidden_states.clone()
         | 
| 105 | 
            +
                       
         | 
| 106 | 
            +
                       # Use original forward pass
         | 
| 107 | 
            +
                       out = original_forward(
         | 
| 108 | 
            +
                           self,
         | 
| 109 | 
            +
                           hidden_states=hidden_states,
         | 
| 110 | 
            +
                           timestep=timestep,
         | 
| 111 | 
            +
                           encoder_hidden_states=encoder_hidden_states,
         | 
| 112 | 
            +
                           encoder_attention_mask=encoder_attention_mask,
         | 
| 113 | 
            +
                           pooled_projections=pooled_projections,
         | 
| 114 | 
            +
                           guidance=guidance,
         | 
| 115 | 
            +
                           attention_kwargs=attention_kwargs,
         | 
| 116 | 
            +
                           return_dict=True
         | 
| 117 | 
            +
                       )
         | 
| 118 | 
            +
                       hidden_states = out["sample"]
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                       # Store residual for future use
         | 
| 121 | 
            +
                       config.previous_residual = hidden_states - ori_hidden_states
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                   if not return_dict:
         | 
| 124 | 
            +
                       return (hidden_states,)
         | 
| 125 | 
            +
                       
         | 
| 126 | 
            +
                   return {"sample": hidden_states}
         | 
| 127 | 
            +
             | 
| 128 | 
            +
               return teacache_forward
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            def enable_teacache(model: Any, num_inference_steps: int, rel_l1_thresh: float = 0.15):
         | 
| 131 | 
            +
               """Enable TeaCache acceleration for a model"""
         | 
| 132 | 
            +
               if not hasattr(model, '_original_forward'):
         | 
| 133 | 
            +
                   model._original_forward = model.forward
         | 
| 134 | 
            +
               
         | 
| 135 | 
            +
               model.teacache_config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh)
         | 
| 136 | 
            +
               model.num_inference_steps = num_inference_steps
         | 
| 137 | 
            +
               model.forward = create_teacache_forward(model._original_forward).__get__(model)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            def disable_teacache(model: Any):
         | 
| 140 | 
            +
               """Disable TeaCache acceleration for a model"""
         | 
| 141 | 
            +
               if hasattr(model, '_original_forward'):
         | 
| 142 | 
            +
                   model.forward = model._original_forward
         | 
| 143 | 
            +
                   del model._original_forward
         | 
| 144 | 
            +
               
         | 
| 145 | 
            +
               if hasattr(model, 'teacache_config'):
         | 
| 146 | 
            +
                   del model.teacache_config
         | 
