Upload pipeline_ltx_video.py
Browse files
LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py
CHANGED
|
@@ -13,7 +13,7 @@ from diffusers.image_processor import VaeImageProcessor
|
|
| 13 |
from diffusers.models import AutoencoderKL
|
| 14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 15 |
from diffusers.schedulers import DPMSolverMultistepScheduler
|
| 16 |
-
|
| 17 |
from diffusers.utils.torch_utils import randn_tensor
|
| 18 |
from einops import rearrange
|
| 19 |
from transformers import (
|
|
@@ -24,7 +24,6 @@ from transformers import (
|
|
| 24 |
AutoTokenizer,
|
| 25 |
)
|
| 26 |
|
| 27 |
-
|
| 28 |
from ltx_video.models.autoencoders.causal_video_autoencoder import (
|
| 29 |
CausalVideoAutoencoder,
|
| 30 |
)
|
|
@@ -45,127 +44,8 @@ from ltx_video.models.autoencoders.vae_encode import (
|
|
| 45 |
normalize_latents,
|
| 46 |
)
|
| 47 |
|
| 48 |
-
import warnings
|
| 49 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
| 50 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 51 |
-
warnings.filterwarnings("ignore", message=".*")
|
| 52 |
-
|
| 53 |
-
from huggingface_hub import logging
|
| 54 |
-
|
| 55 |
-
logging.set_verbosity_error()
|
| 56 |
-
logging.set_verbosity_warning()
|
| 57 |
-
logging.set_verbosity_info()
|
| 58 |
-
logging.set_verbosity_debug()
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
#logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
class SpyLatent:
|
| 65 |
-
|
| 66 |
-
"""
|
| 67 |
-
Uma classe para inspecionar tensores latentes em vários estágios de um pipeline.
|
| 68 |
-
Imprime estatísticas e pode salvar visualizações decodificadas por um VAE.
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
-
import torch
|
| 72 |
-
import os
|
| 73 |
-
import traceback
|
| 74 |
-
from einops import rearrange
|
| 75 |
-
from torchvision.utils import save_image
|
| 76 |
-
|
| 77 |
-
def __init__(self, vae=None, output_dir: str = "/app/output"):
|
| 78 |
-
"""
|
| 79 |
-
Inicializa o espião.
|
| 80 |
-
|
| 81 |
-
Args:
|
| 82 |
-
vae: A instância do modelo VAE para decodificar os latentes. Se for None,
|
| 83 |
-
a visualização será desativada.
|
| 84 |
-
output_dir (str): O diretório padrão para salvar as imagens de visualização.
|
| 85 |
-
"""
|
| 86 |
-
self.vae = vae
|
| 87 |
-
self.output_dir = output_dir
|
| 88 |
-
self.device = vae.device if hasattr(vae, 'device') else torch.device("cpu")
|
| 89 |
-
|
| 90 |
-
if self.vae is None:
|
| 91 |
-
print("[SpyLatent] AVISO: VAE não fornecido. A funcionalidade de visualização de imagem está desativada.")
|
| 92 |
-
|
| 93 |
-
def inspect(
|
| 94 |
-
self,
|
| 95 |
-
tensor: torch.Tensor,
|
| 96 |
-
tag: str,
|
| 97 |
-
reference_shape_5d: tuple = None,
|
| 98 |
-
save_visual: bool = True,
|
| 99 |
-
):
|
| 100 |
-
"""
|
| 101 |
-
Inspeciona um tensor latente.
|
| 102 |
-
|
| 103 |
-
Args:
|
| 104 |
-
tensor (torch.Tensor): O tensor a ser inspecionado.
|
| 105 |
-
tag (str): Um rótulo para identificar o ponto de inspeção nos logs.
|
| 106 |
-
reference_shape_5d (tuple, optional): A forma 5D de referência (B, C, F, H, W)
|
| 107 |
-
necessária se o tensor de entrada for 3D.
|
| 108 |
-
save_visual (bool): Se True, decodifica com o VAE e salva uma imagem.
|
| 109 |
-
"""
|
| 110 |
-
#print(f"\n--- [INSPEÇÃO DE LATENTE: {tag}] ---")
|
| 111 |
-
#if not isinstance(tensor, torch.Tensor):
|
| 112 |
-
# print(f" AVISO: O objeto fornecido para '{tag}' não é um tensor.")
|
| 113 |
-
# print("--- [FIM DA INSPEÇÃO] ---\n")
|
| 114 |
-
# return
|
| 115 |
-
|
| 116 |
-
try:
|
| 117 |
-
# --- Imprime Estatísticas do Tensor Original ---
|
| 118 |
-
#self._print_stats("Tensor Original", tensor)
|
| 119 |
-
|
| 120 |
-
# --- Converte para 5D se necessário ---
|
| 121 |
-
tensor_5d = self._to_5d(tensor, reference_shape_5d)
|
| 122 |
-
if tensor_5d is not None and tensor.ndim == 3:
|
| 123 |
-
self._print_stats("Convertido para 5D", tensor_5d)
|
| 124 |
-
|
| 125 |
-
# --- Visualização com VAE ---
|
| 126 |
-
if save_visual and self.vae is not None and tensor_5d is not None:
|
| 127 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
| 128 |
-
#print(f" VISUALIZAÇÃO (VAE): Salvando imagem em {self.output_dir}...")
|
| 129 |
-
|
| 130 |
-
frame_idx_to_viz = min(1, tensor_5d.shape[2] - 1)
|
| 131 |
-
if frame_idx_to_viz < 0:
|
| 132 |
-
print(" VISUALIZAÇÃO (VAE): Tensor não tem frames para visualizar.")
|
| 133 |
-
else:
|
| 134 |
-
#print(f" VISUALIZAÇÃO (VAE): Usando frame de índice {frame_idx_to_viz}.")
|
| 135 |
-
latent_slice = tensor_5d[:, :, frame_idx_to_viz:frame_idx_to_viz+1, :, :]
|
| 136 |
-
|
| 137 |
-
with torch.no_grad(), torch.autocast(device_type=self.device.type):
|
| 138 |
-
pixel_slice = self.vae.decode(latent_slice / self.vae.config.scaling_factor).sample
|
| 139 |
-
|
| 140 |
-
save_image((pixel_slice / 2 + 0.5).clamp(0, 1), os.path.join(self.output_dir, f"inspect_{tag.lower()}.png"))
|
| 141 |
-
print(" VISUALIZAÇÃO (VAE): Imagem salva.")
|
| 142 |
-
|
| 143 |
-
except Exception as e:
|
| 144 |
-
#print(f" ERRO na inspeção: {e}")
|
| 145 |
-
traceback.print_exc()
|
| 146 |
-
|
| 147 |
-
def _to_5d(self, tensor: torch.Tensor, shape_5d: tuple) -> torch.Tensor:
|
| 148 |
-
"""Converte um tensor 3D patchificado de volta para 5D."""
|
| 149 |
-
if tensor.ndim == 5:
|
| 150 |
-
return tensor
|
| 151 |
-
if tensor.ndim == 3 and shape_5d:
|
| 152 |
-
try:
|
| 153 |
-
b, c, f, h, w = shape_5d
|
| 154 |
-
return rearrange(tensor, "b (f h w) c -> b c f h w", c=c, f=f, h=h, w=w)
|
| 155 |
-
except Exception as e:
|
| 156 |
-
#print(f" AVISO: Erro ao rearranjar tensor 3D para 5D: {e}. A visualização pode falhar.")
|
| 157 |
-
return None
|
| 158 |
-
return None
|
| 159 |
-
|
| 160 |
-
def _print_stats(self, prefix: str, tensor: torch.Tensor):
|
| 161 |
-
"""Helper para imprimir estatísticas de um tensor."""
|
| 162 |
-
mean = tensor.mean().item()
|
| 163 |
-
std = tensor.std().item()
|
| 164 |
-
min_val = tensor.min().item()
|
| 165 |
-
max_val = tensor.max().item()
|
| 166 |
-
print(f" {prefix}: {tensor.shape}")
|
| 167 |
-
|
| 168 |
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
ASPECT_RATIO_1024_BIN = {
|
|
@@ -247,8 +127,8 @@ def retrieve_timesteps(
|
|
| 247 |
num_inference_steps: Optional[int] = None,
|
| 248 |
device: Optional[Union[str, torch.device]] = None,
|
| 249 |
timesteps: Optional[List[int]] = None,
|
| 250 |
-
skip_initial_inference_steps:
|
| 251 |
-
skip_final_inference_steps:
|
| 252 |
**kwargs,
|
| 253 |
):
|
| 254 |
"""
|
|
@@ -306,12 +186,6 @@ def retrieve_timesteps(
|
|
| 306 |
]
|
| 307 |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 308 |
num_inference_steps = len(timesteps)
|
| 309 |
-
|
| 310 |
-
try:
|
| 311 |
-
print(f"[LTX]LATENTS {latents.shape}")
|
| 312 |
-
except Exception:
|
| 313 |
-
pass
|
| 314 |
-
|
| 315 |
|
| 316 |
return timesteps, num_inference_steps
|
| 317 |
|
|
@@ -358,8 +232,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 358 |
scheduler ([`SchedulerMixin`]):
|
| 359 |
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 360 |
"""
|
| 361 |
-
|
| 362 |
-
|
| 363 |
|
| 364 |
bad_punct_regex = re.compile(
|
| 365 |
r"["
|
|
@@ -422,8 +294,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 422 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 423 |
|
| 424 |
self.allowed_inference_steps = allowed_inference_steps
|
| 425 |
-
|
| 426 |
-
self.spy = SpyLatent(vae=vae)
|
| 427 |
|
| 428 |
def mask_text_embeddings(self, emb, mask):
|
| 429 |
if emb.shape[0] == 1:
|
|
@@ -473,7 +343,7 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 473 |
|
| 474 |
if "mask_feature" in kwargs:
|
| 475 |
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
| 476 |
-
|
| 477 |
|
| 478 |
if device is None:
|
| 479 |
device = self._execution_device
|
|
@@ -486,10 +356,9 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 486 |
batch_size = prompt_embeds.shape[0]
|
| 487 |
|
| 488 |
# See Section 3.1. of the paper.
|
| 489 |
-
max_length =
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
#)
|
| 493 |
if prompt_embeds is None:
|
| 494 |
assert (
|
| 495 |
self.text_encoder is not None
|
|
@@ -515,10 +384,10 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 515 |
removed_text = self.tokenizer.batch_decode(
|
| 516 |
untruncated_ids[:, max_length - 1 : -1]
|
| 517 |
)
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
|
| 523 |
prompt_attention_mask = text_inputs.attention_mask
|
| 524 |
prompt_attention_mask = prompt_attention_mask.to(text_enc_device)
|
|
@@ -1006,21 +875,15 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1006 |
tone_map_compression_ratio: compression ratio for tone mapping, defaults to 0.0.
|
| 1007 |
If set to 0.0, no tone mapping is applied. If set to 1.0 - full compression is applied.
|
| 1008 |
Examples:
|
|
|
|
| 1009 |
Returns:
|
| 1010 |
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
| 1011 |
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
| 1012 |
returned where the first element is a list with the generated images
|
| 1013 |
"""
|
| 1014 |
-
|
| 1015 |
-
try:
|
| 1016 |
-
print(f"[LTX]LATENTS {latents.shape}")
|
| 1017 |
-
except Exception:
|
| 1018 |
-
pass
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
if "mask_feature" in kwargs:
|
| 1022 |
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
| 1023 |
-
|
| 1024 |
|
| 1025 |
is_video = kwargs.get("is_video", False)
|
| 1026 |
self.check_inputs(
|
|
@@ -1085,12 +948,7 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1085 |
skip_final_inference_steps=skip_final_inference_steps,
|
| 1086 |
**retrieve_timesteps_kwargs,
|
| 1087 |
)
|
| 1088 |
-
|
| 1089 |
-
try:
|
| 1090 |
-
print(f"[LTX2]LATENTS {latents.shape}")
|
| 1091 |
-
except Exception:
|
| 1092 |
-
pass
|
| 1093 |
-
|
| 1094 |
if self.allowed_inference_steps is not None:
|
| 1095 |
for timestep in [round(x, 4) for x in timesteps.tolist()]:
|
| 1096 |
assert (
|
|
@@ -1158,12 +1016,7 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1158 |
conditioning_items,
|
| 1159 |
max_new_tokens=text_encoder_max_tokens,
|
| 1160 |
)
|
| 1161 |
-
|
| 1162 |
-
try:
|
| 1163 |
-
print(f"[LTX3]LATENTS {latents.shape}")
|
| 1164 |
-
except Exception:
|
| 1165 |
-
pass
|
| 1166 |
-
|
| 1167 |
# 3. Encode input prompt
|
| 1168 |
if self.text_encoder is not None:
|
| 1169 |
self.text_encoder = self.text_encoder.to(self._execution_device)
|
|
@@ -1228,15 +1081,7 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1228 |
generator=generator,
|
| 1229 |
vae_per_channel_normalize=vae_per_channel_normalize,
|
| 1230 |
)
|
| 1231 |
-
|
| 1232 |
-
try:
|
| 1233 |
-
print(f"[LTX4]LATENTS {latents.shape}")
|
| 1234 |
-
original_shape = latents
|
| 1235 |
-
except Exception:
|
| 1236 |
-
pass
|
| 1237 |
-
|
| 1238 |
-
|
| 1239 |
-
|
| 1240 |
# Update the latents with the conditioning items and patchify them into (b, n, c)
|
| 1241 |
latents, pixel_coords, conditioning_mask, num_cond_latents = (
|
| 1242 |
self.prepare_conditioning(
|
|
@@ -1251,33 +1096,9 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1251 |
)
|
| 1252 |
init_latents = latents.clone() # Used for image_cond_noise_update
|
| 1253 |
|
| 1254 |
-
try:
|
| 1255 |
-
print(f"[LTXCond]conditioning_mask {conditioning_mask.shape}")
|
| 1256 |
-
except Exception:
|
| 1257 |
-
pass
|
| 1258 |
-
|
| 1259 |
-
try:
|
| 1260 |
-
print(f"[LTXCond]pixel_coords {pixel_coords.shape}")
|
| 1261 |
-
except Exception:
|
| 1262 |
-
pass
|
| 1263 |
-
|
| 1264 |
-
try:
|
| 1265 |
-
print(f"[LTXCond]pixel_coords {pixel_coords.shape}")
|
| 1266 |
-
except Exception:
|
| 1267 |
-
pass
|
| 1268 |
-
|
| 1269 |
-
|
| 1270 |
-
|
| 1271 |
-
|
| 1272 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1273 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1274 |
|
| 1275 |
-
|
| 1276 |
-
try:
|
| 1277 |
-
print(f"[LTX5]LATENTS {latents.shape}")
|
| 1278 |
-
except Exception:
|
| 1279 |
-
pass
|
| 1280 |
-
|
| 1281 |
# 7. Denoising loop
|
| 1282 |
num_warmup_steps = max(
|
| 1283 |
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
|
@@ -1336,14 +1157,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1336 |
orig_conditioning_mask,
|
| 1337 |
generator,
|
| 1338 |
)
|
| 1339 |
-
|
| 1340 |
-
try:
|
| 1341 |
-
print(f"[LTX6]LATENTS {latents.shape}")
|
| 1342 |
-
self.spy.inspect(latents, "LTX6_After_Patchify", reference_shape_5d=original_shape)
|
| 1343 |
-
except Exception:
|
| 1344 |
-
pass
|
| 1345 |
-
|
| 1346 |
-
|
| 1347 |
|
| 1348 |
latent_model_input = (
|
| 1349 |
torch.cat([latents] * num_conds) if num_conds > 1 else latents
|
|
@@ -1351,12 +1164,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1351 |
latent_model_input = self.scheduler.scale_model_input(
|
| 1352 |
latent_model_input, t
|
| 1353 |
)
|
| 1354 |
-
|
| 1355 |
-
try:
|
| 1356 |
-
print(f"[LTX7]LATENTS {latent_model_input.shape}")
|
| 1357 |
-
self.spy.inspect(latents, "LTX7_After_Patchify", reference_shape_5d=original_shape)
|
| 1358 |
-
except Exception:
|
| 1359 |
-
pass
|
| 1360 |
|
| 1361 |
current_timestep = t
|
| 1362 |
if not torch.is_tensor(current_timestep):
|
|
@@ -1472,12 +1279,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1472 |
extra_step_kwargs,
|
| 1473 |
stochastic_sampling=stochastic_sampling,
|
| 1474 |
)
|
| 1475 |
-
|
| 1476 |
-
try:
|
| 1477 |
-
print(f"[LTX8]LATENTS {latents.shape}")
|
| 1478 |
-
self.spy.inspect(latents, "LTX8_After_Patchify", reference_shape_5d=original_shape)
|
| 1479 |
-
except Exception:
|
| 1480 |
-
pass
|
| 1481 |
|
| 1482 |
# call the callback, if provided
|
| 1483 |
if i == len(timesteps) - 1 or (
|
|
@@ -1488,16 +1289,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1488 |
if callback_on_step_end is not None:
|
| 1489 |
callback_on_step_end(self, i, t, {})
|
| 1490 |
|
| 1491 |
-
|
| 1492 |
-
|
| 1493 |
-
try:
|
| 1494 |
-
print(f"[LTX9]LATENTS {latents.shape}")
|
| 1495 |
-
self.spy.inspect(latents, "LTX9_After_Patchify", reference_shape_5d=original_shape)
|
| 1496 |
-
|
| 1497 |
-
except Exception:
|
| 1498 |
-
pass
|
| 1499 |
-
|
| 1500 |
-
|
| 1501 |
if offload_to_cpu:
|
| 1502 |
self.transformer = self.transformer.cpu()
|
| 1503 |
if self._execution_device == "cuda":
|
|
@@ -1505,13 +1296,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1505 |
|
| 1506 |
# Remove the added conditioning latents
|
| 1507 |
latents = latents[:, num_cond_latents:]
|
| 1508 |
-
|
| 1509 |
-
|
| 1510 |
-
try:
|
| 1511 |
-
print(f"[LTX10]LATENTS {latents.shape}")
|
| 1512 |
-
self.spy.inspect(latents, "LTX10_After_Patchify", reference_shape_5d=original_shape)
|
| 1513 |
-
except Exception:
|
| 1514 |
-
pass
|
| 1515 |
|
| 1516 |
latents = self.patchifier.unpatchify(
|
| 1517 |
latents=latents,
|
|
@@ -1547,11 +1331,6 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1547 |
vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
|
| 1548 |
timestep=decode_timestep,
|
| 1549 |
)
|
| 1550 |
-
|
| 1551 |
-
try:
|
| 1552 |
-
print(f"[LTX11]LATENTS {latents.shape}")
|
| 1553 |
-
except Exception:
|
| 1554 |
-
pass
|
| 1555 |
|
| 1556 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1557 |
|
|
@@ -1659,30 +1438,41 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1659 |
|
| 1660 |
# Process each conditioning item
|
| 1661 |
for conditioning_item in conditioning_items:
|
| 1662 |
-
|
| 1663 |
-
|
| 1664 |
-
|
| 1665 |
-
|
| 1666 |
-
|
| 1667 |
-
|
| 1668 |
-
|
| 1669 |
-
|
| 1670 |
-
|
| 1671 |
-
|
| 1672 |
-
|
| 1673 |
-
|
| 1674 |
-
|
| 1675 |
-
media_frame_number
|
| 1676 |
-
|
| 1677 |
-
|
| 1678 |
-
|
| 1679 |
-
|
| 1680 |
-
|
| 1681 |
-
|
| 1682 |
-
|
| 1683 |
-
|
| 1684 |
-
|
| 1685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1686 |
# Handle the different conditioning cases
|
| 1687 |
if media_frame_number == 0:
|
| 1688 |
# Get the target spatial position of the latent conditioning item
|
|
|
|
| 13 |
from diffusers.models import AutoencoderKL
|
| 14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 15 |
from diffusers.schedulers import DPMSolverMultistepScheduler
|
| 16 |
+
from diffusers.utils import deprecate, logging
|
| 17 |
from diffusers.utils.torch_utils import randn_tensor
|
| 18 |
from einops import rearrange
|
| 19 |
from transformers import (
|
|
|
|
| 24 |
AutoTokenizer,
|
| 25 |
)
|
| 26 |
|
|
|
|
| 27 |
from ltx_video.models.autoencoders.causal_video_autoencoder import (
|
| 28 |
CausalVideoAutoencoder,
|
| 29 |
)
|
|
|
|
| 44 |
normalize_latents,
|
| 45 |
)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 49 |
|
| 50 |
|
| 51 |
ASPECT_RATIO_1024_BIN = {
|
|
|
|
| 127 |
num_inference_steps: Optional[int] = None,
|
| 128 |
device: Optional[Union[str, torch.device]] = None,
|
| 129 |
timesteps: Optional[List[int]] = None,
|
| 130 |
+
skip_initial_inference_steps: int = 0,
|
| 131 |
+
skip_final_inference_steps: int = 0,
|
| 132 |
**kwargs,
|
| 133 |
):
|
| 134 |
"""
|
|
|
|
| 186 |
]
|
| 187 |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 188 |
num_inference_steps = len(timesteps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
return timesteps, num_inference_steps
|
| 191 |
|
|
|
|
| 232 |
scheduler ([`SchedulerMixin`]):
|
| 233 |
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 234 |
"""
|
|
|
|
|
|
|
| 235 |
|
| 236 |
bad_punct_regex = re.compile(
|
| 237 |
r"["
|
|
|
|
| 294 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 295 |
|
| 296 |
self.allowed_inference_steps = allowed_inference_steps
|
|
|
|
|
|
|
| 297 |
|
| 298 |
def mask_text_embeddings(self, emb, mask):
|
| 299 |
if emb.shape[0] == 1:
|
|
|
|
| 343 |
|
| 344 |
if "mask_feature" in kwargs:
|
| 345 |
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
| 346 |
+
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
| 347 |
|
| 348 |
if device is None:
|
| 349 |
device = self._execution_device
|
|
|
|
| 356 |
batch_size = prompt_embeds.shape[0]
|
| 357 |
|
| 358 |
# See Section 3.1. of the paper.
|
| 359 |
+
max_length = (
|
| 360 |
+
text_encoder_max_tokens # TPU supports only lengths multiple of 128
|
| 361 |
+
)
|
|
|
|
| 362 |
if prompt_embeds is None:
|
| 363 |
assert (
|
| 364 |
self.text_encoder is not None
|
|
|
|
| 384 |
removed_text = self.tokenizer.batch_decode(
|
| 385 |
untruncated_ids[:, max_length - 1 : -1]
|
| 386 |
)
|
| 387 |
+
logger.warning(
|
| 388 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 389 |
+
f" {max_length} tokens: {removed_text}"
|
| 390 |
+
)
|
| 391 |
|
| 392 |
prompt_attention_mask = text_inputs.attention_mask
|
| 393 |
prompt_attention_mask = prompt_attention_mask.to(text_enc_device)
|
|
|
|
| 875 |
tone_map_compression_ratio: compression ratio for tone mapping, defaults to 0.0.
|
| 876 |
If set to 0.0, no tone mapping is applied. If set to 1.0 - full compression is applied.
|
| 877 |
Examples:
|
| 878 |
+
|
| 879 |
Returns:
|
| 880 |
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
| 881 |
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
| 882 |
returned where the first element is a list with the generated images
|
| 883 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 884 |
if "mask_feature" in kwargs:
|
| 885 |
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
| 886 |
+
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
| 887 |
|
| 888 |
is_video = kwargs.get("is_video", False)
|
| 889 |
self.check_inputs(
|
|
|
|
| 948 |
skip_final_inference_steps=skip_final_inference_steps,
|
| 949 |
**retrieve_timesteps_kwargs,
|
| 950 |
)
|
| 951 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
if self.allowed_inference_steps is not None:
|
| 953 |
for timestep in [round(x, 4) for x in timesteps.tolist()]:
|
| 954 |
assert (
|
|
|
|
| 1016 |
conditioning_items,
|
| 1017 |
max_new_tokens=text_encoder_max_tokens,
|
| 1018 |
)
|
| 1019 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1020 |
# 3. Encode input prompt
|
| 1021 |
if self.text_encoder is not None:
|
| 1022 |
self.text_encoder = self.text_encoder.to(self._execution_device)
|
|
|
|
| 1081 |
generator=generator,
|
| 1082 |
vae_per_channel_normalize=vae_per_channel_normalize,
|
| 1083 |
)
|
| 1084 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1085 |
# Update the latents with the conditioning items and patchify them into (b, n, c)
|
| 1086 |
latents, pixel_coords, conditioning_mask, num_cond_latents = (
|
| 1087 |
self.prepare_conditioning(
|
|
|
|
| 1096 |
)
|
| 1097 |
init_latents = latents.clone() # Used for image_cond_noise_update
|
| 1098 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1099 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1100 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1102 |
# 7. Denoising loop
|
| 1103 |
num_warmup_steps = max(
|
| 1104 |
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
|
|
|
| 1157 |
orig_conditioning_mask,
|
| 1158 |
generator,
|
| 1159 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1160 |
|
| 1161 |
latent_model_input = (
|
| 1162 |
torch.cat([latents] * num_conds) if num_conds > 1 else latents
|
|
|
|
| 1164 |
latent_model_input = self.scheduler.scale_model_input(
|
| 1165 |
latent_model_input, t
|
| 1166 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1167 |
|
| 1168 |
current_timestep = t
|
| 1169 |
if not torch.is_tensor(current_timestep):
|
|
|
|
| 1279 |
extra_step_kwargs,
|
| 1280 |
stochastic_sampling=stochastic_sampling,
|
| 1281 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1282 |
|
| 1283 |
# call the callback, if provided
|
| 1284 |
if i == len(timesteps) - 1 or (
|
|
|
|
| 1289 |
if callback_on_step_end is not None:
|
| 1290 |
callback_on_step_end(self, i, t, {})
|
| 1291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1292 |
if offload_to_cpu:
|
| 1293 |
self.transformer = self.transformer.cpu()
|
| 1294 |
if self._execution_device == "cuda":
|
|
|
|
| 1296 |
|
| 1297 |
# Remove the added conditioning latents
|
| 1298 |
latents = latents[:, num_cond_latents:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1299 |
|
| 1300 |
latents = self.patchifier.unpatchify(
|
| 1301 |
latents=latents,
|
|
|
|
| 1331 |
vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
|
| 1332 |
timestep=decode_timestep,
|
| 1333 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1334 |
|
| 1335 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1336 |
|
|
|
|
| 1438 |
|
| 1439 |
# Process each conditioning item
|
| 1440 |
for conditioning_item in conditioning_items:
|
| 1441 |
+
|
| 1442 |
+
print(f"media_item_latents ini {conditioning_item.media_item.shape}")
|
| 1443 |
+
|
| 1444 |
+
c = conditioning_item.media_item.shape[1]
|
| 1445 |
+
if c == self.transformer.config.in_channels:
|
| 1446 |
+
media_item_latents = conditioning_item.media_item.to(dtype=init_latents_dtype, device=init_latents_device)
|
| 1447 |
+
strength = conditioning_item.conditioning_strength
|
| 1448 |
+
media_frame_number = conditioning_item.media_frame_number
|
| 1449 |
+
else:
|
| 1450 |
+
conditioning_item = self._resize_conditioning_item(
|
| 1451 |
+
conditioning_item, height, width
|
| 1452 |
+
)
|
| 1453 |
+
media_item = conditioning_item.media_item
|
| 1454 |
+
media_frame_number = conditioning_item.media_frame_number
|
| 1455 |
+
strength = conditioning_item.conditioning_strength
|
| 1456 |
+
assert media_item.ndim == 5 # (b, c, f, h, w)
|
| 1457 |
+
b, c, n_frames, h, w = media_item.shape
|
| 1458 |
+
assert (
|
| 1459 |
+
height == h and width == w
|
| 1460 |
+
) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
|
| 1461 |
+
assert n_frames % 8 == 1
|
| 1462 |
+
assert (
|
| 1463 |
+
media_frame_number >= 0
|
| 1464 |
+
and media_frame_number + n_frames <= num_frames
|
| 1465 |
+
)
|
| 1466 |
+
|
| 1467 |
+
# Encode the provided conditioning media item
|
| 1468 |
+
media_item_latents = vae_encode(
|
| 1469 |
+
media_item.to(dtype=self.vae.dtype, device=self.vae.device),
|
| 1470 |
+
self.vae,
|
| 1471 |
+
vae_per_channel_normalize=vae_per_channel_normalize,
|
| 1472 |
+
).to(dtype=init_latents.dtype)
|
| 1473 |
+
|
| 1474 |
+
print(f"media_item_latents encode vae {conditioning_item.media_item.shape}")
|
| 1475 |
+
|
| 1476 |
# Handle the different conditioning cases
|
| 1477 |
if media_frame_number == 0:
|
| 1478 |
# Get the target spatial position of the latent conditioning item
|