Spaces:
Paused
Paused
Update api/ltx_server_refactored_complete.py
Browse files
api/ltx_server_refactored_complete.py
CHANGED
|
@@ -54,9 +54,16 @@ add_deps_to_path()
|
|
| 54 |
|
| 55 |
# --- PROJECT IMPORTS ---
|
| 56 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
from api.gpu_manager import gpu_manager
|
| 58 |
from ltx_video.models.autoencoders.vae_encode import (normalize_latents, un_normalize_latents)
|
| 59 |
-
from ltx_video.pipelines.pipeline_ltx_video import (ConditioningItem, LTXMultiScalePipeline, adain_filter_latent, create_latent_upsampler
|
| 60 |
from ltx_video.utils.inference_utils import load_image_to_tensor_with_resize_and_crop
|
| 61 |
from managers.vae_manager import vae_manager_singleton
|
| 62 |
from tools.video_encode_tool import video_encode_tool_singleton
|
|
@@ -158,28 +165,82 @@ class VideoService:
|
|
| 158 |
with open(config_path, "r") as file:
|
| 159 |
return yaml.safe_load(file)
|
| 160 |
|
| 161 |
-
def _load_models(self) -> Tuple[
|
| 162 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
t0 = time.perf_counter()
|
| 164 |
-
logging.info("
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
latent_upsampler = None
|
| 176 |
if self.config.get("spatial_upscaler_model_path"):
|
|
|
|
| 177 |
spatial_path = self.config["spatial_upscaler_model_path"]
|
| 178 |
latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
|
|
|
|
|
|
|
| 179 |
|
| 180 |
-
logging.info(f"
|
| 181 |
return pipeline, latent_upsampler
|
| 182 |
-
|
|
|
|
| 183 |
def move_to_device(self, main_device_str: str, vae_device_str: str):
|
| 184 |
"""Moves pipeline components to their target devices."""
|
| 185 |
target_main_device = torch.device(main_device_str)
|
|
|
|
| 54 |
|
| 55 |
# --- PROJECT IMPORTS ---
|
| 56 |
try:
|
| 57 |
+
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, create_latent_upsampler # E outros...
|
| 58 |
+
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 59 |
+
from ltx_video.models.transformers.transformer3d import Transformer3DModel
|
| 60 |
+
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
| 61 |
+
from ltx_video.schedulers.rf import RectifiedFlowScheduler
|
| 62 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 63 |
+
from safetensors import safe_open
|
| 64 |
from api.gpu_manager import gpu_manager
|
| 65 |
from ltx_video.models.autoencoders.vae_encode import (normalize_latents, un_normalize_latents)
|
| 66 |
+
from ltx_video.pipelines.pipeline_ltx_video import (ConditioningItem, LTXMultiScalePipeline, adain_filter_latent, create_latent_upsampler)
|
| 67 |
from ltx_video.utils.inference_utils import load_image_to_tensor_with_resize_and_crop
|
| 68 |
from managers.vae_manager import vae_manager_singleton
|
| 69 |
from tools.video_encode_tool import video_encode_tool_singleton
|
|
|
|
| 165 |
with open(config_path, "r") as file:
|
| 166 |
return yaml.safe_load(file)
|
| 167 |
|
| 168 |
+
def _load_models(self) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
|
| 169 |
+
"""
|
| 170 |
+
Carrega todos os sub-modelos do pipeline na CPU.
|
| 171 |
+
Esta função substitui a necessidade de chamar a `create_ltx_video_pipeline` externa,
|
| 172 |
+
dando-nos controle total sobre o processo.
|
| 173 |
+
"""
|
| 174 |
t0 = time.perf_counter()
|
| 175 |
+
logging.info("Carregando sub-modelos do LTX para a CPU...")
|
| 176 |
+
|
| 177 |
+
ckpt_path = Path(self.config["checkpoint_path"])
|
| 178 |
+
if not ckpt_path.is_file():
|
| 179 |
+
raise FileNotFoundError(f"Arquivo de checkpoint principal não encontrado em: {ckpt_path}")
|
| 180 |
+
|
| 181 |
+
# 1. Carrega Metadados do Checkpoint
|
| 182 |
+
with safe_open(ckpt_path, framework="pt") as f:
|
| 183 |
+
metadata = f.metadata() or {}
|
| 184 |
+
config_str = metadata.get("config", "{}")
|
| 185 |
+
configs = json.loads(config_str)
|
| 186 |
+
allowed_inference_steps = configs.get("allowed_inference_steps")
|
| 187 |
+
|
| 188 |
+
# 2. Carrega os Componentes Individuais (todos na CPU)
|
| 189 |
+
# O `.from_pretrained(ckpt_path)` é inteligente e carrega os pesos corretos do arquivo .safetensors.
|
| 190 |
+
logging.info("Carregando VAE...")
|
| 191 |
+
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
|
| 192 |
+
|
| 193 |
+
logging.info("Carregando Transformer...")
|
| 194 |
+
transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu")
|
| 195 |
+
|
| 196 |
+
logging.info("Carregando Scheduler...")
|
| 197 |
+
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
|
| 198 |
+
|
| 199 |
+
logging.info("Carregando Text Encoder e Tokenizer...")
|
| 200 |
+
text_encoder_path = self.config["text_encoder_model_name_or_path"]
|
| 201 |
+
text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
|
| 202 |
+
tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
|
| 203 |
+
|
| 204 |
+
patchifier = SymmetricPatchifier(patch_size=1)
|
| 205 |
+
|
| 206 |
+
# 3. Define a precisão dos modelos (ainda na CPU, será aplicado na GPU depois)
|
| 207 |
+
precision = self.config.get("precision", "bfloat16")
|
| 208 |
+
if precision == "bfloat16":
|
| 209 |
+
vae.to(torch.bfloat16)
|
| 210 |
+
transformer.to(torch.bfloat16)
|
| 211 |
+
text_encoder.to(torch.bfloat16)
|
| 212 |
|
| 213 |
+
# 4. Monta o objeto do Pipeline com os componentes carregados
|
| 214 |
+
logging.info("Montando o objeto LTXVideoPipeline...")
|
| 215 |
+
submodel_dict = {
|
| 216 |
+
"transformer": transformer,
|
| 217 |
+
"patchifier": patchifier,
|
| 218 |
+
"text_encoder": text_encoder,
|
| 219 |
+
"tokenizer": tokenizer,
|
| 220 |
+
"scheduler": scheduler,
|
| 221 |
+
"vae": vae,
|
| 222 |
+
"allowed_inference_steps": allowed_inference_steps,
|
| 223 |
+
# Os prompt enhancers são opcionais e não são carregados por padrão para economizar memória
|
| 224 |
+
"prompt_enhancer_image_caption_model": None,
|
| 225 |
+
"prompt_enhancer_image_caption_processor": None,
|
| 226 |
+
"prompt_enhancer_llm_model": None,
|
| 227 |
+
"prompt_enhancer_llm_tokenizer": None,
|
| 228 |
+
}
|
| 229 |
+
pipeline = LTXVideoPipeline(**submodel_dict)
|
| 230 |
+
|
| 231 |
+
# 5. Carrega o Latent Upsampler (também na CPU)
|
| 232 |
latent_upsampler = None
|
| 233 |
if self.config.get("spatial_upscaler_model_path"):
|
| 234 |
+
logging.info("Carregando Latent Upsampler...")
|
| 235 |
spatial_path = self.config["spatial_upscaler_model_path"]
|
| 236 |
latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
|
| 237 |
+
if precision == "bfloat16":
|
| 238 |
+
latent_upsampler.to(torch.bfloat16)
|
| 239 |
|
| 240 |
+
logging.info(f"Modelos LTX carregados na CPU em {time.perf_counter()-t0:.2f}s")
|
| 241 |
return pipeline, latent_upsampler
|
| 242 |
+
|
| 243 |
+
|
| 244 |
def move_to_device(self, main_device_str: str, vae_device_str: str):
|
| 245 |
"""Moves pipeline components to their target devices."""
|
| 246 |
target_main_device = torch.device(main_device_str)
|