caarleexx commited on
Commit
8f69c0a
·
verified ·
1 Parent(s): 961e81a

Upload 4 files

Browse files
api/aduc_ltx_latent_patch.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # aduc_ltx_latent_patch.py
2
+ #
3
+ # Este módulo fornece um monkey patch para a classe LTXVideoPipeline da biblioteca ltx_video.
4
+ # A principal funcionalidade deste patch é otimizar o processo de condicionamento, permitindo
5
+ # que a pipeline aceite tensores de latentes pré-calculados diretamente através de um
6
+ # `ConditioningItem` modificado. Isso evita a re-codificação desnecessária de mídias (imagens/vídeos)
7
+ # pela VAE, resultando em um ganho de performance significativo quando os latentes já estão disponíveis.
8
+
9
+ import torch
10
+ from torch import Tensor
11
+ from typing import Optional, List, Tuple
12
+ from pathlib import Path
13
+ import os
14
+ import sys
15
+ from dataclasses import dataclass, replace
16
+
17
+ # --- CONFIGURAÇÃO DE PATH (Assume que LTXV_DEBUG e _run_setup_script existem no escopo que carrega este módulo) ---
18
+ # DEPS_DIR = Path("/data")
19
+ # LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
20
+ # def add_deps_to_path(repo_path: Path):
21
+ # """Adiciona o diretório do repositório ao sys.path para importações locais."""
22
+ # resolved_path = str(repo_path.resolve())
23
+ # if resolved_path not in sys.path:
24
+ # sys.path.insert(0, resolved_path)
25
+ # add_deps_to_path(LTX_VIDEO_REPO_DIR)
26
+
27
+
28
+ # Tenta importar as dependências necessárias do módulo original que será modificado.
29
+ try:
30
+ from ltx_video.pipelines.pipeline_ltx_video import (
31
+ LTXVideoPipeline,
32
+ ConditioningItem as OriginalConditioningItem
33
+ )
34
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
35
+ from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
36
+ from diffusers.utils.torch_utils import randn_tensor
37
+ except ImportError as e:
38
+ print(f"FATAL ERROR: Could not import dependencies from 'ltx_video'. "
39
+ f"Please ensure the environment is correctly set up. Error: {e}")
40
+ raise
41
+
42
+ print("[INFO] Patch module 'aduc_ltx_latent_patch' loaded successfully.")
43
+
44
+ # ==============================================================================
45
+ # 1. NOVA DEFINIÇÃO DA DATACLASS `PatchedConditioningItem`
46
+ # ==============================================================================
47
+
48
+ @dataclass
49
+ class PatchedConditioningItem:
50
+ """
51
+ Versão modificada do `ConditioningItem` que aceita tensores de pixel (`media_item`)
52
+ ou tensores de latentes pré-codificados (`latents`).
53
+
54
+ Attributes:
55
+ media_frame_number (int): Quadro inicial do item de condicionamento no vídeo.
56
+ conditioning_strength (float): Força do condicionamento (0.0 a 1.0).
57
+ media_item (Optional[Tensor]): Tensor de mídia (pixels). Usado se `latents` for None.
58
+ media_x (Optional[int]): Coordenada X (esquerda) para posicionamento espacial.
59
+ media_y (Optional[int]): Coordenada Y (topo) para posicionamento espacial.
60
+ latents (Optional[Tensor]): Tensor de latentes pré-codificado. Terá precedência sobre `media_item`.
61
+ """
62
+ media_frame_number: int
63
+ conditioning_strength: float
64
+ media_item: Optional[Tensor] = None
65
+ media_x: Optional[int] = None
66
+ media_y: Optional[int] = None
67
+ latents: Optional[Tensor] = None
68
+
69
+ def __post_init__(self):
70
+ """Valida o estado do objeto após a inicialização."""
71
+ if self.media_item is None and self.latents is None:
72
+ raise ValueError("A `PatchedConditioningItem` must have either 'media_item' or 'latents' defined.")
73
+ if self.media_item is not None and self.latents is not None:
74
+ print("[WARNING] `PatchedConditioningItem` received both 'media_item' and 'latents'. "
75
+ "The 'latents' tensor will take precedence.")
76
+
77
+ # ==============================================================================
78
+ # 2. NOVA IMPLEMENTAÇÃO DA FUNÇÃO `prepare_conditioning`
79
+ # ==============================================================================
80
+
81
+ def prepare_conditioning_with_latents(
82
+ self: LTXVideoPipeline,
83
+ conditioning_items: Optional[List[PatchedConditioningItem]],
84
+ init_latents: Tensor,
85
+ num_frames: int,
86
+ height: int,
87
+ width: int,
88
+ vae_per_channel_normalize: bool = False,
89
+ generator: Optional[torch.Generator] = None,
90
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor], int]:
91
+ """
92
+ Versão modificada de `prepare_conditioning` que prioriza o uso de latentes pré-calculados
93
+ dos `conditioning_items`, evitando a re-codificação desnecessária pela VAE.
94
+ """
95
+ assert isinstance(self, LTXVideoPipeline), "This function must be called as a method of LTXVideoPipeline."
96
+ assert isinstance(self.vae, CausalVideoAutoencoder), "VAE must be of type CausalVideoAutoencoder."
97
+
98
+ if not conditioning_items:
99
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
100
+ init_pixel_coords = latent_to_pixel_coords(
101
+ init_latent_coords, self.vae,
102
+ causal_fix=self.transformer.config.causal_temporal_positioning
103
+ )
104
+ return init_latents, init_pixel_coords, None, 0
105
+
106
+ init_conditioning_mask = torch.zeros(
107
+ init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device
108
+ )
109
+ extra_conditioning_latents = []
110
+ extra_conditioning_pixel_coords = []
111
+ extra_conditioning_mask = []
112
+ extra_conditioning_num_latents = 0
113
+
114
+ for item in conditioning_items:
115
+ item_latents: Tensor
116
+
117
+ if item.latents is not None:
118
+ item_latents = item.latents.to(dtype=init_latents.dtype, device=init_latents.device)
119
+ if item_latents.ndim != 5:
120
+ raise ValueError(f"Latents must have 5 dimensions (b, c, f, h, w), but got {item_latents.ndim}")
121
+ elif item.media_item is not None:
122
+ resized_item = self._resize_conditioning_item(item, height, width)
123
+ media_item = resized_item.media_item
124
+ assert media_item.ndim == 5, f"media_item must have 5 dims, but got {media_item.ndim}"
125
+ item_latents = vae_encode(
126
+ media_item.to(dtype=self.vae.dtype, device=self.vae.device),
127
+ self.vae,
128
+ vae_per_channel_normalize=vae_per_channel_normalize,
129
+ ).to(dtype=init_latents.dtype)
130
+ else:
131
+ raise ValueError("ConditioningItem is invalid: it has neither 'latents' nor 'media_item'.")
132
+
133
+ media_frame_number = item.media_frame_number
134
+ strength = item.conditioning_strength
135
+
136
+ if media_frame_number == 0:
137
+ # --- INÍCIO DA MODIFICAÇÃO ---
138
+ # Se `item.media_item` for None (nosso caso de uso otimizado), a função original `_get_latent_spatial_position`
139
+ # quebraria. Para evitar isso, criamos um item temporário com um tensor de placeholder que contém
140
+ # as informações de dimensão corretas, inferidas a partir dos próprios latentes.
141
+
142
+ item_for_spatial_position = item
143
+ if item.media_item is None:
144
+ # Infere as dimensões em pixels a partir da forma dos latentes
145
+ latent_h, latent_w = item_latents.shape[-2:]
146
+ pixel_h = latent_h * self.vae_scale_factor
147
+ pixel_w = latent_w * self.vae_scale_factor
148
+
149
+ # Cria um tensor de placeholder com o shape esperado (o conteúdo não importa)
150
+ placeholder_media_item = torch.empty(
151
+ (1, 1, 1, pixel_h, pixel_w), device=item_latents.device, dtype=item_latents.dtype
152
+ )
153
+
154
+ # Usa `dataclasses.replace` para criar uma cópia temporária do item com o placeholder
155
+ item_for_spatial_position = replace(item, media_item=placeholder_media_item)
156
+
157
+ # Chama a função original com um item que ela pode processar sem erro
158
+ item_latents, l_x, l_y = self._get_latent_spatial_position(
159
+ item_latents, item_for_spatial_position, height, width, strip_latent_border=True
160
+ )
161
+ # --- FIM DA MODIFICAÇÃO ---
162
+
163
+ _, _, f_l, h_l, w_l = item_latents.shape
164
+ init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = torch.lerp(
165
+ init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], item_latents, strength
166
+ )
167
+ init_conditioning_mask[:, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = strength
168
+ else:
169
+ if item_latents.shape[2] > 1:
170
+ (init_latents, init_conditioning_mask, item_latents) = self._handle_non_first_conditioning_sequence(
171
+ init_latents, init_conditioning_mask, item_latents, media_frame_number, strength
172
+ )
173
+
174
+ if item_latents is not None:
175
+ noise = randn_tensor(
176
+ item_latents.shape, generator=generator,
177
+ device=item_latents.device, dtype=item_latents.dtype
178
+ )
179
+ item_latents = torch.lerp(noise, item_latents, strength)
180
+ item_latents, latent_coords = self.patchifier.patchify(latents=item_latents)
181
+ pixel_coords = latent_to_pixel_coords(
182
+ latent_coords, self.vae,
183
+ causal_fix=self.transformer.config.causal_temporal_positioning
184
+ )
185
+ pixel_coords[:, 0] += media_frame_number
186
+ extra_conditioning_num_latents += item_latents.shape[1]
187
+ conditioning_mask = torch.full(
188
+ item_latents.shape[:2], strength,
189
+ dtype=torch.float32, device=init_latents.device
190
+ )
191
+ extra_conditioning_latents.append(item_latents)
192
+ extra_conditioning_pixel_coords.append(pixel_coords)
193
+ extra_conditioning_mask.append(conditioning_mask)
194
+
195
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
196
+ init_pixel_coords = latent_to_pixel_coords(
197
+ init_latent_coords, self.vae,
198
+ causal_fix=self.transformer.config.causal_temporal_positioning
199
+ )
200
+ init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
201
+ init_conditioning_mask = init_conditioning_mask.squeeze(-1)
202
+
203
+ if extra_conditioning_latents:
204
+ init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
205
+ init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
206
+ init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
207
+
208
+ if self.transformer.use_tpu_flash_attention:
209
+ init_latents = init_latents[:, :-extra_conditioning_num_latents]
210
+ init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
211
+ init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]
212
+
213
+ return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
214
+
215
+ # ==============================================================================
216
+ # 3. CLASSE DO MONKEY PATCHER
217
+ # ==============================================================================
218
+
219
+ class LTXLatentConditioningPatch:
220
+ """
221
+ Classe estática para aplicar e reverter o monkey patch na pipeline LTX-Video.
222
+ """
223
+ _original_prepare_conditioning = None
224
+ _is_patched = False
225
+
226
+ @staticmethod
227
+ def apply():
228
+ """
229
+ Aplica o monkey patch à classe `LTXVideoPipeline`.
230
+ """
231
+ if LTXLatentConditioningPatch._is_patched:
232
+ print("[WARNING] LTXLatentConditioningPatch has already been applied. Ignoring.")
233
+ return
234
+
235
+ print("[INFO] Applying monkey patch for latent-based conditioning...")
236
+
237
+ LTXLatentConditioningPatch._original_prepare_conditioning = LTXVideoPipeline.prepare_conditioning
238
+ LTXVideoPipeline.prepare_conditioning = prepare_conditioning_with_latents
239
+
240
+ LTXLatentConditioningPatch._is_patched = True
241
+ print("[SUCCESS] Monkey patch applied successfully.")
242
+ print(" - `LTXVideoPipeline.prepare_conditioning` has been updated.")
243
+ print(" - NOTE: Remember to use `aduc_ltx_latent_patch.PatchedConditioningItem` when creating conditioning items.")
244
+
245
+ @staticmethod
246
+ def revert():
247
+ """
248
+ Reverte o monkey patch, restaurando a implementação original.
249
+ """
250
+ if not LTXLatentConditioningPatch._is_patched:
251
+ print("[WARNING] Patch is not currently applied. No action taken.")
252
+ return
253
+
254
+ if LTXLatentConditioningPatch._original_prepare_conditioning:
255
+ print("[INFO] Reverting LTXLatentConditioningPatch...")
256
+ LTXVideoPipeline.prepare_conditioning = LTXLatentConditioningPatch._original_prepare_conditioning
257
+ LTXLatentConditioningPatch._is_patched = False
258
+ print("[SUCCESS] Patch reverted successfully. Original functionality restored.")
259
+ else:
260
+ print("[ERROR] Cannot revert: original implementation was not saved.")
api/ltx_server_refactored.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ltx_server_clean_refactor.py — VideoService (Modular Version with Simple Overlap Chunking)
2
+
3
+ # ==============================================================================
4
+ # 0. CONFIGURAÇÃO DE AMBIENTE E IMPORTAÇÕES
5
+ # ==============================================================================
6
+ import os
7
+ import sys
8
+ import gc
9
+ import yaml
10
+ import time
11
+ import json
12
+ import random
13
+ import shutil
14
+ import warnings
15
+ import tempfile
16
+ import traceback
17
+ import subprocess
18
+ from pathlib import Path
19
+ from typing import List, Dict, Optional, Tuple, Union
20
+ import cv2
21
+
22
+ # --- Configurações de Logging e Avisos ---
23
+ warnings.filterwarnings("ignore", category=UserWarning)
24
+ warnings.filterwarnings("ignore", category=FutureWarning)
25
+ from huggingface_hub import logging as hf_logging
26
+ hf_logging.set_verbosity_error()
27
+
28
+ # --- Importações de Bibliotecas de ML/Processamento ---
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import numpy as np
32
+ from PIL import Image
33
+ from einops import rearrange
34
+ from huggingface_hub import hf_hub_download
35
+ from safetensors import safe_open
36
+
37
+ from managers.vae_manager import vae_manager_singleton
38
+ from tools.video_encode_tool import video_encode_tool_singleton
39
+
40
+ from api.aduc_ltx_latent_patch import LTXLatentConditioningPatch, PatchedConditioningItem
41
+
42
+ # --- Constantes Globais ---
43
+ LTXV_DEBUG = True # Mude para False para desativar logs de debug
44
+ LTXV_FRAME_LOG_EVERY = 8
45
+ DEPS_DIR = Path("/data")
46
+ LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
47
+ RESULTS_DIR = Path("/app/output")
48
+ DEFAULT_FPS = 24.0
49
+
50
+ # ==============================================================================
51
+ # 1. SETUP E FUNÇÕES AUXILIARES DE AMBIENTE
52
+ # ==============================================================================
53
+
54
+ def _run_setup_script():
55
+ """Executa o script setup.py se o repositório LTX-Video não existir."""
56
+ setup_script_path = "setup.py"
57
+ if not os.path.exists(setup_script_path):
58
+ print("[DEBUG] 'setup.py' não encontrado. Pulando clonagem de dependências.")
59
+ return
60
+
61
+ print(f"[DEBUG] Repositório não encontrado em {LTX_VIDEO_REPO_DIR}. Executando setup.py...")
62
+ try:
63
+ subprocess.run([sys.executable, setup_script_path], check=True, capture_output=True, text=True)
64
+ print("[DEBUG] Script 'setup.py' concluído com sucesso.")
65
+ except subprocess.CalledProcessError as e:
66
+ print(f"[ERROR] Falha ao executar 'setup.py' (código {e.returncode}).\nOutput:\n{e.stdout}\n{e.stderr}")
67
+ sys.exit(1)
68
+
69
+ def add_deps_to_path(repo_path: Path):
70
+ """Adiciona o diretório do repositório ao sys.path para importações locais."""
71
+ resolved_path = str(repo_path.resolve())
72
+ if resolved_path not in sys.path:
73
+ sys.path.insert(0, resolved_path)
74
+ if LTXV_DEBUG:
75
+ print(f"[DEBUG] Adicionado ao sys.path: {resolved_path}")
76
+
77
+ # --- Execução da configuração inicial ---
78
+ if not LTX_VIDEO_REPO_DIR.exists():
79
+ _run_setup_script()
80
+ add_deps_to_path(LTX_VIDEO_REPO_DIR)
81
+
82
+ # --- Importações Dependentes do Path Adicionado ---
83
+ from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents
84
+ from ltx_video.pipelines.pipeline_ltx_video import adain_filter_latent
85
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
86
+ from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline
87
+ from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
88
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
89
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
90
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
91
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
92
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
93
+ import ltx_video.pipelines.crf_compressor as crf_compressor
94
+
95
+
96
+ def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
97
+ latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
98
+ latent_upsampler.to(device)
99
+ latent_upsampler.eval()
100
+ return latent_upsampler
101
+
102
+ def create_ltx_video_pipeline(
103
+ ckpt_path: str,
104
+ precision: str,
105
+ text_encoder_model_name_or_path: str,
106
+ sampler: Optional[str] = None,
107
+ device: Optional[str] = None,
108
+ enhance_prompt: bool = False,
109
+ prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
110
+ prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
111
+ ) -> LTXVideoPipeline:
112
+ ckpt_path = Path(ckpt_path)
113
+ assert os.path.exists(
114
+ ckpt_path
115
+ ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
116
+
117
+ with safe_open(ckpt_path, framework="pt") as f:
118
+ metadata = f.metadata()
119
+ config_str = metadata.get("config")
120
+ configs = json.loads(config_str)
121
+ allowed_inference_steps = configs.get("allowed_inference_steps", None)
122
+
123
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
124
+ transformer = Transformer3DModel.from_pretrained(ckpt_path)
125
+
126
+ # Use constructor if sampler is specified, otherwise use from_pretrained
127
+ if sampler == "from_checkpoint" or not sampler:
128
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
129
+ else:
130
+ scheduler = RectifiedFlowScheduler(
131
+ sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
132
+ )
133
+
134
+ text_encoder = T5EncoderModel.from_pretrained(
135
+ text_encoder_model_name_or_path, subfolder="text_encoder"
136
+ )
137
+ patchifier = SymmetricPatchifier(patch_size=1)
138
+ tokenizer = T5Tokenizer.from_pretrained(
139
+ text_encoder_model_name_or_path, subfolder="tokenizer"
140
+ )
141
+
142
+ transformer = transformer.to(device)
143
+ vae = vae.to(device)
144
+ text_encoder = text_encoder.to(device)
145
+
146
+ if enhance_prompt:
147
+ prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
148
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
149
+ )
150
+ prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
151
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
152
+ )
153
+ prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
154
+ prompt_enhancer_llm_model_name_or_path,
155
+ torch_dtype="bfloat16",
156
+ )
157
+ prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
158
+ prompt_enhancer_llm_model_name_or_path,
159
+ )
160
+ else:
161
+ prompt_enhancer_image_caption_model = None
162
+ prompt_enhancer_image_caption_processor = None
163
+ prompt_enhancer_llm_model = None
164
+ prompt_enhancer_llm_tokenizer = None
165
+
166
+ vae = vae.to(torch.bfloat16)
167
+ if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
168
+ transformer = transformer.to(torch.bfloat16)
169
+ text_encoder = text_encoder.to(torch.bfloat16)
170
+
171
+ # Use submodels for the pipeline
172
+ submodel_dict = {
173
+ "transformer": transformer,
174
+ "patchifier": patchifier,
175
+ "text_encoder": text_encoder,
176
+ "tokenizer": tokenizer,
177
+ "scheduler": scheduler,
178
+ "vae": vae,
179
+ "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
180
+ "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
181
+ "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
182
+ "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
183
+ "allowed_inference_steps": allowed_inference_steps,
184
+ }
185
+
186
+ pipeline = LTXVideoPipeline(**submodel_dict)
187
+
188
+ LTXLatentConditioningPatch.apply()
189
+
190
+ pipeline = pipeline.to(device)
191
+ return pipeline
192
+
193
+ # ==============================================================================
194
+ # 2. FUNÇÕES AUXILIARES DE PROCESSAMENTO
195
+ # ==============================================================================
196
+
197
+ def calculate_padding(orig_h: int, orig_w: int, target_h: int, target_w: int) -> Tuple[int, int, int, int]:
198
+ """Calcula o preenchimento para centralizar uma imagem em uma nova dimensão."""
199
+ pad_h = target_h - orig_h
200
+ pad_w = target_w - orig_w
201
+ pad_top = pad_h // 2
202
+ pad_bottom = pad_h - pad_top
203
+ pad_left = pad_w // 2
204
+ pad_right = pad_w - pad_left
205
+ return (pad_left, pad_right, pad_top, pad_bottom)
206
+
207
+ def log_tensor_info(tensor: torch.Tensor, name: str = "Tensor"):
208
+ """Exibe informações detalhadas sobre um tensor para depuração."""
209
+ if not isinstance(tensor, torch.Tensor):
210
+ print(f"\n[INFO] '{name}' não é um tensor.")
211
+ return
212
+ print(f"\n--- Tensor Info: {name} ---")
213
+ print(f" - Shape: {tuple(tensor.shape)}")
214
+ print(f" - Dtype: {tensor.dtype}")
215
+ print(f" - Device: {tensor.device}")
216
+ if tensor.numel() > 0:
217
+ try:
218
+ print(f" - Stats: Min={tensor.min().item():.4f}, Max={tensor.max().item():.4f}, Mean={tensor.mean().item():.4f}")
219
+ except RuntimeError:
220
+ print(" - Stats: Não foi possível calcular (ex: tensores bool).")
221
+ print("-" * 30)
222
+
223
+ # ==============================================================================
224
+ # 3. CLASSE PRINCIPAL DO SERVIÇO DE VÍDEO
225
+ # ==============================================================================
226
+
227
+ class VideoService:
228
+ """
229
+ Serviço encapsulado para gerar vídeos usando a pipeline LTX-Video.
230
+ Gerencia o carregamento de modelos, pré-processamento, geração em múltiplos
231
+ passos (baixa resolução, upscale com denoise) e pós-processamento.
232
+ """
233
+ def __init__(self):
234
+ """Inicializa o serviço, carregando configurações e modelos."""
235
+ t0 = time.perf_counter()
236
+ print("[INFO] Inicializando VideoService...")
237
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
238
+ self.config = self._load_config("ltxv-13b-0.9.8-distilled-fp8.yaml")
239
+
240
+ self.pipeline, self.latent_upsampler = self._load_models_from_hub()
241
+ self._move_models_to_device()
242
+
243
+ self.runtime_autocast_dtype = self._get_precision_dtype()
244
+ vae_manager_singleton.attach_pipeline(
245
+ self.pipeline,
246
+ device=self.device,
247
+ autocast_dtype=self.runtime_autocast_dtype
248
+ )
249
+ self._tmp_dirs = set()
250
+ RESULTS_DIR.mkdir(exist_ok=True)
251
+ print(f"[INFO] VideoService pronto. Tempo de inicialização: {time.perf_counter()-t0:.2f}s")
252
+
253
+ # --------------------------------------------------------------------------
254
+ # --- Métodos Públicos (API do Serviço) ---
255
+ # --------------------------------------------------------------------------
256
+
257
+ def _load_image_to_tensor_with_resize_and_crop(
258
+ self,
259
+ image_input: Union[str, Image.Image],
260
+ target_height: int = 512,
261
+ target_width: int = 768,
262
+ just_crop: bool = False,
263
+ ) -> torch.Tensor:
264
+ """Load and process an image into a tensor.
265
+
266
+ Args:
267
+ image_input: Either a file path (str) or a PIL Image object
268
+ target_height: Desired height of output tensor
269
+ target_width: Desired width of output tensor
270
+ just_crop: If True, only crop the image to the target size without resizing
271
+ """
272
+ if isinstance(image_input, str):
273
+ image = Image.open(image_input).convert("RGB")
274
+ elif isinstance(image_input, Image.Image):
275
+ image = image_input
276
+ else:
277
+ raise ValueError("image_input must be either a file path or a PIL Image object")
278
+
279
+ input_width, input_height = image.size
280
+ aspect_ratio_target = target_width / target_height
281
+ aspect_ratio_frame = input_width / input_height
282
+ if aspect_ratio_frame > aspect_ratio_target:
283
+ new_width = int(input_height * aspect_ratio_target)
284
+ new_height = input_height
285
+ x_start = (input_width - new_width) // 2
286
+ y_start = 0
287
+ else:
288
+ new_width = input_width
289
+ new_height = int(input_width / aspect_ratio_target)
290
+ x_start = 0
291
+ y_start = (input_height - new_height) // 2
292
+
293
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
294
+ if not just_crop:
295
+ image = image.resize((target_width, target_height))
296
+
297
+ image = np.array(image)
298
+ image = cv2.GaussianBlur(image, (3, 3), 0)
299
+ frame_tensor = torch.from_numpy(image).float()
300
+ frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
301
+ frame_tensor = frame_tensor.permute(2, 0, 1)
302
+ frame_tensor = (frame_tensor / 127.5) - 1.0
303
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
304
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
305
+
306
+
307
+
308
+
309
+
310
+ # ADICIONE A FUNÇÃO ABAIXO
311
+ @torch.no_grad()
312
+ def _image_to_latents(self, image_input: Union[str, Image.Image], height: int, width: int) -> torch.Tensor:
313
+ """
314
+ Converte uma imagem (caminho ou PIL) em um tensor de latentes 5D.
315
+ Retorna: Tensor na forma [1, C_lat, 1, H_lat, W_lat]
316
+ """
317
+ print(f"[DEBUG] Codificando imagem para latente ({height}x{width})...")
318
+ # 1. Carrega a imagem e a transforma em um tensor de pixel 5D
319
+ pixel_tensor = self._load_image_to_tensor_with_resize_and_crop(
320
+ image_input, target_height=height, target_width=width
321
+ )
322
+ pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.pipeline.vae.dtype)
323
+
324
+ # 2. Usa a VAE para codificar o tensor de pixel em um tensor de latentes
325
+ with torch.autocast(device_type=self.device.split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
326
+ # O vae_encode da pipeline já lida com tensores 5D
327
+ latents = self.pipeline.vae.encode(pixel_tensor_gpu).latent_dist.sample()
328
+
329
+ # 3. Aplica o fator de escala (importante para consistência)
330
+ if hasattr(self.pipeline.vae.config, "scaling_factor"):
331
+ latents = latents * self.pipeline.vae.config.scaling_factor
332
+
333
+ print(f"[DEBUG] Imagem codificada para latente com shape: {latents.shape}")
334
+ return latents
335
+
336
+ def _prepare_condition_items(self, items_list: List[Tuple], height: int, width: int) -> List[PatchedConditioningItem]:
337
+ """
338
+ Prepara os itens de condicionamento.
339
+ Recebe uma lista [Imagem, frame, peso], converte a Imagem para LATENTE
340
+ e cria uma lista de PatchedConditioningItem com o tensor em `latents`.
341
+ """
342
+ if not items_list:
343
+ return []
344
+
345
+ conditioning_items = []
346
+ for media_input, frame_idx, weight in items_list:
347
+ # 1. USA A NOVA FUNÇÃO PARA OBTER O TENSOR DE LATENTES DIRETAMENTE
348
+ latent_tensor = self._image_to_latents(media_input, height, width)
349
+
350
+ safe_frame_idx = int(frame_idx)
351
+
352
+ # 2. CRIA O PatchedConditioningItem COM O CAMPO `latents` PREENCHIDO
353
+ item = PatchedConditioningItem(
354
+ media_frame_number=safe_frame_idx,
355
+ conditioning_strength=float(weight),
356
+ media_item=None, # Importante: media_item é None
357
+ latents=latent_tensor # O latente pré-calculado vai aqui!
358
+ )
359
+ conditioning_items.append(item)
360
+
361
+ print(f"[INFO] Preparados {len(conditioning_items)} itens de condicionamento com latentes pré-codificados.")
362
+ return conditioning_items
363
+
364
+
365
+
366
+ def generate_low_resolution(
367
+ self,
368
+ prompt: str,
369
+ negative_prompt: str,
370
+ height: int,
371
+ width: int,
372
+ duration_secs: float,
373
+ guidance_scale: float,
374
+ seed: Optional[int] = None,
375
+ conditioning_items: Optional[List[PatchedConditioningItem]] = None
376
+ ) -> Tuple[str, str, int]:
377
+ """
378
+ ETAPA 1: Gera um vídeo e latentes em resolução base a partir de um prompt e
379
+ condicionamentos opcionais.
380
+ """
381
+ print("[INFO] Iniciando ETAPA 1: Geração de Baixa Resolução...")
382
+
383
+ # --- Configuração de Seed e Diretórios ---
384
+ used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
385
+ #seed_everything(used_seed)
386
+ print(f" - Usando Seed: {used_seed}")
387
+
388
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_low_")
389
+ self._register_tmp_dir(temp_dir)
390
+ results_dir = "/app/output"
391
+ os.makedirs(results_dir, exist_ok=True)
392
+
393
+ # --- Cálculo de Dimensões e Frames ---
394
+ actual_num_frames = int(round(duration_secs * DEFAULT_FPS))
395
+ downscaled_height = height
396
+ downscaled_width = width
397
+ #self._calculate_downscaled_dims(height, width)
398
+
399
+
400
+ print(f" - Frames: {actual_num_frames}, Duração: {duration_secs}s")
401
+ print(f" - Dimensões de Saída: {downscaled_height}x{downscaled_width}")
402
+
403
+ # --- Execução da Pipeline ---
404
+ with torch.autocast(device_type=self.device.split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
405
+
406
+ first_pass_kwargs = {
407
+ "prompt": prompt,
408
+ "negative_prompt": negative_prompt,
409
+ "height": downscaled_height,
410
+ "width": downscaled_width,
411
+ "num_frames": (actual_num_frames//8)+1,
412
+ "frame_rate": int(DEFAULT_FPS),
413
+ "generator": torch.Generator(device=self.device).manual_seed(used_seed),
414
+ "output_type": "latent",
415
+ "vae_per_channel_normalize": True,
416
+ "is_video": True,
417
+ "conditioning_items": conditioning_items,
418
+ "guidance_scale": float(guidance_scale),
419
+ **(self.config.get("first_pass", {}))
420
+ }
421
+
422
+ print(" - Enviando para a pipeline LTX...")
423
+ latents = self.pipeline(**first_pass_kwargs).images
424
+ print(f" - Latentes gerados com shape: {latents.shape}")
425
+
426
+ # Decodifica os latentes para pixels para criar o vídeo de preview
427
+ pixel_tensor = vae_manager_singleton.decode(latents, decode_timestep=float(self.config.get("decode_timestep", 0.05)))
428
+ tensor_path = self._save_latents_to_disk(latents, "latents_low_res", used_seed)
429
+
430
+ final_video_path = self._save_video_from_tensor(pixel_tensor, f"final_video_{seed}", seed, temp_dir, fps=DEFAULT_FPS)
431
+ return final_video_path
432
+
433
+ # --- Limpeza ---
434
+ self._finalize()
435
+
436
+ print("[SUCCESS] ETAPA 1 Concluída.")
437
+ return final_video_path, tensor_path, used_seed
438
+
439
+
440
+ # --------------------------------------------------------------------------
441
+ # --- Métodos Internos e Auxiliares ---
442
+ # --------------------------------------------------------------------------
443
+
444
+ def _finalize(self):
445
+ """Limpa a memória da GPU e os diretórios temporários."""
446
+ if LTXV_DEBUG:
447
+ print("[DEBUG] Finalize: iniciando limpeza...")
448
+
449
+ gc.collect()
450
+ if torch.cuda.is_available():
451
+ torch.cuda.empty_cache()
452
+ torch.cuda.ipc_collect()
453
+
454
+ # Limpa todos os diretórios temporários registrados
455
+ for d in list(self._tmp_dirs):
456
+ shutil.rmtree(d, ignore_errors=True)
457
+ self._tmp_dirs.remove(d)
458
+ if LTXV_DEBUG:
459
+ print(f"[DEBUG] Diretório temporário removido: {d}")
460
+
461
+ def _save_latents_to_disk(self, latents_tensor: torch.Tensor, base_filename: str, seed: int) -> str:
462
+ """Salva um tensor de latentes em um arquivo .pt."""
463
+ latents_cpu = latents_tensor.detach().to("cpu")
464
+ tensor_path = RESULTS_DIR / f"{base_filename}_{seed}.pt"
465
+ torch.save(latents_cpu, tensor_path)
466
+ if LTXV_DEBUG:
467
+ print(f"[DEBUG] Latentes salvos em: {tensor_path}")
468
+ return str(tensor_path)
469
+
470
+ def _save_video_from_tensor(self, pixel_tensor: torch.Tensor, base_filename: str, seed: int, temp_dir: str, fps: int = int(DEFAULT_FPS)) -> str:
471
+ """Salva um tensor de pixels como um arquivo de vídeo MP4."""
472
+ temp_path = os.path.join(temp_dir, f"{base_filename}_{seed}.mp4")
473
+ video_encode_tool_singleton.save_video_from_tensor(pixel_tensor, temp_path, fps=DEFAULT_FPS)
474
+
475
+ final_path = RESULTS_DIR / f"{base_filename}_{seed}.mp4"
476
+ shutil.move(temp_path, final_path)
477
+ print(f"[INFO] Vídeo final salvo em: {final_path}")
478
+ return str(final_path)
479
+
480
+ def _load_config(self, config_filename: str) -> Dict:
481
+ """Carrega o arquivo de configuração YAML."""
482
+ config_path = LTX_VIDEO_REPO_DIR / "configs" / config_filename
483
+ print(f"[INFO] Carregando configuração de: {config_path}")
484
+ with open(config_path, "r") as file:
485
+ return yaml.safe_load(file)
486
+
487
+ def _load_models_from_hub(self):
488
+ """Baixa e cria as instâncias da pipeline e do upsampler."""
489
+ t0 = time.perf_counter()
490
+ LTX_REPO = "Lightricks/LTX-Video"
491
+
492
+ print("[INFO] Baixando checkpoint principal...")
493
+ self.config["checkpoint_path"] = hf_hub_download(
494
+ repo_id=LTX_REPO, filename=self.config["checkpoint_path"],
495
+ token=os.getenv("HF_TOKEN")
496
+ )
497
+ print(f"[INFO] Checkpoint principal em: {self.config['checkpoint_path']}")
498
+
499
+ print("[INFO] Construindo pipeline...")
500
+ pipeline = create_ltx_video_pipeline(
501
+ ckpt_path=self.config["checkpoint_path"],
502
+ precision=self.config["precision"],
503
+ text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
504
+ sampler=self.config["sampler"],
505
+ device="cpu", # Carrega em CPU primeiro
506
+ enhance_prompt=False
507
+ )
508
+ print("[INFO] Pipeline construída.")
509
+
510
+ latent_upsampler = None
511
+ if self.config.get("spatial_upscaler_model_path"):
512
+ print("[INFO] Baixando upscaler espacial...")
513
+ self.config["spatial_upscaler_model_path"] = hf_hub_download(
514
+ repo_id=LTX_REPO, filename=self.config["spatial_upscaler_model_path"],
515
+ token=os.getenv("HF_TOKEN")
516
+ )
517
+ print(f"[INFO] Upscaler em: {self.config['spatial_upscaler_model_path']}")
518
+
519
+ print("[INFO] Construindo latent_upsampler...")
520
+ latent_upsampler = create_latent_upsampler(self.config["spatial_upscaler_model_path"], device="cpu")
521
+ print("[INFO] Latent upsampler construído.")
522
+
523
+ print(f"[INFO] Carregamento de modelos concluído em {time.perf_counter()-t0:.2f}s")
524
+ return pipeline, latent_upsampler
525
+
526
+ def _move_models_to_device(self):
527
+ """Move os modelos carregados para o dispositivo de computação (GPU/CPU)."""
528
+ print(f"[INFO] Movendo modelos para o dispositivo: {self.device}")
529
+ self.pipeline.to(self.device)
530
+ if self.latent_upsampler:
531
+ self.latent_upsampler.to(self.device)
532
+
533
+ def _get_precision_dtype(self) -> torch.dtype:
534
+ """Determina o dtype para autocast com base na configuração de precisão."""
535
+ prec = str(self.config.get("precision", "")).lower()
536
+ if prec in ["float8_e4m3fn", "bfloat16"]:
537
+ return torch.bfloat16
538
+ elif prec == "mixed_precision":
539
+ return torch.float16
540
+ return torch.float32
541
+
542
+ @torch.no_grad()
543
+ def _upsample_and_filter_latents(self, latents: torch.Tensor) -> torch.Tensor:
544
+ """Aplica o upsample espacial e o filtro AdaIN aos latentes."""
545
+ if not self.latent_upsampler:
546
+ raise ValueError("Latent Upsampler não está carregado para a operação de upscale.")
547
+
548
+ latents_unnormalized = un_normalize_latents(latents, self.pipeline.vae, vae_per_channel_normalize=True)
549
+ upsampled_latents_unnormalized = self.latent_upsampler(latents_unnormalized)
550
+ upsampled_latents_normalized = normalize_latents(upsampled_latents_unnormalized, self.pipeline.vae, vae_per_channel_normalize=True)
551
+
552
+ # Filtro AdaIN para manter consistência de cor/estilo com o vídeo de baixa resolução
553
+ return adain_filter_latent(latents=upsampled_latents_normalized, reference_latents=latents)
554
+
555
+ def _prepare_conditioning_tensor_from_path(self, filepath: str, height: int, width: int, padding: Tuple) -> torch.Tensor:
556
+ """Carrega uma imagem, redimensiona, aplica padding e move para o dispositivo."""
557
+ tensor = self._load_image_to_tensor_with_resize_and_crop(filepath, height, width)
558
+ tensor = F.pad(tensor, padding)
559
+ return tensor.to(self.device, dtype=self.runtime_autocast_dtype)
560
+
561
+ def _calculate_downscaled_dims(self, height: int, width: int) -> Tuple[int, int]:
562
+ """Calcula as dimensões para o primeiro passo (baixa resolução)."""
563
+ height_padded = ((height - 1) // 8 + 1) * 8
564
+ width_padded = ((width - 1) // 8 + 1) * 8
565
+
566
+ downscale_factor = self.config.get("downscale_factor", 0.6666666)
567
+ vae_scale_factor = self.pipeline.vae_scale_factor
568
+
569
+ target_w = int(width_padded * downscale_factor)
570
+ downscaled_width = target_w - (target_w % vae_scale_factor)
571
+
572
+ target_h = int(height_padded * downscale_factor)
573
+ downscaled_height = target_h - (target_h % vae_scale_factor)
574
+
575
+ return downscaled_height, downscaled_width
576
+
577
+
578
+ def _seed_everething(self, seed: int):
579
+ random.seed(seed)
580
+ np.random.seed(seed)
581
+ torch.manual_seed(seed)
582
+ if torch.cuda.is_available():
583
+ torch.cuda.manual_seed(seed)
584
+ if torch.backends.mps.is_available():
585
+ torch.mps.manual_seed(seed)
586
+
587
+
588
+ def _register_tmp_dir(self, dir_path: str):
589
+ """Registra um diretório temporário para limpeza posterior."""
590
+ if dir_path and os.path.isdir(dir_path):
591
+ self._tmp_dirs.add(dir_path)
592
+ if LTXV_DEBUG:
593
+ print(f"[DEBUG] Diretório temporário registrado: {dir_path}")
594
+
595
+ # ==============================================================================
596
+ # 4. INSTANCIAÇÃO E PONTO DE ENTRADA (Exemplo)
597
+ # ==============================================================================
598
+
599
+ print("Criando instância do VideoService. O carregamento do modelo começará agora...")
600
+ video_generation_service = VideoService()
601
+ print("Instância do VideoService pronta para uso.")