Spaces:
Paused
Paused
Update managers/vae_manager.py
Browse files- managers/vae_manager.py +73 -72
managers/vae_manager.py
CHANGED
|
@@ -1,90 +1,91 @@
|
|
| 1 |
-
# vae_manager.py
|
| 2 |
-
#
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import contextlib
|
| 6 |
-
import
|
| 7 |
-
import subprocess
|
| 8 |
-
import sys
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
|
| 11 |
-
from huggingface_hub import logging
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
logging.set_verbosity_error()
|
| 15 |
-
logging.set_verbosity_warning()
|
| 16 |
-
logging.set_verbosity_info()
|
| 17 |
-
logging.set_verbosity_debug()
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
DEPS_DIR = Path("/data")
|
| 23 |
-
LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
|
| 24 |
-
if not LTX_VIDEO_REPO_DIR.exists():
|
| 25 |
-
print(f"[DEBUG] Repositório não encontrado em {LTX_VIDEO_REPO_DIR}. Rodando setup...")
|
| 26 |
-
run_setup()
|
| 27 |
-
|
| 28 |
-
def add_deps_to_path():
|
| 29 |
-
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
|
| 30 |
-
if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
|
| 31 |
-
sys.path.insert(0, repo_path)
|
| 32 |
-
print(f"[DEBUG] Repo adicionado ao sys.path: {repo_path}")
|
| 33 |
-
|
| 34 |
-
add_deps_to_path()
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
|
| 39 |
-
|
| 40 |
|
| 41 |
class _SimpleVAEManager:
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
"""
|
| 48 |
-
self.pipeline =
|
| 49 |
-
self.device = device
|
| 50 |
-
self.autocast_dtype =
|
| 51 |
|
| 52 |
def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
self.pipeline = pipeline
|
| 54 |
if device is not None:
|
| 55 |
-
self.device = device
|
|
|
|
| 56 |
if autocast_dtype is not None:
|
| 57 |
self.autocast_dtype = autocast_dtype
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
@torch.no_grad()
|
| 62 |
def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
with ctx:
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
timestep=timestep_tensor,
|
| 78 |
-
vae_per_channel_normalize=True,
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
# Normaliza para [0,1] se vier em [-1,1]
|
| 82 |
-
if pixels.min() < 0:
|
| 83 |
-
pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0
|
| 84 |
-
else:
|
| 85 |
-
pixels = pixels.clamp(0, 1)
|
| 86 |
-
return pixels
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
#
|
| 90 |
-
vae_manager_singleton = _SimpleVAEManager()
|
|
|
|
| 1 |
+
# FILE: managers/vae_manager.py
|
| 2 |
+
# DESCRIPTION: Singleton manager for VAE decoding operations, supporting dedicated GPU devices.
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import contextlib
|
| 6 |
+
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class _SimpleVAEManager:
|
| 9 |
+
"""
|
| 10 |
+
Manages VAE decoding. It's designed to be aware that the VAE might reside
|
| 11 |
+
on a different GPU than the main generation pipeline (e.g., Transformer).
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self):
|
| 14 |
+
"""Initializes the manager without a pipeline attached."""
|
| 15 |
+
self.pipeline = None
|
| 16 |
+
self.device = torch.device("cpu") # Defaults to CPU until a device is attached.
|
| 17 |
+
self.autocast_dtype = torch.float32
|
| 18 |
|
| 19 |
def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
|
| 20 |
+
"""
|
| 21 |
+
Attaches the main pipeline and, crucially, stores the specific device
|
| 22 |
+
that this manager and its associated VAE should operate on.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
pipeline: The main LTX video pipeline instance.
|
| 26 |
+
device (torch.device or str): The target device for VAE operations (e.g., 'cuda:1').
|
| 27 |
+
autocast_dtype (torch.dtype): The precision for torch.autocast.
|
| 28 |
+
"""
|
| 29 |
self.pipeline = pipeline
|
| 30 |
if device is not None:
|
| 31 |
+
self.device = torch.device(device)
|
| 32 |
+
logging.info(f"[VAEManager] VAE device successfully set to: {self.device}")
|
| 33 |
if autocast_dtype is not None:
|
| 34 |
self.autocast_dtype = autocast_dtype
|
| 35 |
|
|
|
|
|
|
|
| 36 |
@torch.no_grad()
|
| 37 |
def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
|
| 38 |
+
"""
|
| 39 |
+
Decodes a latent tensor into a pixel tensor.
|
| 40 |
+
|
| 41 |
+
This method ensures that the decoding operation happens on the correct,
|
| 42 |
+
potentially dedicated, VAE device.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
latent_tensor (torch.Tensor): The latents to decode, typically on the main device or CPU.
|
| 46 |
+
decode_timestep (float): The timestep for VAE decoding.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
torch.Tensor: The resulting pixel tensor, moved to the CPU for general use.
|
| 50 |
+
"""
|
| 51 |
+
if self.pipeline is None:
|
| 52 |
+
raise RuntimeError("VAEManager: No pipeline has been attached. Call attach_pipeline() first.")
|
| 53 |
+
if not hasattr(self.pipeline, 'vae'):
|
| 54 |
+
raise AttributeError("VAEManager: The attached pipeline does not have a 'vae' attribute.")
|
| 55 |
+
|
| 56 |
+
# 1. Move the input latents to the dedicated VAE device. This is the critical step.
|
| 57 |
+
logging.debug(f"[VAEManager] Moving latents from {latent_tensor.device} to VAE device {self.device} for decoding.")
|
| 58 |
+
latent_tensor_on_vae_device = latent_tensor.to(self.device)
|
| 59 |
+
|
| 60 |
+
# 2. Get a reference to the VAE model (which is already on the correct device).
|
| 61 |
+
vae = self.pipeline.vae
|
| 62 |
+
|
| 63 |
+
# 3. Prepare other necessary tensors on the same VAE device.
|
| 64 |
+
num_items_in_batch = latent_tensor_on_vae_device.shape[0]
|
| 65 |
+
timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device)
|
| 66 |
|
| 67 |
+
# 4. Set up the autocast context for the target device type.
|
| 68 |
+
autocast_device_type = self.device.type
|
| 69 |
+
ctx = torch.autocast(
|
| 70 |
+
device_type=autocast_device_type,
|
| 71 |
+
dtype=self.autocast_dtype,
|
| 72 |
+
enabled=(autocast_device_type == 'cuda')
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# 5. Perform the decoding operation within the autocast context.
|
| 76 |
with ctx:
|
| 77 |
+
logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.")
|
| 78 |
+
# The VAE expects latents scaled by its scaling factor.
|
| 79 |
+
scaled_latents = latent_tensor_on_vae_device / vae.config.scaling_factor
|
| 80 |
+
pixels = vae.decode(scaled_latents, timesteps=timestep_tensor).sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
# 6. Post-process the output: normalize to [0, 1] range.
|
| 83 |
+
pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0
|
| 84 |
+
|
| 85 |
+
# 7. Move the final pixel tensor to the CPU. This is a safe default, as subsequent
|
| 86 |
+
# operations like video saving or UI display typically expect CPU tensors.
|
| 87 |
+
logging.debug(f"[VAEManager] Decoding complete. Moving pixel tensor to CPU.")
|
| 88 |
+
return pixels.cpu()
|
| 89 |
|
| 90 |
+
# Create a single, global instance of the manager to be used throughout the application.
|
| 91 |
+
vae_manager_singleton = _SimpleVAEManager()
|