|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
from typing import Optional, List, Tuple |
|
|
from pathlib import Path |
|
|
import os |
|
|
import sys |
|
|
from dataclasses import dataclass, replace |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from ltx_video.pipelines.pipeline_ltx_video import ( |
|
|
LTXVideoPipeline, |
|
|
ConditioningItem as OriginalConditioningItem |
|
|
) |
|
|
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder |
|
|
from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
except ImportError as e: |
|
|
print(f"FATAL ERROR: Could not import dependencies from 'ltx_video'. " |
|
|
f"Please ensure the environment is correctly set up. Error: {e}") |
|
|
raise |
|
|
|
|
|
print("[INFO] Patch module 'aduc_ltx_latent_patch' loaded successfully.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PatchedConditioningItem: |
|
|
""" |
|
|
Versão modificada do `ConditioningItem` que aceita tensores de pixel (`media_item`) |
|
|
ou tensores de latentes pré-codificados (`latents`). |
|
|
|
|
|
Attributes: |
|
|
media_frame_number (int): Quadro inicial do item de condicionamento no vídeo. |
|
|
conditioning_strength (float): Força do condicionamento (0.0 a 1.0). |
|
|
media_item (Optional[Tensor]): Tensor de mídia (pixels). Usado se `latents` for None. |
|
|
media_x (Optional[int]): Coordenada X (esquerda) para posicionamento espacial. |
|
|
media_y (Optional[int]): Coordenada Y (topo) para posicionamento espacial. |
|
|
latents (Optional[Tensor]): Tensor de latentes pré-codificado. Terá precedência sobre `media_item`. |
|
|
""" |
|
|
media_frame_number: int |
|
|
conditioning_strength: float |
|
|
media_item: Optional[Tensor] = None |
|
|
media_x: Optional[int] = None |
|
|
media_y: Optional[int] = None |
|
|
latents: Optional[Tensor] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
"""Valida o estado do objeto após a inicialização.""" |
|
|
if self.media_item is None and self.latents is None: |
|
|
raise ValueError("A `PatchedConditioningItem` must have either 'media_item' or 'latents' defined.") |
|
|
if self.media_item is not None and self.latents is not None: |
|
|
print("[WARNING] `PatchedConditioningItem` received both 'media_item' and 'latents'. " |
|
|
"The 'latents' tensor will take precedence.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_conditioning_with_latents( |
|
|
self: LTXVideoPipeline, |
|
|
conditioning_items: Optional[List[PatchedConditioningItem]], |
|
|
init_latents: Tensor, |
|
|
num_frames: int, |
|
|
height: int, |
|
|
width: int, |
|
|
vae_per_channel_normalize: bool = False, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
) -> Tuple[Tensor, Tensor, Optional[Tensor], int]: |
|
|
""" |
|
|
Versão modificada de `prepare_conditioning` que prioriza o uso de latentes pré-calculados |
|
|
dos `conditioning_items`, evitando a re-codificação desnecessária pela VAE. |
|
|
""" |
|
|
assert isinstance(self, LTXVideoPipeline), "This function must be called as a method of LTXVideoPipeline." |
|
|
assert isinstance(self.vae, CausalVideoAutoencoder), "VAE must be of type CausalVideoAutoencoder." |
|
|
|
|
|
if not conditioning_items: |
|
|
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) |
|
|
init_pixel_coords = latent_to_pixel_coords( |
|
|
init_latent_coords, self.vae, |
|
|
causal_fix=self.transformer.config.causal_temporal_positioning |
|
|
) |
|
|
return init_latents, init_pixel_coords, None, 0 |
|
|
|
|
|
init_conditioning_mask = torch.zeros( |
|
|
init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device |
|
|
) |
|
|
extra_conditioning_latents = [] |
|
|
extra_conditioning_pixel_coords = [] |
|
|
extra_conditioning_mask = [] |
|
|
extra_conditioning_num_latents = 0 |
|
|
|
|
|
for item in conditioning_items: |
|
|
item_latents: Tensor |
|
|
|
|
|
if item.latents is not None: |
|
|
item_latents = item.latents.to(dtype=init_latents.dtype, device=init_latents.device) |
|
|
if item_latents.ndim != 5: |
|
|
raise ValueError(f"Latents must have 5 dimensions (b, c, f, h, w), but got {item_latents.ndim}") |
|
|
elif item.media_item is not None: |
|
|
resized_item = self._resize_conditioning_item(item, height, width) |
|
|
media_item = resized_item.media_item |
|
|
assert media_item.ndim == 5, f"media_item must have 5 dims, but got {media_item.ndim}" |
|
|
item_latents = vae_encode( |
|
|
media_item.to(dtype=self.vae.dtype, device=self.vae.device), |
|
|
self.vae, |
|
|
vae_per_channel_normalize=vae_per_channel_normalize, |
|
|
).to(dtype=init_latents.dtype) |
|
|
else: |
|
|
raise ValueError("ConditioningItem is invalid: it has neither 'latents' nor 'media_item'.") |
|
|
|
|
|
media_frame_number = item.media_frame_number |
|
|
strength = item.conditioning_strength |
|
|
|
|
|
if media_frame_number == 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
item_for_spatial_position = item |
|
|
if item.media_item is None: |
|
|
|
|
|
latent_h, latent_w = item_latents.shape[-2:] |
|
|
pixel_h = latent_h * self.vae_scale_factor |
|
|
pixel_w = latent_w * self.vae_scale_factor |
|
|
|
|
|
|
|
|
placeholder_media_item = torch.empty( |
|
|
(1, 1, 1, pixel_h, pixel_w), device=item_latents.device, dtype=item_latents.dtype |
|
|
) |
|
|
|
|
|
|
|
|
item_for_spatial_position = replace(item, media_item=placeholder_media_item) |
|
|
|
|
|
|
|
|
item_latents, l_x, l_y = self._get_latent_spatial_position( |
|
|
item_latents, item_for_spatial_position, height, width, strip_latent_border=True |
|
|
) |
|
|
|
|
|
|
|
|
_, _, f_l, h_l, w_l = item_latents.shape |
|
|
init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = torch.lerp( |
|
|
init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], item_latents, strength |
|
|
) |
|
|
init_conditioning_mask[:, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = strength |
|
|
else: |
|
|
if item_latents.shape[2] > 1: |
|
|
(init_latents, init_conditioning_mask, item_latents) = self._handle_non_first_conditioning_sequence( |
|
|
init_latents, init_conditioning_mask, item_latents, media_frame_number, strength |
|
|
) |
|
|
|
|
|
if item_latents is not None: |
|
|
noise = randn_tensor( |
|
|
item_latents.shape, generator=generator, |
|
|
device=item_latents.device, dtype=item_latents.dtype |
|
|
) |
|
|
item_latents = torch.lerp(noise, item_latents, strength) |
|
|
item_latents, latent_coords = self.patchifier.patchify(latents=item_latents) |
|
|
pixel_coords = latent_to_pixel_coords( |
|
|
latent_coords, self.vae, |
|
|
causal_fix=self.transformer.config.causal_temporal_positioning |
|
|
) |
|
|
pixel_coords[:, 0] += media_frame_number |
|
|
extra_conditioning_num_latents += item_latents.shape[1] |
|
|
conditioning_mask = torch.full( |
|
|
item_latents.shape[:2], strength, |
|
|
dtype=torch.float32, device=init_latents.device |
|
|
) |
|
|
extra_conditioning_latents.append(item_latents) |
|
|
extra_conditioning_pixel_coords.append(pixel_coords) |
|
|
extra_conditioning_mask.append(conditioning_mask) |
|
|
|
|
|
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) |
|
|
init_pixel_coords = latent_to_pixel_coords( |
|
|
init_latent_coords, self.vae, |
|
|
causal_fix=self.transformer.config.causal_temporal_positioning |
|
|
) |
|
|
init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1)) |
|
|
init_conditioning_mask = init_conditioning_mask.squeeze(-1) |
|
|
|
|
|
if extra_conditioning_latents: |
|
|
init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) |
|
|
init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2) |
|
|
init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1) |
|
|
|
|
|
if self.transformer.use_tpu_flash_attention: |
|
|
init_latents = init_latents[:, :-extra_conditioning_num_latents] |
|
|
init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents] |
|
|
init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents] |
|
|
|
|
|
return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LTXLatentConditioningPatch: |
|
|
""" |
|
|
Classe estática para aplicar e reverter o monkey patch na pipeline LTX-Video. |
|
|
""" |
|
|
_original_prepare_conditioning = None |
|
|
_is_patched = False |
|
|
|
|
|
@staticmethod |
|
|
def apply(): |
|
|
""" |
|
|
Aplica o monkey patch à classe `LTXVideoPipeline`. |
|
|
""" |
|
|
if LTXLatentConditioningPatch._is_patched: |
|
|
print("[WARNING] LTXLatentConditioningPatch has already been applied. Ignoring.") |
|
|
return |
|
|
|
|
|
print("[INFO] Applying monkey patch for latent-based conditioning...") |
|
|
|
|
|
LTXLatentConditioningPatch._original_prepare_conditioning = LTXVideoPipeline.prepare_conditioning |
|
|
LTXVideoPipeline.prepare_conditioning = prepare_conditioning_with_latents |
|
|
|
|
|
LTXLatentConditioningPatch._is_patched = True |
|
|
print("[SUCCESS] Monkey patch applied successfully.") |
|
|
print(" - `LTXVideoPipeline.prepare_conditioning` has been updated.") |
|
|
print(" - NOTE: Remember to use `aduc_ltx_latent_patch.PatchedConditioningItem` when creating conditioning items.") |
|
|
|
|
|
@staticmethod |
|
|
def revert(): |
|
|
""" |
|
|
Reverte o monkey patch, restaurando a implementação original. |
|
|
""" |
|
|
if not LTXLatentConditioningPatch._is_patched: |
|
|
print("[WARNING] Patch is not currently applied. No action taken.") |
|
|
return |
|
|
|
|
|
if LTXLatentConditioningPatch._original_prepare_conditioning: |
|
|
print("[INFO] Reverting LTXLatentConditioningPatch...") |
|
|
LTXVideoPipeline.prepare_conditioning = LTXLatentConditioningPatch._original_prepare_conditioning |
|
|
LTXLatentConditioningPatch._is_patched = False |
|
|
print("[SUCCESS] Patch reverted successfully. Original functionality restored.") |
|
|
else: |
|
|
print("[ERROR] Cannot revert: original implementation was not saved.") |