eeuuia commited on
Commit
b8a0748
·
verified ·
1 Parent(s): c8b13b1

Update managers/vae_manager.py

Browse files
Files changed (1) hide show
  1. managers/vae_manager.py +73 -72
managers/vae_manager.py CHANGED
@@ -1,90 +1,91 @@
1
- # vae_manager.py — versão simples (beta 1.0)
2
- # Responsável por decodificar latentes (B,C,T,H,W) pixels (B,C,T,H',W') em [0,1].
3
 
4
  import torch
5
  import contextlib
6
- import os
7
- import subprocess
8
- import sys
9
- from pathlib import Path
10
-
11
- from huggingface_hub import logging
12
-
13
-
14
- logging.set_verbosity_error()
15
- logging.set_verbosity_warning()
16
- logging.set_verbosity_info()
17
- logging.set_verbosity_debug()
18
-
19
-
20
-
21
-
22
- DEPS_DIR = Path("/data")
23
- LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
24
- if not LTX_VIDEO_REPO_DIR.exists():
25
- print(f"[DEBUG] Repositório não encontrado em {LTX_VIDEO_REPO_DIR}. Rodando setup...")
26
- run_setup()
27
-
28
- def add_deps_to_path():
29
- repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
30
- if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
31
- sys.path.insert(0, repo_path)
32
- print(f"[DEBUG] Repo adicionado ao sys.path: {repo_path}")
33
-
34
- add_deps_to_path()
35
-
36
-
37
-
38
- from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
39
-
40
 
41
  class _SimpleVAEManager:
42
- def __init__(self, pipeline=None, device=None, autocast_dtype=torch.float32):
43
- """
44
- pipeline: objeto do LTX que expõe decode_latents(...) ou .vae.decode(...)
45
- device: "cuda" ou "cpu" onde a decodificação deve ocorrer
46
- autocast_dtype: dtype de autocast quando em CUDA (bf16/fp16/fp32)
47
- """
48
- self.pipeline = pipeline
49
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
50
- self.autocast_dtype = autocast_dtype
51
 
52
  def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
 
 
 
 
 
 
 
 
 
53
  self.pipeline = pipeline
54
  if device is not None:
55
- self.device = device
 
56
  if autocast_dtype is not None:
57
  self.autocast_dtype = autocast_dtype
58
 
59
-
60
-
61
  @torch.no_grad()
62
  def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
63
-
64
- # Garante device e dtype conforme runtime
65
- latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.autocast_dtype if self.device == "cuda" else latent_tensor.dtype)
66
-
67
- # Constrói o vetor de timesteps (um por item no batch B)
68
- num_items_in_batch = latent_tensor_gpu.shape[0]
69
- timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=latent_tensor_gpu.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
 
 
 
 
 
 
 
 
72
  with ctx:
73
- pixels = vae_decode(
74
- latent_tensor_gpu,
75
- self.pipeline.vae if hasattr(self.pipeline, "vae") else self.pipeline, # compat
76
- is_video=True,
77
- timestep=timestep_tensor,
78
- vae_per_channel_normalize=True,
79
- )
80
-
81
- # Normaliza para [0,1] se vier em [-1,1]
82
- if pixels.min() < 0:
83
- pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0
84
- else:
85
- pixels = pixels.clamp(0, 1)
86
- return pixels
87
 
 
 
 
 
 
 
 
88
 
89
- # Singleton global de uso simples
90
- vae_manager_singleton = _SimpleVAEManager()
 
1
+ # FILE: managers/vae_manager.py
2
+ # DESCRIPTION: Singleton manager for VAE decoding operations, supporting dedicated GPU devices.
3
 
4
  import torch
5
  import contextlib
6
+ import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class _SimpleVAEManager:
9
+ """
10
+ Manages VAE decoding. It's designed to be aware that the VAE might reside
11
+ on a different GPU than the main generation pipeline (e.g., Transformer).
12
+ """
13
+ def __init__(self):
14
+ """Initializes the manager without a pipeline attached."""
15
+ self.pipeline = None
16
+ self.device = torch.device("cpu") # Defaults to CPU until a device is attached.
17
+ self.autocast_dtype = torch.float32
18
 
19
  def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
20
+ """
21
+ Attaches the main pipeline and, crucially, stores the specific device
22
+ that this manager and its associated VAE should operate on.
23
+
24
+ Args:
25
+ pipeline: The main LTX video pipeline instance.
26
+ device (torch.device or str): The target device for VAE operations (e.g., 'cuda:1').
27
+ autocast_dtype (torch.dtype): The precision for torch.autocast.
28
+ """
29
  self.pipeline = pipeline
30
  if device is not None:
31
+ self.device = torch.device(device)
32
+ logging.info(f"[VAEManager] VAE device successfully set to: {self.device}")
33
  if autocast_dtype is not None:
34
  self.autocast_dtype = autocast_dtype
35
 
 
 
36
  @torch.no_grad()
37
  def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
38
+ """
39
+ Decodes a latent tensor into a pixel tensor.
40
+
41
+ This method ensures that the decoding operation happens on the correct,
42
+ potentially dedicated, VAE device.
43
+
44
+ Args:
45
+ latent_tensor (torch.Tensor): The latents to decode, typically on the main device or CPU.
46
+ decode_timestep (float): The timestep for VAE decoding.
47
+
48
+ Returns:
49
+ torch.Tensor: The resulting pixel tensor, moved to the CPU for general use.
50
+ """
51
+ if self.pipeline is None:
52
+ raise RuntimeError("VAEManager: No pipeline has been attached. Call attach_pipeline() first.")
53
+ if not hasattr(self.pipeline, 'vae'):
54
+ raise AttributeError("VAEManager: The attached pipeline does not have a 'vae' attribute.")
55
+
56
+ # 1. Move the input latents to the dedicated VAE device. This is the critical step.
57
+ logging.debug(f"[VAEManager] Moving latents from {latent_tensor.device} to VAE device {self.device} for decoding.")
58
+ latent_tensor_on_vae_device = latent_tensor.to(self.device)
59
+
60
+ # 2. Get a reference to the VAE model (which is already on the correct device).
61
+ vae = self.pipeline.vae
62
+
63
+ # 3. Prepare other necessary tensors on the same VAE device.
64
+ num_items_in_batch = latent_tensor_on_vae_device.shape[0]
65
+ timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device)
66
 
67
+ # 4. Set up the autocast context for the target device type.
68
+ autocast_device_type = self.device.type
69
+ ctx = torch.autocast(
70
+ device_type=autocast_device_type,
71
+ dtype=self.autocast_dtype,
72
+ enabled=(autocast_device_type == 'cuda')
73
+ )
74
+
75
+ # 5. Perform the decoding operation within the autocast context.
76
  with ctx:
77
+ logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.")
78
+ # The VAE expects latents scaled by its scaling factor.
79
+ scaled_latents = latent_tensor_on_vae_device / vae.config.scaling_factor
80
+ pixels = vae.decode(scaled_latents, timesteps=timestep_tensor).sample
 
 
 
 
 
 
 
 
 
 
81
 
82
+ # 6. Post-process the output: normalize to [0, 1] range.
83
+ pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0
84
+
85
+ # 7. Move the final pixel tensor to the CPU. This is a safe default, as subsequent
86
+ # operations like video saving or UI display typically expect CPU tensors.
87
+ logging.debug(f"[VAEManager] Decoding complete. Moving pixel tensor to CPU.")
88
+ return pixels.cpu()
89
 
90
+ # Create a single, global instance of the manager to be used throughout the application.
91
+ vae_manager_singleton = _SimpleVAEManager()