caarleexx commited on
Commit
5a15d3e
·
verified ·
1 Parent(s): b670def

Upload 4 files

Browse files
api/aduc_ltx_latent_patch.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # aduc_ltx_latent_patch.py
2
+ #
3
+ # Este módulo fornece um monkey patch para a classe LTXVideoPipeline da biblioteca ltx_video.
4
+ # A principal funcionalidade deste patch é otimizar o processo de condicionamento, permitindo
5
+ # que a pipeline aceite tensores de latentes pré-calculados diretamente através de um
6
+ # `ConditioningItem` modificado. Isso evita a re-codificação desnecessária de mídias (imagens/vídeos)
7
+ # pela VAE, resultando em um ganho de performance significativo quando os latentes já estão disponíveis.
8
+
9
+ import torch
10
+ from torch import Tensor
11
+ from typing import Optional, List, Tuple
12
+ from pathlib import Path
13
+ import os
14
+ import sys
15
+
16
+ DEPS_DIR = Path("/data")
17
+ LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
18
+ def add_deps_to_path(repo_path: Path):
19
+ """Adiciona o diretório do repositório ao sys.path para importações locais."""
20
+ resolved_path = str(repo_path.resolve())
21
+ if resolved_path not in sys.path:
22
+ sys.path.insert(0, resolved_path)
23
+ if LTXV_DEBUG:
24
+ print(f"[DEBUG] Adicionado ao sys.path: {resolved_path}")
25
+
26
+ # --- Execução da configuração inicial ---
27
+ if not LTX_VIDEO_REPO_DIR.exists():
28
+ _run_setup_script()
29
+ add_deps_to_path(LTX_VIDEO_REPO_DIR)
30
+
31
+
32
+ # Tenta importar as dependências necessárias do módulo original que será modificado.
33
+ # Isso requer que o ambiente Python tenha o pacote `ltx_video` acessível em seu sys.path.
34
+ try:
35
+ from ltx_video.pipelines.pipeline_ltx_video import (
36
+ LTXVideoPipeline,
37
+ ConditioningItem as OriginalConditioningItem
38
+ )
39
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
40
+ from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ except ImportError as e:
43
+ print(f"FATAL ERROR: Could not import dependencies from 'ltx_video'. "
44
+ f"Please ensure the environment is correctly set up. Error: {e}")
45
+ # Interrompe a execução se as dependências essenciais não puderem ser encontradas.
46
+ raise
47
+
48
+ print("[INFO] Patch module 'aduc_ltx_latent_patch' loaded successfully.")
49
+
50
+ # ==============================================================================
51
+ # 1. NOVA DEFINIÇÃO DA DATACLASS `ConditioningItem`
52
+ # ==============================================================================
53
+
54
+ from dataclasses import dataclass
55
+
56
+ @dataclass
57
+ class PatchedConditioningItem:
58
+ """
59
+ Versão modificada do `ConditioningItem` que aceita tensores de pixel (`media_item`)
60
+ ou tensores de latentes pré-codificados (`latents`).
61
+
62
+ Attributes:
63
+ media_frame_number (int): Quadro inicial do item de condicionamento no vídeo.
64
+ conditioning_strength (float): Força do condicionamento (0.0 a 1.0).
65
+ media_item (Optional[Tensor]): Tensor de mídia (pixels). Usado se `latents` for None.
66
+ media_x (Optional[int]): Coordenada X (esquerda) para posicionamento espacial.
67
+ media_y (Optional[int]): Coordenada Y (topo) para posicionamento espacial.
68
+ latents (Optional[Tensor]): Tensor de latentes pré-codificado. Terá precedência sobre `media_item`.
69
+ """
70
+ media_frame_number: int
71
+ conditioning_strength: float
72
+ media_item: Optional[Tensor] = None
73
+ media_x: Optional[int] = None
74
+ media_y: Optional[int] = None
75
+ latents: Optional[Tensor] = None
76
+
77
+ def __post_init__(self):
78
+ """Valida o estado do objeto após a inicialização."""
79
+ if self.media_item is None and self.latents is None:
80
+ raise ValueError("A `PatchedConditioningItem` must have either 'media_item' or 'latents' defined.")
81
+ if self.media_item is not None and self.latents is not None:
82
+ print("[WARNING] `PatchedConditioningItem` received both 'media_item' and 'latents'. "
83
+ "The 'latents' tensor will take precedence.")
84
+
85
+ # ==============================================================================
86
+ # 2. NOVA IMPLEMENTAÇÃO DA FUNÇÃO `prepare_conditioning`
87
+ # ==============================================================================
88
+
89
+ def prepare_conditioning_with_latents(
90
+ self: LTXVideoPipeline,
91
+ conditioning_items: Optional[List[PatchedConditioningItem]],
92
+ init_latents: Tensor,
93
+ num_frames: int,
94
+ height: int,
95
+ width: int,
96
+ vae_per_channel_normalize: bool = False,
97
+ generator: Optional[torch.Generator] = None,
98
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor], int]:
99
+ """
100
+ Versão modificada de `prepare_conditioning` que prioriza o uso de latentes pré-calculados
101
+ dos `conditioning_items`, evitando a re-codificação desnecessária pela VAE.
102
+ """
103
+ assert isinstance(self, LTXVideoPipeline), "This function must be called as a method of LTXVideoPipeline."
104
+ assert isinstance(self.vae, CausalVideoAutoencoder), "VAE must be of type CausalVideoAutoencoder."
105
+
106
+ # Se não há itens de condicionamento, apenas patchifica os latentes e retorna.
107
+ if not conditioning_items:
108
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
109
+ init_pixel_coords = latent_to_pixel_coords(
110
+ init_latent_coords, self.vae,
111
+ causal_fix=self.transformer.config.causal_temporal_positioning
112
+ )
113
+ return init_latents, init_pixel_coords, None, 0
114
+
115
+ # Inicializa tensores para acumular resultados
116
+ init_conditioning_mask = torch.zeros(
117
+ init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device
118
+ )
119
+ extra_conditioning_latents = []
120
+ extra_conditioning_pixel_coords = []
121
+ extra_conditioning_mask = []
122
+ extra_conditioning_num_latents = 0
123
+
124
+ for item in conditioning_items:
125
+ item_latents: Tensor
126
+
127
+ # --- LÓGICA CENTRAL DO PATCH ---
128
+ if item.latents is not None:
129
+ # 1. Se latentes pré-calculados existem, use-os diretamente.
130
+ item_latents = item.latents.to(dtype=init_latents.dtype, device=init_latents.device)
131
+ if item_latents.ndim != 5:
132
+ raise ValueError(f"Latents must have 5 dimensions (b, c, f, h, w), but got {item_latents.ndim}")
133
+ elif item.media_item is not None:
134
+ # 2. Caso contrário, volte para o fluxo original de codificação da VAE.
135
+ resized_item = self._resize_conditioning_item(item, height, width)
136
+ media_item = resized_item.media_item
137
+ assert media_item.ndim == 5, f"media_item must have 5 dims, but got {media_item.ndim}"
138
+
139
+ item_latents = vae_encode(
140
+ media_item.to(dtype=self.vae.dtype, device=self.vae.device),
141
+ self.vae,
142
+ vae_per_channel_normalize=vae_per_channel_normalize,
143
+ ).to(dtype=init_latents.dtype)
144
+ else:
145
+ # Este caso é prevenido pelo __post_init__ do dataclass, mas é bom ter uma checagem.
146
+ raise ValueError("ConditioningItem is invalid: it has neither 'latents' nor 'media_item'.")
147
+ # --- FIM DA LÓGICA DO PATCH ---
148
+
149
+ media_frame_number = item.media_frame_number
150
+ strength = item.conditioning_strength
151
+
152
+ # O resto da lógica da função original é aplicado sobre `item_latents`.
153
+ if media_frame_number == 0:
154
+ item_latents, l_x, l_y = self._get_latent_spatial_position(
155
+ item_latents, item, height, width, strip_latent_border=True
156
+ )
157
+ _, _, f_l, h_l, w_l = item_latents.shape
158
+ init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = torch.lerp(
159
+ init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], item_latents, strength
160
+ )
161
+ init_conditioning_mask[:, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = strength
162
+ else:
163
+ if item_latents.shape[2] > 1:
164
+ (init_latents, init_conditioning_mask, item_latents) = self._handle_non_first_conditioning_sequence(
165
+ init_latents, init_conditioning_mask, item_latents, media_frame_number, strength
166
+ )
167
+
168
+ if item_latents is not None:
169
+ noise = randn_tensor(
170
+ item_latents.shape, generator=generator,
171
+ device=item_latents.device, dtype=item_latents.dtype
172
+ )
173
+ item_latents = torch.lerp(noise, item_latents, strength)
174
+ item_latents, latent_coords = self.patchifier.patchify(latents=item_latents)
175
+ pixel_coords = latent_to_pixel_coords(
176
+ latent_coords, self.vae,
177
+ causal_fix=self.transformer.config.causal_temporal_positioning
178
+ )
179
+ pixel_coords[:, 0] += media_frame_number
180
+ extra_conditioning_num_latents += item_latents.shape[1]
181
+ conditioning_mask = torch.full(
182
+ item_latents.shape[:2], strength,
183
+ dtype=torch.float32, device=init_latents.device
184
+ )
185
+ extra_conditioning_latents.append(item_latents)
186
+ extra_conditioning_pixel_coords.append(pixel_coords)
187
+ extra_conditioning_mask.append(conditioning_mask)
188
+
189
+ # Patchifica os latentes principais e a máscara de condicionamento
190
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
191
+ init_pixel_coords = latent_to_pixel_coords(
192
+ init_latent_coords, self.vae,
193
+ causal_fix=self.transformer.config.causal_temporal_positioning
194
+ )
195
+ init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
196
+ init_conditioning_mask = init_conditioning_mask.squeeze(-1)
197
+
198
+ # Concatena os latentes extras (se houver)
199
+ if extra_conditioning_latents:
200
+ init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
201
+ init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
202
+ init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
203
+
204
+ if self.transformer.use_tpu_flash_attention:
205
+ init_latents = init_latents[:, :-extra_conditioning_num_latents]
206
+ init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
207
+ init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]
208
+
209
+ return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
210
+
211
+
212
+ # ==============================================================================
213
+ # 3. CLASSE DO MONKEY PATCHER
214
+ # ==============================================================================
215
+
216
+ class LTXLatentConditioningPatch:
217
+ """
218
+ Classe estática para aplicar e reverter o monkey patch na pipeline LTX-Video.
219
+
220
+ Esta classe substitui o método `prepare_conditioning` da `LTXVideoPipeline`
221
+ pela versão otimizada que suporta latentes pré-calculados, e implicitamente
222
+ requer o uso da `PatchedConditioningItem`.
223
+ """
224
+ _original_prepare_conditioning = None
225
+ _is_patched = False
226
+
227
+ @staticmethod
228
+ def apply():
229
+ """
230
+ Aplica o monkey patch à classe `LTXVideoPipeline`.
231
+
232
+ Guarda o método original e o substitui pela nova implementação.
233
+ É idempotente; aplicar múltiplas vezes não causa efeito adicional.
234
+ """
235
+ if LTXLatentConditioningPatch._is_patched:
236
+ print("[WARNING] LTXLatentConditioningPatch has already been applied. Ignoring.")
237
+ return
238
+
239
+ print("[INFO] Applying monkey patch for latent-based conditioning...")
240
+
241
+ # Guarda a implementação original para permitir a reversão.
242
+ LTXLatentConditioningPatch._original_prepare_conditioning = LTXVideoPipeline.prepare_conditioning
243
+
244
+ # Substitui o método na classe LTXVideoPipeline.
245
+ # Todas as instâncias futuras e existentes da classe usarão este novo método.
246
+ LTXVideoPipeline.prepare_conditioning = prepare_conditioning_with_latents
247
+
248
+ LTXLatentConditioningPatch._is_patched = True
249
+ print("[SUCCESS] Monkey patch applied successfully.")
250
+ print(" - `LTXVideoPipeline.prepare_conditioning` has been updated.")
251
+ print(" - NOTE: Remember to use `aduc_ltx_latent_patch.PatchedConditioningItem` when creating conditioning items.")
252
+
253
+ @staticmethod
254
+ def revert():
255
+ """
256
+ Reverte o monkey patch, restaurando a implementação original.
257
+ """
258
+ if not LTXLatentConditioningPatch._is_patched:
259
+ print("[WARNING] Patch is not currently applied. No action taken.")
260
+ return
261
+
262
+ if LTXLatentConditioningPatch._original_prepare_conditioning:
263
+ print("[INFO] Reverting LTXLatentConditioningPatch...")
264
+ LTXVideoPipeline.prepare_conditioning = LTXLatentConditioningPatch._original_prepare_conditioning
265
+ LTXLatentConditioningPatch._is_patched = False
266
+ print("[SUCCESS] Patch reverted successfully. Original functionality restored.")
267
+ else:
268
+ print("[ERROR] Cannot revert: original implementation was not saved.")
api/gpu_manager.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/gpu_manager.py
2
+
3
+ import os
4
+ import torch
5
+
6
+ class GPUManager:
7
+ """
8
+ Gerencia e aloca GPUs disponíveis para diferentes serviços (LTX, SeedVR).
9
+ """
10
+ def __init__(self):
11
+ self.total_gpus = torch.cuda.device_count()
12
+ self.ltx_gpus = []
13
+ self.seedvr_gpus = []
14
+ self._allocate_gpus()
15
+
16
+ def _allocate_gpus(self):
17
+ """
18
+ Divide as GPUs disponíveis entre os serviços LTX e SeedVR.
19
+ """
20
+ print("="*50)
21
+ print("🤖 Gerenciador de GPUs inicializado.")
22
+ print(f" > Total de GPUs detectadas: {self.total_gpus}")
23
+
24
+ if self.total_gpus == 0:
25
+ print(" > Nenhuma GPU detectada. Operando em modo CPU.")
26
+ elif self.total_gpus == 1:
27
+ print(" > 1 GPU detectada. Modo de compartilhamento de memória será usado.")
28
+ # Ambos usarão a GPU 0, mas precisarão gerenciar a memória
29
+ self.ltx_gpus = [0]
30
+ self.seedvr_gpus = [0]
31
+ else:
32
+ # Divide as GPUs entre os dois serviços
33
+ mid_point = self.total_gpus // 2
34
+ self.ltx_gpus = list(range(0, mid_point))
35
+ self.seedvr_gpus = list(range(mid_point, self.total_gpus))
36
+ print(f" > Alocação: LTX usará GPUs {self.ltx_gpus}, SeedVR usará GPUs {self.seedvr_gpus}.")
37
+
38
+ print("="*50)
39
+
40
+ def get_ltx_device(self):
41
+ """Retorna o dispositivo principal para o LTX."""
42
+ if not self.ltx_gpus:
43
+ return torch.device("cpu")
44
+ # Por padrão, o modelo principal do LTX roda na primeira GPU do seu grupo
45
+ return torch.device(f"cuda:{self.ltx_gpus[0]}")
46
+
47
+ def get_seedvr_devices(self) -> list:
48
+ """Retorna a lista de IDs de GPU para o SeedVR."""
49
+ return self.seedvr_gpus
50
+
51
+ def requires_memory_swap(self) -> bool:
52
+ """Verifica se é necessário mover modelos entre CPU e GPU."""
53
+ return self.total_gpus < 2
54
+
55
+ # Instância global para ser importada por outros módulos
56
+ gpu_manager = GPUManager()
api/ltx_server_refactored.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ltx_server_clean_refactor.py — VideoService (Modular Version with Simple Overlap Chunking)
2
+
3
+ # ==============================================================================
4
+ # 0. CONFIGURAÇÃO DE AMBIENTE E IMPORTAÇÕES
5
+ # ==============================================================================
6
+ import os
7
+ import sys
8
+ import gc
9
+ import yaml
10
+ import time
11
+ import json
12
+ import random
13
+ import shutil
14
+ import warnings
15
+ import tempfile
16
+ import traceback
17
+ import subprocess
18
+ from pathlib import Path
19
+ from typing import List, Dict, Optional, Tuple, Union
20
+ import cv2
21
+
22
+ # --- Configurações de Logging e Avisos ---
23
+ warnings.filterwarnings("ignore", category=UserWarning)
24
+ warnings.filterwarnings("ignore", category=FutureWarning)
25
+ from huggingface_hub import logging as hf_logging
26
+ hf_logging.set_verbosity_error()
27
+
28
+ # --- Importações de Bibliotecas de ML/Processamento ---
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import numpy as np
32
+ from PIL import Image
33
+ from einops import rearrange
34
+ from huggingface_hub import hf_hub_download
35
+ from safetensors import safe_open
36
+
37
+ from managers.vae_manager import vae_manager_singleton
38
+ from tools.video_encode_tool import video_encode_tool_singleton
39
+
40
+ from api.aduc_ltx_latent_patch import LTXLatentConditioningPatch, PatchedConditioningItem
41
+
42
+ # --- Constantes Globais ---
43
+ LTXV_DEBUG = True # Mude para False para desativar logs de debug
44
+ LTXV_FRAME_LOG_EVERY = 8
45
+ DEPS_DIR = Path("/data")
46
+ LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
47
+ RESULTS_DIR = Path("/app/output")
48
+ DEFAULT_FPS = 24.0
49
+
50
+ # ==============================================================================
51
+ # 1. SETUP E FUNÇÕES AUXILIARES DE AMBIENTE
52
+ # ==============================================================================
53
+
54
+ def _run_setup_script():
55
+ """Executa o script setup.py se o repositório LTX-Video não existir."""
56
+ setup_script_path = "setup.py"
57
+ if not os.path.exists(setup_script_path):
58
+ print("[DEBUG] 'setup.py' não encontrado. Pulando clonagem de dependências.")
59
+ return
60
+
61
+ print(f"[DEBUG] Repositório não encontrado em {LTX_VIDEO_REPO_DIR}. Executando setup.py...")
62
+ try:
63
+ subprocess.run([sys.executable, setup_script_path], check=True, capture_output=True, text=True)
64
+ print("[DEBUG] Script 'setup.py' concluído com sucesso.")
65
+ except subprocess.CalledProcessError as e:
66
+ print(f"[ERROR] Falha ao executar 'setup.py' (código {e.returncode}).\nOutput:\n{e.stdout}\n{e.stderr}")
67
+ sys.exit(1)
68
+
69
+ def add_deps_to_path(repo_path: Path):
70
+ """Adiciona o diretório do repositório ao sys.path para importações locais."""
71
+ resolved_path = str(repo_path.resolve())
72
+ if resolved_path not in sys.path:
73
+ sys.path.insert(0, resolved_path)
74
+ if LTXV_DEBUG:
75
+ print(f"[DEBUG] Adicionado ao sys.path: {resolved_path}")
76
+
77
+ # --- Execução da configuração inicial ---
78
+ if not LTX_VIDEO_REPO_DIR.exists():
79
+ _run_setup_script()
80
+ add_deps_to_path(LTX_VIDEO_REPO_DIR)
81
+
82
+ # --- Importações Dependentes do Path Adicionado ---
83
+ from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents
84
+ from ltx_video.pipelines.pipeline_ltx_video import adain_filter_latent
85
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
86
+ from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline
87
+ from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
88
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
89
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
90
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
91
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
92
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
93
+ import ltx_video.pipelines.crf_compressor as crf_compressor
94
+
95
+
96
+ def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
97
+ latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
98
+ latent_upsampler.to(device)
99
+ latent_upsampler.eval()
100
+ return latent_upsampler
101
+
102
+ def create_ltx_video_pipeline(
103
+ ckpt_path: str,
104
+ precision: str,
105
+ text_encoder_model_name_or_path: str,
106
+ sampler: Optional[str] = None,
107
+ device: Optional[str] = None,
108
+ enhance_prompt: bool = False,
109
+ prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
110
+ prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
111
+ ) -> LTXVideoPipeline:
112
+ ckpt_path = Path(ckpt_path)
113
+ assert os.path.exists(
114
+ ckpt_path
115
+ ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
116
+
117
+ with safe_open(ckpt_path, framework="pt") as f:
118
+ metadata = f.metadata()
119
+ config_str = metadata.get("config")
120
+ configs = json.loads(config_str)
121
+ allowed_inference_steps = configs.get("allowed_inference_steps", None)
122
+
123
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
124
+ transformer = Transformer3DModel.from_pretrained(ckpt_path)
125
+
126
+ # Use constructor if sampler is specified, otherwise use from_pretrained
127
+ if sampler == "from_checkpoint" or not sampler:
128
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
129
+ else:
130
+ scheduler = RectifiedFlowScheduler(
131
+ sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
132
+ )
133
+
134
+ text_encoder = T5EncoderModel.from_pretrained(
135
+ text_encoder_model_name_or_path, subfolder="text_encoder"
136
+ )
137
+ patchifier = SymmetricPatchifier(patch_size=1)
138
+ tokenizer = T5Tokenizer.from_pretrained(
139
+ text_encoder_model_name_or_path, subfolder="tokenizer"
140
+ )
141
+
142
+ transformer = transformer.to(device)
143
+ vae = vae.to(device)
144
+ text_encoder = text_encoder.to(device)
145
+
146
+ if enhance_prompt:
147
+ prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
148
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
149
+ )
150
+ prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
151
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
152
+ )
153
+ prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
154
+ prompt_enhancer_llm_model_name_or_path,
155
+ torch_dtype="bfloat16",
156
+ )
157
+ prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
158
+ prompt_enhancer_llm_model_name_or_path,
159
+ )
160
+ else:
161
+ prompt_enhancer_image_caption_model = None
162
+ prompt_enhancer_image_caption_processor = None
163
+ prompt_enhancer_llm_model = None
164
+ prompt_enhancer_llm_tokenizer = None
165
+
166
+ vae = vae.to(torch.bfloat16)
167
+ if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
168
+ transformer = transformer.to(torch.bfloat16)
169
+ text_encoder = text_encoder.to(torch.bfloat16)
170
+
171
+ # Use submodels for the pipeline
172
+ submodel_dict = {
173
+ "transformer": transformer,
174
+ "patchifier": patchifier,
175
+ "text_encoder": text_encoder,
176
+ "tokenizer": tokenizer,
177
+ "scheduler": scheduler,
178
+ "vae": vae,
179
+ "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
180
+ "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
181
+ "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
182
+ "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
183
+ "allowed_inference_steps": allowed_inference_steps,
184
+ }
185
+
186
+ pipeline = LTXVideoPipeline(**submodel_dict)
187
+
188
+ LTXLatentConditioningPatch.apply()
189
+
190
+ pipeline = pipeline.to(device)
191
+ return pipeline
192
+
193
+ # ==============================================================================
194
+ # 2. FUNÇÕES AUXILIARES DE PROCESSAMENTO
195
+ # ==============================================================================
196
+
197
+ def calculate_padding(orig_h: int, orig_w: int, target_h: int, target_w: int) -> Tuple[int, int, int, int]:
198
+ """Calcula o preenchimento para centralizar uma imagem em uma nova dimensão."""
199
+ pad_h = target_h - orig_h
200
+ pad_w = target_w - orig_w
201
+ pad_top = pad_h // 2
202
+ pad_bottom = pad_h - pad_top
203
+ pad_left = pad_w // 2
204
+ pad_right = pad_w - pad_left
205
+ return (pad_left, pad_right, pad_top, pad_bottom)
206
+
207
+ def log_tensor_info(tensor: torch.Tensor, name: str = "Tensor"):
208
+ """Exibe informações detalhadas sobre um tensor para depuração."""
209
+ if not isinstance(tensor, torch.Tensor):
210
+ print(f"\n[INFO] '{name}' não é um tensor.")
211
+ return
212
+ print(f"\n--- Tensor Info: {name} ---")
213
+ print(f" - Shape: {tuple(tensor.shape)}")
214
+ print(f" - Dtype: {tensor.dtype}")
215
+ print(f" - Device: {tensor.device}")
216
+ if tensor.numel() > 0:
217
+ try:
218
+ print(f" - Stats: Min={tensor.min().item():.4f}, Max={tensor.max().item():.4f}, Mean={tensor.mean().item():.4f}")
219
+ except RuntimeError:
220
+ print(" - Stats: Não foi possível calcular (ex: tensores bool).")
221
+ print("-" * 30)
222
+
223
+ # ==============================================================================
224
+ # 3. CLASSE PRINCIPAL DO SERVIÇO DE VÍDEO
225
+ # ==============================================================================
226
+
227
+ class VideoService:
228
+ """
229
+ Serviço encapsulado para gerar vídeos usando a pipeline LTX-Video.
230
+ Gerencia o carregamento de modelos, pré-processamento, geração em múltiplos
231
+ passos (baixa resolução, upscale com denoise) e pós-processamento.
232
+ """
233
+ def __init__(self):
234
+ """Inicializa o serviço, carregando configurações e modelos."""
235
+ t0 = time.perf_counter()
236
+ print("[INFO] Inicializando VideoService...")
237
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
238
+ self.config = self._load_config("ltxv-13b-0.9.8-distilled-fp8.yaml")
239
+
240
+ self.pipeline, self.latent_upsampler = self._load_models_from_hub()
241
+ self._move_models_to_device()
242
+
243
+ self.runtime_autocast_dtype = self._get_precision_dtype()
244
+ vae_manager_singleton.attach_pipeline(
245
+ self.pipeline,
246
+ device=self.device,
247
+ autocast_dtype=self.runtime_autocast_dtype
248
+ )
249
+ self._tmp_dirs = set()
250
+ RESULTS_DIR.mkdir(exist_ok=True)
251
+ print(f"[INFO] VideoService pronto. Tempo de inicialização: {time.perf_counter()-t0:.2f}s")
252
+
253
+ # --------------------------------------------------------------------------
254
+ # --- Métodos Públicos (API do Serviço) ---
255
+ # --------------------------------------------------------------------------
256
+
257
+ def _load_image_to_tensor_with_resize_and_crop(
258
+ self,
259
+ image_input: Union[str, Image.Image],
260
+ target_height: int = 512,
261
+ target_width: int = 768,
262
+ just_crop: bool = False,
263
+ ) -> torch.Tensor:
264
+ """Load and process an image into a tensor.
265
+
266
+ Args:
267
+ image_input: Either a file path (str) or a PIL Image object
268
+ target_height: Desired height of output tensor
269
+ target_width: Desired width of output tensor
270
+ just_crop: If True, only crop the image to the target size without resizing
271
+ """
272
+ if isinstance(image_input, str):
273
+ image = Image.open(image_input).convert("RGB")
274
+ elif isinstance(image_input, Image.Image):
275
+ image = image_input
276
+ else:
277
+ raise ValueError("image_input must be either a file path or a PIL Image object")
278
+
279
+ input_width, input_height = image.size
280
+ aspect_ratio_target = target_width / target_height
281
+ aspect_ratio_frame = input_width / input_height
282
+ if aspect_ratio_frame > aspect_ratio_target:
283
+ new_width = int(input_height * aspect_ratio_target)
284
+ new_height = input_height
285
+ x_start = (input_width - new_width) // 2
286
+ y_start = 0
287
+ else:
288
+ new_width = input_width
289
+ new_height = int(input_width / aspect_ratio_target)
290
+ x_start = 0
291
+ y_start = (input_height - new_height) // 2
292
+
293
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
294
+ if not just_crop:
295
+ image = image.resize((target_width, target_height))
296
+
297
+ image = np.array(image)
298
+ image = cv2.GaussianBlur(image, (3, 3), 0)
299
+ frame_tensor = torch.from_numpy(image).float()
300
+ frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
301
+ frame_tensor = frame_tensor.permute(2, 0, 1)
302
+ frame_tensor = (frame_tensor / 127.5) - 1.0
303
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
304
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
305
+
306
+
307
+
308
+ def generate_low_resolution1(self, prompt: str, negative_prompt: str, height: int, width: int, duration_secs: float, guidance_scale: float, seed: Optional[int] = None, conditioning_items: Optional[List[PatchedConditioningItem]] = None) -> Tuple[str, str, int]:
309
+ """
310
+ Gera um vídeo de baixa resolução e retorna os caminhos para o vídeo e os latentes.
311
+ """
312
+ used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
313
+ #self._seed_everething(used_seed)
314
+
315
+ actual_num_frames = max(9, int(round((round(duration_secs * DEFAULT_FPS) - 1) / 8.0) * 8 + 1))
316
+
317
+ downscaled_height, downscaled_width = self._calculate_downscaled_dims(height, width)
318
+
319
+ first_pass_kwargs = {
320
+ "prompt": prompt, "negative_prompt": negative_prompt, "height": downscaled_height,
321
+ "width": downscaled_width, "num_frames": actual_num_frames, "frame_rate": int(DEFAULT_FPS),
322
+ "generator": torch.Generator(device=self.device).manual_seed(used_seed),
323
+ "output_type": "latent", "conditioning_items": conditioning_items,
324
+ "guidance_scale": float(guidance_scale), **(self.config.get("first_pass", {}))
325
+ }
326
+
327
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_low_")
328
+ self._register_tmp_dir(temp_dir)
329
+
330
+ try:
331
+ with torch.autocast(device_type=self.device.split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
332
+ latents = self.pipeline(**first_pass_kwargs).images
333
+ #pixel_tensor = vae_manager_singleton.decode(latents.clone(), decode_timestep=float(self.config.get("decode_timestep", 0.05)))
334
+ #video_path = self._save_video_from_tensor(pixel_tensor, "low_res_video", used_seed, temp_dir)
335
+ latents_path = self._save_latents_to_disk(latents, "latents_low_res", used_seed)
336
+
337
+ log_tensor_info(latents, "first_pass_lat" )
338
+ self._finalize()
339
+
340
+ final_video_path, final_latents_path = self.generate_upscale_denoise(
341
+ latents_path=latents_path,
342
+ prompt=prompt,
343
+ negative_prompt=negative_prompt,
344
+ guidance_scale=guidance_scale,
345
+ seed=used_seed
346
+ )
347
+
348
+ print(f"[SUCCESS] PASSO 2 concluído. Vídeo final em: {final_video_path}")
349
+
350
+ return final_video_path, final_latents_path, used_seed
351
+
352
+ except Exception as e:
353
+ print(f"[ERROR] Falha na geração de baixa resolução: {e}")
354
+ traceback.print_exc()
355
+ raise
356
+ finally:
357
+ self._finalize()
358
+
359
+
360
+ def _prepare_condition_items(self, items_list: List[Tuple], height: int, width: int) -> List[PatchedConditioningItem]:
361
+ """Prepara os tensores de condicionamento a partir de imagens ou tensores."""
362
+ if not items_list:
363
+ return []
364
+
365
+ height_padded = ((height - 1) // 8 + 1) * 8
366
+ width_padded = ((width - 1) // 8 + 1) * 8
367
+ padding_values = calculate_padding(height, width, height_padded, width_padded)
368
+
369
+ conditioning_items = []
370
+ for media, frame_idx, weight in items_list:
371
+ if isinstance(media, str):
372
+ tensor = self._prepare_conditioning_tensor_from_path(media, height, width, padding_values)
373
+ else: # Assume que é um tensor
374
+ tensor = media.to(self.device, dtype=self.runtime_autocast_dtype)
375
+
376
+ # Garante que o frame de condicionamento esteja dentro dos limites do vídeo
377
+ safe_frame_idx = int(frame_idx)
378
+ conditioning_items.append(PatchedConditioningItem(tensor, safe_frame_idx, float(weight)))
379
+
380
+ return PatchedConditioningItem
381
+
382
+
383
+ def generate_upscale_denoise(self, latents_path: str, prompt: str, negative_prompt: str, guidance_scale: float, seed: Optional[int] = None) -> Tuple[str, str]:
384
+ """
385
+ Aplica upscale, AdaIN e Denoise em latentes de baixa resolução usando um processo de chunking.
386
+ """
387
+ used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
388
+ #seed_everything(used_seed)
389
+
390
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_up_")
391
+ self._register_tmp_dir(temp_dir)
392
+
393
+ try:
394
+ latents_low = torch.load(latents_path).to(self.device)
395
+ with torch.autocast(device_type=self.device.split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
396
+ upsampled_latents = self._upsample_and_filter_latents(latents_low)
397
+ del latents_low; torch.cuda.empty_cache()
398
+
399
+ chunks = self._split_latents_with_overlap(upsampled_latents)
400
+ refined_chunks = []
401
+
402
+ for chunk in chunks:
403
+ if chunk.shape[2] <= 1: continue # Pula chunks inválidos
404
+
405
+ second_pass_height = chunk.shape[3] * self.pipeline.vae_scale_factor
406
+ second_pass_width = chunk.shape[4] * self.pipeline.vae_scale_factor
407
+
408
+ second_pass_kwargs = {
409
+ "prompt": prompt, "negative_prompt": negative_prompt, "height": second_pass_height,
410
+ "width": second_pass_width, "num_frames": chunk.shape[2], "latents": chunk,
411
+ "guidance_scale": float(guidance_scale), "output_type": "latent",
412
+ "generator": torch.Generator(device=self.device).manual_seed(used_seed),
413
+ **(self.config.get("second_pass", {}))
414
+ }
415
+ refined_chunk = self.pipeline(**second_pass_kwargs).images
416
+ refined_chunks.append(refined_chunk)
417
+
418
+ log_tensor_info(refined_chunk, "refined_chunk" )
419
+
420
+ final_latents = self._merge_chunks_with_overlap(refined_chunks)
421
+
422
+ if LTXV_DEBUG:
423
+ log_tensor_info(final_latents, "Latentes Upscaled/Refinados Finais")
424
+
425
+ latents_path = self._save_latents_to_disk(final_latents, "latents_refined", used_seed)
426
+ pixel_tensor = vae_manager_singleton.decode(final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05)))
427
+ video_path = self._save_video_from_tensor(pixel_tensor, "refined_video", used_seed, temp_dir)
428
+
429
+ return video_path, latents_path
430
+
431
+ except Exception as e:
432
+ print(f"[ERROR] Falha no processo de upscale e denoise: {e}")
433
+ traceback.print_exc()
434
+ raise
435
+ finally:
436
+ self._finalize()
437
+
438
+ def generate_low_resolution(
439
+ self,
440
+ prompt: str,
441
+ negative_prompt: str,
442
+ height: int,
443
+ width: int,
444
+ duration_secs: float,
445
+ guidance_scale: float,
446
+ seed: Optional[int] = None,
447
+ conditioning_items: Optional[List[PatchedConditioningItem]] = None
448
+ ) -> Tuple[str, str, int]:
449
+ """
450
+ ETAPA 1: Gera um vídeo e latentes em resolução base a partir de um prompt e
451
+ condicionamentos opcionais.
452
+ """
453
+ print("[INFO] Iniciando ETAPA 1: Geração de Baixa Resolução...")
454
+
455
+ # --- Configuração de Seed e Diretórios ---
456
+ used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
457
+ #seed_everything(used_seed)
458
+ print(f" - Usando Seed: {used_seed}")
459
+
460
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_low_")
461
+ self._register_tmp_dir(temp_dir)
462
+ results_dir = "/app/output"
463
+ os.makedirs(results_dir, exist_ok=True)
464
+
465
+ # --- Cálculo de Dimensões e Frames ---
466
+ actual_num_frames = int(round(duration_secs * DEFAULT_FPS))
467
+ downscaled_height = height
468
+ downscaled_width = width
469
+ #self._calculate_downscaled_dims(height, width)
470
+
471
+
472
+ print(f" - Frames: {actual_num_frames}, Duração: {duration_secs}s")
473
+ print(f" - Dimensões de Saída: {downscaled_height}x{downscaled_width}")
474
+
475
+ # --- Execução da Pipeline ---
476
+ with torch.autocast(device_type=self.device.split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
477
+
478
+ first_pass_kwargs = {
479
+ "prompt": prompt,
480
+ "negative_prompt": negative_prompt,
481
+ "height": downscaled_height,
482
+ "width": downscaled_width,
483
+ "num_frames": (actual_num_frames//8)+1,
484
+ "frame_rate": int(DEFAULT_FPS),
485
+ "generator": torch.Generator(device=self.device).manual_seed(used_seed),
486
+ "output_type": "latent",
487
+ "conditioning_items": conditioning_items,
488
+ "guidance_scale": float(guidance_scale),
489
+ **(self.config.get("first_pass", {}))
490
+ }
491
+
492
+ print(" - Enviando para a pipeline LTX...")
493
+ latents = self.pipeline(**first_pass_kwargs).images
494
+ print(f" - Latentes gerados com shape: {latents.shape}")
495
+
496
+ # Decodifica os latentes para pixels para criar o vídeo de preview
497
+ #pixel_tensor = vae_manager_singleton.decode(latents.clone(), decode_timestep=float(self.config.get("decode_timestep", 0.05)))
498
+
499
+ # Salva os artefatos de saída (vídeo e tensor de latentes)
500
+ #video_path = self._save_video_from_tensor(pixel_tensor, "low_res_video", used_seed, temp_dir)
501
+ tensor_path = self._save_latents_to_disk(latents, "latents_low_res", used_seed)
502
+
503
+ self._finalize()
504
+
505
+ final_video_path, final_latents_path = self.refine_texture_only(
506
+ latents_path=tensor_path,
507
+ prompt=prompt,
508
+ negative_prompt=negative_prompt,
509
+ guidance_scale=guidance_scale,
510
+ seed=used_seed,
511
+ conditioning_items=conditioning_items,
512
+ )
513
+
514
+ # --- Limpeza ---
515
+ self._finalize()
516
+
517
+ print("[SUCCESS] ETAPA 1 Concluída.")
518
+ return final_video_path, final_latents_path, used_seed
519
+
520
+
521
+ def refine_texture_only(
522
+ self,
523
+ latents_path: str,
524
+ prompt: str,
525
+ negative_prompt: str,
526
+ guidance_scale: float,
527
+ seed: Optional[int] = None,
528
+ conditioning_items: Optional[List[PatchedConditioningItem]] = None
529
+ ) -> Tuple[str, str]:
530
+ """
531
+ ETAPA 2: Refina a textura dos latentes existentes SEM alterar sua resolução
532
+ e SEM dividi-los em pedaços. O tensor inteiro é processado de uma só vez para
533
+ garantir máxima consistência temporal.
534
+ """
535
+ print("[INFO] Iniciando ETAPA 2: Refinamento de Textura...")
536
+
537
+ # --- Configuração de Seed e Diretórios ---
538
+ used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
539
+ #seed_everything(used_seed)
540
+ print(f" - Usando Seed (consistente com Etapa 1): {used_seed}")
541
+
542
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_refine_single_")
543
+ self._register_tmp_dir(temp_dir)
544
+
545
+ # --- Carregamento dos Latentes ---
546
+ latents_to_refine = torch.load(latents_path).to(self.device)
547
+ print(f" - Shape dos latentes de entrada: {latents_to_refine.shape}")
548
+
549
+ if conditioning_items:
550
+ print(f" - Usando {len(conditioning_items)} item(ns) de condicionamento para o refinamento.")
551
+
552
+ # --- Execução da Pipeline ---
553
+ with torch.autocast(device_type=self.device.split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
554
+
555
+ # As dimensões são as mesmas do tensor de entrada
556
+ refine_height = latents_to_refine.shape[3] * self.pipeline.vae_scale_factor
557
+ refine_width = latents_to_refine.shape[4] * self.pipeline.vae_scale_factor
558
+
559
+ second_pass_kwargs = {
560
+ "prompt": prompt,
561
+ "negative_prompt": negative_prompt,
562
+ "height": refine_height,
563
+ "width": refine_width,
564
+ "frame_rate": int(DEFAULT_FPS),
565
+ "num_frames": latents_to_refine.shape[2],
566
+ "latents": latents_to_refine, # O tensor completo é passado aqui
567
+ "guidance_scale": float(guidance_scale),
568
+ "output_type": "latent",
569
+ "generator": torch.Generator(device=self.device).manual_seed(used_seed),
570
+ "conditioning_items": conditioning_items,
571
+ **(self.config.get("second_pass", {}))
572
+ }
573
+
574
+ print(" - Enviando tensor completo para a pipeline de refinamento...")
575
+ final_latents = self.pipeline(**second_pass_kwargs).images
576
+ print(f" - Latentes refinados com shape: {final_latents.shape}")
577
+
578
+ # Decodifica os latentes refinados para pixels
579
+ pixel_tensor = vae_manager_singleton.decode(final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05)))
580
+
581
+ # Salva os artefatos de saída
582
+ video_path_out = self._save_video_from_tensor(pixel_tensor, "refined_video_single_pass", used_seed, temp_dir)
583
+ latents_path_out = self._save_latents_to_disk(final_latents, "latents_refined_single_pass", used_seed)
584
+
585
+ # --- Limpeza ---
586
+ # Libera os tensores da memória da GPU antes de finalizar.
587
+ del latents_to_refine
588
+ if 'final_latents' in locals():
589
+ del final_latents
590
+ if 'pixel_tensor' in locals():
591
+ del pixel_tensor
592
+ self._finalize()
593
+
594
+ print("[SUCCESS] ETAPA 2 Concluída.")
595
+ return video_path_out, latents_path_out
596
+
597
+
598
+ def encode_latents_to_mp4(self, latents_path: str, fps: int = int(DEFAULT_FPS)) -> str:
599
+ """Decodifica um tensor de latentes salvo e o salva como um vídeo MP4."""
600
+ latents = torch.load(latents_path)
601
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_enc_")
602
+ self._register_tmp_dir(temp_dir)
603
+ seed = random.randint(0, 99999) # Seed apenas para nome do arquivo
604
+
605
+ try:
606
+ chunks = self._split_latents_with_overlap(latents)
607
+ pixel_chunks = []
608
+
609
+ with torch.autocast(device_type=self.device.split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
610
+ for chunk in chunks:
611
+ if chunk.shape[2] == 0: continue
612
+ pixel_chunk = vae_manager_singleton.decode(chunk.to(self.device), decode_timestep=float(self.config.get("decode_timestep", 0.05)))
613
+ pixel_chunks.append(pixel_chunk)
614
+
615
+ final_pixel_tensor = self._merge_chunks_with_overlap(pixel_chunks)
616
+ final_video_path = self._save_video_from_tensor(final_pixel_tensor, f"final_video_{seed}", seed, temp_dir, fps=fps)
617
+ return final_video_path
618
+
619
+ except Exception as e:
620
+ print(f"[ERROR] Falha ao encodar latentes para MP4: {e}")
621
+ traceback.print_exc()
622
+ raise
623
+ finally:
624
+ self._finalize()
625
+
626
+ # --------------------------------------------------------------------------
627
+ # --- Métodos Internos e Auxiliares ---
628
+ # --------------------------------------------------------------------------
629
+
630
+ def _finalize(self):
631
+ """Limpa a memória da GPU e os diretórios temporários."""
632
+ if LTXV_DEBUG:
633
+ print("[DEBUG] Finalize: iniciando limpeza...")
634
+
635
+ gc.collect()
636
+ if torch.cuda.is_available():
637
+ torch.cuda.empty_cache()
638
+ torch.cuda.ipc_collect()
639
+
640
+ # Limpa todos os diretórios temporários registrados
641
+ for d in list(self._tmp_dirs):
642
+ shutil.rmtree(d, ignore_errors=True)
643
+ self._tmp_dirs.remove(d)
644
+ if LTXV_DEBUG:
645
+ print(f"[DEBUG] Diretório temporário removido: {d}")
646
+
647
+ def _load_config(self, config_filename: str) -> Dict:
648
+ """Carrega o arquivo de configuração YAML."""
649
+ config_path = LTX_VIDEO_REPO_DIR / "configs" / config_filename
650
+ print(f"[INFO] Carregando configuração de: {config_path}")
651
+ with open(config_path, "r") as file:
652
+ return yaml.safe_load(file)
653
+
654
+ def _load_models_from_hub(self):
655
+ """Baixa e cria as instâncias da pipeline e do upsampler."""
656
+ t0 = time.perf_counter()
657
+ LTX_REPO = "Lightricks/LTX-Video"
658
+
659
+ print("[INFO] Baixando checkpoint principal...")
660
+ self.config["checkpoint_path"] = hf_hub_download(
661
+ repo_id=LTX_REPO, filename=self.config["checkpoint_path"],
662
+ token=os.getenv("HF_TOKEN")
663
+ )
664
+ print(f"[INFO] Checkpoint principal em: {self.config['checkpoint_path']}")
665
+
666
+ print("[INFO] Construindo pipeline...")
667
+ pipeline = create_ltx_video_pipeline(
668
+ ckpt_path=self.config["checkpoint_path"],
669
+ precision=self.config["precision"],
670
+ text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
671
+ sampler=self.config["sampler"],
672
+ device="cpu", # Carrega em CPU primeiro
673
+ enhance_prompt=False
674
+ )
675
+ print("[INFO] Pipeline construída.")
676
+
677
+ latent_upsampler = None
678
+ if self.config.get("spatial_upscaler_model_path"):
679
+ print("[INFO] Baixando upscaler espacial...")
680
+ self.config["spatial_upscaler_model_path"] = hf_hub_download(
681
+ repo_id=LTX_REPO, filename=self.config["spatial_upscaler_model_path"],
682
+ token=os.getenv("HF_TOKEN")
683
+ )
684
+ print(f"[INFO] Upscaler em: {self.config['spatial_upscaler_model_path']}")
685
+
686
+ print("[INFO] Construindo latent_upsampler...")
687
+ latent_upsampler = create_latent_upsampler(self.config["spatial_upscaler_model_path"], device="cpu")
688
+ print("[INFO] Latent upsampler construído.")
689
+
690
+ print(f"[INFO] Carregamento de modelos concluído em {time.perf_counter()-t0:.2f}s")
691
+ return pipeline, latent_upsampler
692
+
693
+ def _move_models_to_device(self):
694
+ """Move os modelos carregados para o dispositivo de computação (GPU/CPU)."""
695
+ print(f"[INFO] Movendo modelos para o dispositivo: {self.device}")
696
+ self.pipeline.to(self.device)
697
+ if self.latent_upsampler:
698
+ self.latent_upsampler.to(self.device)
699
+
700
+ def _get_precision_dtype(self) -> torch.dtype:
701
+ """Determina o dtype para autocast com base na configuração de precisão."""
702
+ prec = str(self.config.get("precision", "")).lower()
703
+ if prec in ["float8_e4m3fn", "bfloat16"]:
704
+ return torch.bfloat16
705
+ elif prec == "mixed_precision":
706
+ return torch.float16
707
+ return torch.float32
708
+
709
+ @torch.no_grad()
710
+ def _upsample_and_filter_latents(self, latents: torch.Tensor) -> torch.Tensor:
711
+ """Aplica o upsample espacial e o filtro AdaIN aos latentes."""
712
+ if not self.latent_upsampler:
713
+ raise ValueError("Latent Upsampler não está carregado para a operação de upscale.")
714
+
715
+ latents_unnormalized = un_normalize_latents(latents, self.pipeline.vae, vae_per_channel_normalize=True)
716
+ upsampled_latents_unnormalized = self.latent_upsampler(latents_unnormalized)
717
+ upsampled_latents_normalized = normalize_latents(upsampled_latents_unnormalized, self.pipeline.vae, vae_per_channel_normalize=True)
718
+
719
+ # Filtro AdaIN para manter consistência de cor/estilo com o vídeo de baixa resolução
720
+ return adain_filter_latent(latents=upsampled_latents_normalized, reference_latents=latents)
721
+
722
+ def _prepare_conditioning_tensor_from_path(self, filepath: str, height: int, width: int, padding: Tuple) -> torch.Tensor:
723
+ """Carrega uma imagem, redimensiona, aplica padding e move para o dispositivo."""
724
+ tensor = self._load_image_to_tensor_with_resize_and_crop(filepath, height, width)
725
+ tensor = F.pad(tensor, padding)
726
+ return tensor.to(self.device, dtype=self.runtime_autocast_dtype)
727
+
728
+ def _calculate_downscaled_dims(self, height: int, width: int) -> Tuple[int, int]:
729
+ """Calcula as dimensões para o primeiro passo (baixa resolução)."""
730
+ height_padded = ((height - 1) // 8 + 1) * 8
731
+ width_padded = ((width - 1) // 8 + 1) * 8
732
+
733
+ downscale_factor = self.config.get("downscale_factor", 0.6666666)
734
+ vae_scale_factor = self.pipeline.vae_scale_factor
735
+
736
+ target_w = int(width_padded * downscale_factor)
737
+ downscaled_width = target_w - (target_w % vae_scale_factor)
738
+
739
+ target_h = int(height_padded * downscale_factor)
740
+ downscaled_height = target_h - (target_h % vae_scale_factor)
741
+
742
+ return downscaled_height, downscaled_width
743
+
744
+ def _split_latents_with_overlap(self, latents: torch.Tensor, overlap: int = 1) -> List[torch.Tensor]:
745
+ """Divide um tensor de latentes em dois chunks com sobreposição."""
746
+ total_frames = latents.shape[2]
747
+ if total_frames <= overlap:
748
+ return [latents]
749
+
750
+ mid_point = max(overlap, total_frames // 2)
751
+ chunk1 = latents[:, :, :mid_point, :, :]
752
+ # O segundo chunk começa 'overlap' frames antes para criar a sobreposição
753
+ chunk2 = latents[:, :, mid_point - overlap:, :, :]
754
+
755
+ return [c for c in [chunk1, chunk2] if c.shape[2] > 0]
756
+
757
+ def _merge_chunks_with_overlap(self, chunks: List[torch.Tensor], overlap: int = 1) -> torch.Tensor:
758
+ """Junta uma lista de chunks, removendo a sobreposição."""
759
+ if not chunks:
760
+ return torch.empty(0)
761
+ if len(chunks) == 1:
762
+ return chunks[0]
763
+
764
+ # Pega o primeiro chunk sem o frame de sobreposição final
765
+ merged_list = [chunks[0][:, :, :-overlap, :, :]]
766
+ # Adiciona os chunks restantes
767
+ merged_list.extend(chunks[1:])
768
+
769
+ return torch.cat(merged_list, dim=2)
770
+
771
+ def _save_latents_to_disk(self, latents_tensor: torch.Tensor, base_filename: str, seed: int) -> str:
772
+ """Salva um tensor de latentes em um arquivo .pt."""
773
+ latents_cpu = latents_tensor.detach().to("cpu")
774
+ tensor_path = RESULTS_DIR / f"{base_filename}_{seed}.pt"
775
+ torch.save(latents_cpu, tensor_path)
776
+ if LTXV_DEBUG:
777
+ print(f"[DEBUG] Latentes salvos em: {tensor_path}")
778
+ return str(tensor_path)
779
+
780
+ def _save_video_from_tensor(self, pixel_tensor: torch.Tensor, base_filename: str, seed: int, temp_dir: str, fps: int = int(DEFAULT_FPS)) -> str:
781
+ """Salva um tensor de pixels como um arquivo de vídeo MP4."""
782
+ temp_path = os.path.join(temp_dir, f"{base_filename}_{seed}.mp4")
783
+ video_encode_tool_singleton.save_video_from_tensor(pixel_tensor, temp_path, fps=fps)
784
+
785
+ final_path = RESULTS_DIR / f"{base_filename}_{seed}.mp4"
786
+ shutil.move(temp_path, final_path)
787
+ print(f"[INFO] Vídeo final salvo em: {final_path}")
788
+ return str(final_path)
789
+
790
+
791
+ def _seed_everething(self, seed: int):
792
+ random.seed(seed)
793
+ np.random.seed(seed)
794
+ torch.manual_seed(seed)
795
+ if torch.cuda.is_available():
796
+ torch.cuda.manual_seed(seed)
797
+ if torch.backends.mps.is_available():
798
+ torch.mps.manual_seed(seed)
799
+
800
+
801
+ def _register_tmp_dir(self, dir_path: str):
802
+ """Registra um diretório temporário para limpeza posterior."""
803
+ if dir_path and os.path.isdir(dir_path):
804
+ self._tmp_dirs.add(dir_path)
805
+ if LTXV_DEBUG:
806
+ print(f"[DEBUG] Diretório temporário registrado: {dir_path}")
807
+
808
+ # ==============================================================================
809
+ # 4. INSTANCIAÇÃO E PONTO DE ENTRADA (Exemplo)
810
+ # ==============================================================================
811
+
812
+ print("Criando instância do VideoService. O carregamento do modelo começará agora...")
813
+ video_generation_service = VideoService()
814
+ print("Instância do VideoService pronta para uso.")
api/seedvr_server.py CHANGED
@@ -1,111 +1,268 @@
 
 
1
  import os
2
- import shutil
3
- import subprocess
4
  import sys
5
  import time
6
- import mimetypes
 
 
7
  from pathlib import Path
8
- from typing import List, Optional, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class SeedVRServer:
13
  def __init__(self, **kwargs):
14
- self.SEEDVR_ROOT = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
15
- # Apontamos para o nosso diretório de checkpoints customizado
16
- self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
17
- self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
18
- self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
19
  self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
20
  self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
21
- self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
22
 
23
- print("🚀 SeedVRServer (FP16) inicializando e preparando o ambiente...")
24
- for p in [self.SEEDVR_ROOT.parent, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
 
 
 
 
25
  p.mkdir(parents=True, exist_ok=True)
26
-
27
  self.setup_dependencies()
28
- print(" SeedVRServer (FP16) pronto.")
29
 
30
  def setup_dependencies(self):
31
- self._ensure_repo()
32
- # O monkey patch agora é feito pelo start_seedvr.sh, não mais aqui.
33
- self._ensure_model()
34
-
35
- def _ensure_repo(self) -> None:
36
  if not (self.SEEDVR_ROOT / ".git").exists():
37
- print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
38
  subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
39
- else:
40
- print("[SeedVRServer] Repositório SeedVR já existe.")
41
-
42
- def _ensure_model(self) -> None:
43
- """Baixa os arquivos de modelo FP16 otimizados e suas dependências."""
44
- print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
45
 
46
  model_files = {
47
- "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses", "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
48
- "pos_emb.pt": "ByteDance-Seed/SeedVR2-3B", "neg_emb.pt": "ByteDance-Seed/SeedVR2-3B"
49
  }
50
-
51
  for filename, repo_id in model_files.items():
52
  if not (self.CKPTS_ROOT / filename).exists():
53
- print(f"Baixando {filename} de {repo_id}...")
54
- hf_hub_download(repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT), cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN"))
55
- print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
56
-
57
- def _prepare_job(self, input_file: str) -> Tuple[Path, Path]:
58
- ts = f"{int(time.time())}_{os.urandom(4).hex()}"
59
- job_input_dir = self.INPUT_ROOT / f"job_{ts}"
60
- out_dir = self.OUTPUT_ROOT / f"run_{ts}"
61
- job_input_dir.mkdir(parents=True, exist_ok=True)
62
- out_dir.mkdir(parents=True, exist_ok=True)
63
- shutil.copy2(input_file, job_input_dir / Path(input_file).name)
64
- return job_input_dir, out_dir
65
-
66
- def run_inference(self, filepath: str, *, seed: int, resh: int, resw: int, spsize: int, fps: Optional[float] = None):
67
- script = self.SEEDVR_ROOT / "inference_cli.py"
68
- job_input_dir, outdir = self._prepare_job(filepath)
69
- mediatype, _ = mimetypes.guess_type(filepath)
70
- is_image = mediatype and mediatype.startswith("image")
71
-
72
- effective_nproc = 1 if is_image else self.NUM_GPUS_TOTAL
73
- effective_spsize = 1 if is_image else spsize
74
-
75
- output_filename = f"result_{Path(filepath).stem}.mp4" if not is_image else f"{Path(filepath).stem}_upscaled"
76
- output_filepath = outdir / output_filename
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
-
80
- cmd = [
81
- "torchrun", "--standalone", "--nnodes=1",
82
- f"--nproc-per-node={effective_nproc}",
83
- str(script),
84
- "--video_path", str(filepath),
85
- "--output", str(output_filepath),
86
- "--model_dir", str(self.CKPTS_ROOT),
87
- "--seed", str(seed),
88
- "--cuda_device", "0",
89
- "--resolution", str(resh),
90
- "--batch_size", str(effective_spsize),
91
- "--model", "seedvr2_ema_3b_fp16.safetensors",
92
- "--preserve_vram",
93
- "--debug",
94
- "--output_format", "video" if not is_image else "png",
95
- ]
96
-
97
-
98
- print("SeedVRServer Comando:", " ".join(cmd))
99
  try:
100
- subprocess.run(cmd, cwd=str(self.SEEDVR_ROOT), check=True, env=os.environ.copy(), stdout=sys.stdout, stderr=sys.stderr)
101
- # Constrói a tupla de retorno de forma determinística
102
- if is_image:
103
- # CLI salva PNGs em diretório args.output (tratado como diretório quando outputformat=png)
104
- image_dir = output_filepath if output_filepath.suffix == "" else output_filepath.with_suffix("")
105
- return str(image_dir), None, outdir
106
- else:
107
- # CLI salva vídeo exatamente em output_filepath
108
- return None, str(output_filepath), outdir
109
- except Exception as e:
110
- print(f"[UI ERROR] A inferência falhou: {e}")
111
- return None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/seedvr_server.py
2
+
3
  import os
 
 
4
  import sys
5
  import time
6
+ import subprocess
7
+ import queue
8
+ import multiprocessing as mp
9
  from pathlib import Path
10
+ from typing import Optional, Callable
11
+
12
+ # --- 1. Import dos Módulos Compartilhados ---
13
+ # É crucial que estes imports venham antes dos imports pesados (torch, etc.)
14
+ # para que o ambiente de multiprocessing seja configurado corretamente.
15
+
16
+ try:
17
+ # Importa o gerenciador de GPUs que centraliza a lógica de alocação
18
+ from api.gpu_manager import gpu_manager
19
+ # Importa o serviço do LTX para podermos comandá-lo a liberar a VRAM
20
+ from api.ltx_server_refactored import video_generation_service
21
+ except ImportError:
22
+ print("ERRO FATAL: Não foi possível importar `gpu_manager` ou `video_generation_service`.")
23
+ print("Certifique-se de que os arquivos `gpu_manager.py` e `ltx_server_refactored.py` existem em `api/`.")
24
+ sys.exit(1)
25
+
26
+
27
+ # --- 2. Configuração de Ambiente e CUDA ---
28
+ if mp.get_start_method(allow_none=True) != 'spawn':
29
+ mp.set_start_method('spawn', force=True)
30
+
31
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
32
+
33
+ # Adiciona o caminho do repositório SeedVR
34
+ SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
35
+ if str(SEEDVR_REPO_PATH) not in sys.path:
36
+ sys.path.insert(0, str(SEEDVR_REPO_PATH))
37
+
38
+ # Imports pesados
39
+ import torch
40
+ import cv2
41
+ import numpy as np
42
+ from datetime import datetime
43
+
44
 
45
+ # --- 3. Funções Auxiliares de Processamento (Workers e I/O) ---
46
+ # (Estas funções não precisam de alteração)
47
+
48
+ def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load_cap=None):
49
+ if debug: print(f"🎬 [SeedVR] Extraindo frames de: {video_path}")
50
+ if not os.path.exists(video_path): raise FileNotFoundError(f"Arquivo de vídeo não encontrado: {video_path}")
51
+ cap = cv2.VideoCapture(video_path)
52
+ if not cap.isOpened(): raise ValueError(f"Não foi possível abrir o vídeo: {video_path}")
53
+
54
+ fps = cap.get(cv2.CAP_PROP_FPS)
55
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
56
+ frames = []
57
+ frames_loaded = 0
58
+ for i in range(frame_count):
59
+ ret, frame = cap.read()
60
+ if not ret: break
61
+ if i < skip_first_frames: continue
62
+ if load_cap and frames_loaded >= load_cap: break
63
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
64
+ frames.append(frame.astype(np.float32) / 255.0)
65
+ frames_loaded += 1
66
+ cap.release()
67
+ if not frames: raise ValueError(f"Nenhum frame extraído de: {video_path}")
68
+ if debug: print(f"✅ [SeedVR] {len(frames)} frames extraídos com sucesso.")
69
+ return torch.from_numpy(np.stack(frames)).to(torch.float16), fps
70
+
71
+ def save_frames_to_video(frames_tensor, output_path, fps=30.0, debug=False):
72
+ if debug: print(f"💾 [SeedVR] Salvando {frames_tensor.shape[0]} frames em: {output_path}")
73
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
74
+ frames_np = (frames_tensor.cpu().numpy() * 255.0).astype(np.uint8)
75
+ T, H, W, _ = frames_np.shape
76
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
77
+ out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
78
+ if not out.isOpened(): raise ValueError(f"Não foi possível criar o vídeo: {output_path}")
79
+ for frame in frames_np:
80
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
81
+ out.release()
82
+ if debug: print(f"✅ [SeedVR] Vídeo salvo com sucesso: {output_path}")
83
+
84
+ def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue, progress_queue=None):
85
+ """Processo filho (worker) que executa o upscaling em uma GPU dedicada."""
86
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
87
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
88
+
89
+ import torch
90
+ from src.core.model_manager import configure_runner
91
+ from src.core.generation import generation_loop
92
+
93
+ try:
94
+ frames_tensor = torch.from_numpy(frames_np).to(torch.float16)
95
+ callback = (lambda b, t, _, m: progress_queue.put((proc_idx, b, t, m))) if progress_queue else None
96
+
97
+ runner = configure_runner(shared_args["model"], shared_args["model_dir"], shared_args["preserve_vram"], shared_args["debug"])
98
+ result_tensor = generation_loop(
99
+ runner=runner, images=frames_tensor, cfg_scale=1.0, seed=shared_args["seed"],
100
+ res_w=shared_args["resolution"], batch_size=shared_args["batch_size"],
101
+ preserve_vram=shared_args["preserve_vram"], temporal_overlap=0,
102
+ debug=shared_args["debug"], progress_callback=callback
103
+ )
104
+ return_queue.put((proc_idx, result_tensor.cpu().numpy()))
105
+ except Exception as e:
106
+ import traceback
107
+ error_msg = f"ERRO no worker {proc_idx} (GPU {device_id}): {e}\n{traceback.format_exc()}"
108
+ print(error_msg)
109
+ if progress_queue: progress_queue.put((proc_idx, -1, -1, error_msg))
110
+ return_queue.put((proc_idx, error_msg))
111
+
112
+ # --- 4. CLASSE DO SERVIDOR PRINCIPAL ---
113
 
114
  class SeedVRServer:
115
  def __init__(self, **kwargs):
116
+ """Inicializa o servidor, define os caminhos e prepara o ambiente."""
117
+ print("⚙️ SeedVRServer inicializando...")
118
+ self.SEEDVR_ROOT = SEEDVR_REPO_PATH
119
+ self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
120
+ self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/output"))
121
  self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
122
  self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
 
123
 
124
+ # OBTÉM AS GPUS ALOCADAS PELO GERENCIADOR CENTRAL
125
+ self.device_list = gpu_manager.get_seedvr_devices()
126
+ self.num_gpus = len(self.device_list)
127
+ print(f"[SeedVR] Alocado para usar {self.num_gpus} GPU(s): {self.device_list}")
128
+
129
+ for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.HF_HOME_CACHE]:
130
  p.mkdir(parents=True, exist_ok=True)
131
+
132
  self.setup_dependencies()
133
+ print("📦 SeedVRServer pronto.")
134
 
135
  def setup_dependencies(self):
136
+ """Garante que o repositório e os modelos estão presentes."""
 
 
 
 
137
  if not (self.SEEDVR_ROOT / ".git").exists():
138
+ print(f"[SeedVR] Clonando repositório para {self.SEEDVR_ROOT}...")
139
  subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
 
 
 
 
 
 
140
 
141
  model_files = {
142
+ "seedvr2_ema_7b_sharp_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
143
+ "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses"
144
  }
 
145
  for filename, repo_id in model_files.items():
146
  if not (self.CKPTS_ROOT / filename).exists():
147
+ print(f"Baixando {filename}...")
148
+ from huggingface_hub import hf_hub_download
149
+ hf_hub_download(
150
+ repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT),
151
+ cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
152
+ )
153
+ print("[SeedVR] Checkpoints verificados.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ def run_inference(
156
+ self,
157
+ file_path: str, *,
158
+ seed: int,
159
+ resolution: int,
160
+ batch_size: int,
161
+ model: str = "seedvr2_ema_7b_sharp_fp16.safetensors",
162
+ fps: Optional[float] = None,
163
+ debug: bool = True,
164
+ preserve_vram: bool = True,
165
+ progress: Optional[Callable] = None
166
+ ) -> str:
167
+ """
168
+ Executa o pipeline completo de upscaling de vídeo, gerenciando a memória da GPU.
169
+ """
170
+ if progress: progress(0.01, "⌛ Inicializando inferência SeedVR...")
171
+
172
+ # --- NÓ 1: GERENCIAMENTO DE MEMÓRIA (SWAP) ---
173
+ if gpu_manager.requires_memory_swap():
174
+ print("[SWAP] SeedVR precisa da GPU. Movendo LTX para a CPU...")
175
+ if progress: progress(0.02, "🔄 Liberando VRAM para o SeedVR...")
176
+ video_generation_service.move_to_cpu()
177
+ print("[SWAP] LTX movido para a CPU. VRAM liberada.")
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  try:
180
+ # --- 2: EXTRAÇÃO DE FRAMES ---
181
+ if progress: progress(0.05, "🎬 Extraindo frames do vídeo...")
182
+ frames_tensor, original_fps = extract_frames_from_video(file_path, debug)
183
+
184
+ # --- 3: DIVISÃO PARA MULTI-GPU ---
185
+ if self.num_gpus == 0:
186
+ raise RuntimeError("SeedVR requer pelo menos 1 GPU alocada, mas não encontrou nenhuma.")
187
+
188
+ print(f"[SeedVR] Dividindo {frames_tensor.shape[0]} frames em {self.num_gpus} chunks para processamento paralelo.")
189
+ chunks = torch.chunk(frames_tensor, self.num_gpus, dim=0)
190
+
191
+ manager = mp.Manager()
192
+ return_queue = manager.Queue()
193
+ progress_queue = manager.Queue() if progress else None
194
+
195
+ shared_args = {
196
+ "model": model, "model_dir": str(self.CKPTS_ROOT), "preserve_vram": preserve_vram,
197
+ "debug": debug, "seed": seed, "resolution": resolution, "batch_size": batch_size
198
+ }
199
+
200
+ # --- NÓ 4: INÍCIO DOS WORKERS ---
201
+ if progress: progress(0.1, f"🚀 Iniciando geração em {self.num_gpus} GPU(s)...")
202
+ workers = []
203
+ for idx, device_id in enumerate(self.device_list):
204
+ p = mp.Process(target=_worker_process, args=(idx, device_id, chunks[idx].cpu().numpy(), shared_args, return_queue, progress_queue))
205
+ p.start()
206
+ workers.append(p)
207
+
208
+ # --- NÓ 5: COLETA DE RESULTADOS E MONITORAMENTO ---
209
+ results_np = [None] * self.num_gpus
210
+ finished_workers = 0
211
+ worker_progress = [0.0] * self.num_gpus
212
+ while finished_workers < self.num_gpus:
213
+ if progress_queue:
214
+ while not progress_queue.empty():
215
+ try:
216
+ p_idx, b_idx, b_total, msg = progress_queue.get_nowait()
217
+ if b_idx == -1: raise RuntimeError(f"Erro no Worker {p_idx}: {msg}")
218
+ if b_total > 0: worker_progress[p_idx] = b_idx / b_total
219
+ total_progress = sum(worker_progress) / self.num_gpus
220
+ progress(0.1 + total_progress * 0.85, desc=f"GPU {p_idx+1}/{self.num_gpus}: {msg}")
221
+ except queue.Empty: pass
222
+
223
+ try:
224
+ proc_idx, result = return_queue.get(timeout=0.2)
225
+ if isinstance(result, str): raise RuntimeError(f"Worker {proc_idx} falhou: {result}")
226
+ results_np[proc_idx] = result
227
+ worker_progress[proc_idx] = 1.0
228
+ finished_workers += 1
229
+ except queue.Empty: pass
230
+
231
+ for p in workers: p.join()
232
+
233
+ # --- NÓ 6: FINALIZAÇÃO ---
234
+ if any(r is None for r in results_np):
235
+ raise RuntimeError("Um ou mais workers falharam ao retornar um resultado.")
236
+
237
+ result_tensor = torch.from_numpy(np.concatenate(results_np, axis=0)).to(torch.float16)
238
+ if progress: progress(0.95, "💾 Salvando o vídeo final...")
239
+
240
+ out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
241
+ out_dir.mkdir(parents=True, exist_ok=True)
242
+ output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
243
+
244
+ final_fps = fps if fps and fps > 0 else original_fps
245
+ save_frames_to_video(result_tensor, str(output_filepath), final_fps, debug)
246
+
247
+ print(f"✅ Vídeo salvo com sucesso em: {output_filepath}")
248
+ return str(output_filepath)
249
+
250
+ finally:
251
+ # --- NÓ 7: RESTAURAÇÃO DE MEMÓRIA (SWAP BACK) ---
252
+ if gpu_manager.requires_memory_swap():
253
+ print("[SWAP] Inferência do SeedVR concluída. Movendo LTX de volta para a GPU...")
254
+ if progress: progress(0.99, "🔄 Restaurando o ambiente LTX...")
255
+ ltx_device = gpu_manager.get_ltx_device()
256
+ video_generation_service.move_to_device(ltx_device)
257
+ print(f"[SWAP] LTX de volta em {ltx_device}.")
258
+
259
+ # --- PONTO DE ENTRADA ---
260
+ if __name__ == "__main__":
261
+ print("🚀 Executando o servidor SeedVR em modo autônomo para inicialização...")
262
+ try:
263
+ server = SeedVRServer()
264
+ print("✅ Servidor inicializado com sucesso. Pronto para receber chamadas.")
265
+ except Exception as e:
266
+ print(f"❌ Falha ao inicializar o servidor SeedVR: {e}")
267
+ traceback.print_exc()
268
+ sys.exit(1)