caarleexx commited on
Commit
f2a0118
·
verified ·
1 Parent(s): 8f69c0a

Upload pipeline_ltx_video.py

Browse files
LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py CHANGED
@@ -13,7 +13,7 @@ from diffusers.image_processor import VaeImageProcessor
13
  from diffusers.models import AutoencoderKL
14
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
15
  from diffusers.schedulers import DPMSolverMultistepScheduler
16
- #from diffusers.utils import deprecate, logging
17
  from diffusers.utils.torch_utils import randn_tensor
18
  from einops import rearrange
19
  from transformers import (
@@ -24,7 +24,6 @@ from transformers import (
24
  AutoTokenizer,
25
  )
26
 
27
-
28
  from ltx_video.models.autoencoders.causal_video_autoencoder import (
29
  CausalVideoAutoencoder,
30
  )
@@ -45,127 +44,8 @@ from ltx_video.models.autoencoders.vae_encode import (
45
  normalize_latents,
46
  )
47
 
48
- import warnings
49
- warnings.filterwarnings("ignore", category=UserWarning)
50
- warnings.filterwarnings("ignore", category=FutureWarning)
51
- warnings.filterwarnings("ignore", message=".*")
52
-
53
- from huggingface_hub import logging
54
-
55
- logging.set_verbosity_error()
56
- logging.set_verbosity_warning()
57
- logging.set_verbosity_info()
58
- logging.set_verbosity_debug()
59
-
60
-
61
- #logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
-
63
-
64
- class SpyLatent:
65
-
66
- """
67
- Uma classe para inspecionar tensores latentes em vários estágios de um pipeline.
68
- Imprime estatísticas e pode salvar visualizações decodificadas por um VAE.
69
- """
70
-
71
- import torch
72
- import os
73
- import traceback
74
- from einops import rearrange
75
- from torchvision.utils import save_image
76
-
77
- def __init__(self, vae=None, output_dir: str = "/app/output"):
78
- """
79
- Inicializa o espião.
80
-
81
- Args:
82
- vae: A instância do modelo VAE para decodificar os latentes. Se for None,
83
- a visualização será desativada.
84
- output_dir (str): O diretório padrão para salvar as imagens de visualização.
85
- """
86
- self.vae = vae
87
- self.output_dir = output_dir
88
- self.device = vae.device if hasattr(vae, 'device') else torch.device("cpu")
89
-
90
- if self.vae is None:
91
- print("[SpyLatent] AVISO: VAE não fornecido. A funcionalidade de visualização de imagem está desativada.")
92
-
93
- def inspect(
94
- self,
95
- tensor: torch.Tensor,
96
- tag: str,
97
- reference_shape_5d: tuple = None,
98
- save_visual: bool = True,
99
- ):
100
- """
101
- Inspeciona um tensor latente.
102
-
103
- Args:
104
- tensor (torch.Tensor): O tensor a ser inspecionado.
105
- tag (str): Um rótulo para identificar o ponto de inspeção nos logs.
106
- reference_shape_5d (tuple, optional): A forma 5D de referência (B, C, F, H, W)
107
- necessária se o tensor de entrada for 3D.
108
- save_visual (bool): Se True, decodifica com o VAE e salva uma imagem.
109
- """
110
- #print(f"\n--- [INSPEÇÃO DE LATENTE: {tag}] ---")
111
- #if not isinstance(tensor, torch.Tensor):
112
- # print(f" AVISO: O objeto fornecido para '{tag}' não é um tensor.")
113
- # print("--- [FIM DA INSPEÇÃO] ---\n")
114
- # return
115
-
116
- try:
117
- # --- Imprime Estatísticas do Tensor Original ---
118
- #self._print_stats("Tensor Original", tensor)
119
-
120
- # --- Converte para 5D se necessário ---
121
- tensor_5d = self._to_5d(tensor, reference_shape_5d)
122
- if tensor_5d is not None and tensor.ndim == 3:
123
- self._print_stats("Convertido para 5D", tensor_5d)
124
-
125
- # --- Visualização com VAE ---
126
- if save_visual and self.vae is not None and tensor_5d is not None:
127
- os.makedirs(self.output_dir, exist_ok=True)
128
- #print(f" VISUALIZAÇÃO (VAE): Salvando imagem em {self.output_dir}...")
129
-
130
- frame_idx_to_viz = min(1, tensor_5d.shape[2] - 1)
131
- if frame_idx_to_viz < 0:
132
- print(" VISUALIZAÇÃO (VAE): Tensor não tem frames para visualizar.")
133
- else:
134
- #print(f" VISUALIZAÇÃO (VAE): Usando frame de índice {frame_idx_to_viz}.")
135
- latent_slice = tensor_5d[:, :, frame_idx_to_viz:frame_idx_to_viz+1, :, :]
136
-
137
- with torch.no_grad(), torch.autocast(device_type=self.device.type):
138
- pixel_slice = self.vae.decode(latent_slice / self.vae.config.scaling_factor).sample
139
-
140
- save_image((pixel_slice / 2 + 0.5).clamp(0, 1), os.path.join(self.output_dir, f"inspect_{tag.lower()}.png"))
141
- print(" VISUALIZAÇÃO (VAE): Imagem salva.")
142
-
143
- except Exception as e:
144
- #print(f" ERRO na inspeção: {e}")
145
- traceback.print_exc()
146
-
147
- def _to_5d(self, tensor: torch.Tensor, shape_5d: tuple) -> torch.Tensor:
148
- """Converte um tensor 3D patchificado de volta para 5D."""
149
- if tensor.ndim == 5:
150
- return tensor
151
- if tensor.ndim == 3 and shape_5d:
152
- try:
153
- b, c, f, h, w = shape_5d
154
- return rearrange(tensor, "b (f h w) c -> b c f h w", c=c, f=f, h=h, w=w)
155
- except Exception as e:
156
- #print(f" AVISO: Erro ao rearranjar tensor 3D para 5D: {e}. A visualização pode falhar.")
157
- return None
158
- return None
159
-
160
- def _print_stats(self, prefix: str, tensor: torch.Tensor):
161
- """Helper para imprimir estatísticas de um tensor."""
162
- mean = tensor.mean().item()
163
- std = tensor.std().item()
164
- min_val = tensor.min().item()
165
- max_val = tensor.max().item()
166
- print(f" {prefix}: {tensor.shape}")
167
-
168
 
 
169
 
170
 
171
  ASPECT_RATIO_1024_BIN = {
@@ -247,8 +127,8 @@ def retrieve_timesteps(
247
  num_inference_steps: Optional[int] = None,
248
  device: Optional[Union[str, torch.device]] = None,
249
  timesteps: Optional[List[int]] = None,
250
- skip_initial_inference_steps: Optional[int] = 0,
251
- skip_final_inference_steps: Optional[int] = 0,
252
  **kwargs,
253
  ):
254
  """
@@ -306,12 +186,6 @@ def retrieve_timesteps(
306
  ]
307
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
308
  num_inference_steps = len(timesteps)
309
-
310
- try:
311
- print(f"[LTX]LATENTS {latents.shape}")
312
- except Exception:
313
- pass
314
-
315
 
316
  return timesteps, num_inference_steps
317
 
@@ -358,8 +232,6 @@ class LTXVideoPipeline(DiffusionPipeline):
358
  scheduler ([`SchedulerMixin`]):
359
  A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
360
  """
361
-
362
-
363
 
364
  bad_punct_regex = re.compile(
365
  r"["
@@ -422,8 +294,6 @@ class LTXVideoPipeline(DiffusionPipeline):
422
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
423
 
424
  self.allowed_inference_steps = allowed_inference_steps
425
-
426
- self.spy = SpyLatent(vae=vae)
427
 
428
  def mask_text_embeddings(self, emb, mask):
429
  if emb.shape[0] == 1:
@@ -473,7 +343,7 @@ class LTXVideoPipeline(DiffusionPipeline):
473
 
474
  if "mask_feature" in kwargs:
475
  deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
476
- #deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
477
 
478
  if device is None:
479
  device = self._execution_device
@@ -486,10 +356,9 @@ class LTXVideoPipeline(DiffusionPipeline):
486
  batch_size = prompt_embeds.shape[0]
487
 
488
  # See Section 3.1. of the paper.
489
- max_length = 256
490
- #(
491
- # text_encoder_max_tokens # TPU supports only lengths multiple of 128
492
- #)
493
  if prompt_embeds is None:
494
  assert (
495
  self.text_encoder is not None
@@ -515,10 +384,10 @@ class LTXVideoPipeline(DiffusionPipeline):
515
  removed_text = self.tokenizer.batch_decode(
516
  untruncated_ids[:, max_length - 1 : -1]
517
  )
518
- #logger.warning(
519
- # "The following part of your input was truncated because CLIP can only handle sequences up to"
520
- # f" {max_length} tokens: {removed_text}"
521
- #)
522
 
523
  prompt_attention_mask = text_inputs.attention_mask
524
  prompt_attention_mask = prompt_attention_mask.to(text_enc_device)
@@ -1006,21 +875,15 @@ class LTXVideoPipeline(DiffusionPipeline):
1006
  tone_map_compression_ratio: compression ratio for tone mapping, defaults to 0.0.
1007
  If set to 0.0, no tone mapping is applied. If set to 1.0 - full compression is applied.
1008
  Examples:
 
1009
  Returns:
1010
  [`~pipelines.ImagePipelineOutput`] or `tuple`:
1011
  If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
1012
  returned where the first element is a list with the generated images
1013
  """
1014
-
1015
- try:
1016
- print(f"[LTX]LATENTS {latents.shape}")
1017
- except Exception:
1018
- pass
1019
-
1020
-
1021
  if "mask_feature" in kwargs:
1022
  deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
1023
- #deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
1024
 
1025
  is_video = kwargs.get("is_video", False)
1026
  self.check_inputs(
@@ -1085,12 +948,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1085
  skip_final_inference_steps=skip_final_inference_steps,
1086
  **retrieve_timesteps_kwargs,
1087
  )
1088
-
1089
- try:
1090
- print(f"[LTX2]LATENTS {latents.shape}")
1091
- except Exception:
1092
- pass
1093
-
1094
  if self.allowed_inference_steps is not None:
1095
  for timestep in [round(x, 4) for x in timesteps.tolist()]:
1096
  assert (
@@ -1158,12 +1016,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1158
  conditioning_items,
1159
  max_new_tokens=text_encoder_max_tokens,
1160
  )
1161
-
1162
- try:
1163
- print(f"[LTX3]LATENTS {latents.shape}")
1164
- except Exception:
1165
- pass
1166
-
1167
  # 3. Encode input prompt
1168
  if self.text_encoder is not None:
1169
  self.text_encoder = self.text_encoder.to(self._execution_device)
@@ -1228,15 +1081,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1228
  generator=generator,
1229
  vae_per_channel_normalize=vae_per_channel_normalize,
1230
  )
1231
-
1232
- try:
1233
- print(f"[LTX4]LATENTS {latents.shape}")
1234
- original_shape = latents
1235
- except Exception:
1236
- pass
1237
-
1238
-
1239
-
1240
  # Update the latents with the conditioning items and patchify them into (b, n, c)
1241
  latents, pixel_coords, conditioning_mask, num_cond_latents = (
1242
  self.prepare_conditioning(
@@ -1251,33 +1096,9 @@ class LTXVideoPipeline(DiffusionPipeline):
1251
  )
1252
  init_latents = latents.clone() # Used for image_cond_noise_update
1253
 
1254
- try:
1255
- print(f"[LTXCond]conditioning_mask {conditioning_mask.shape}")
1256
- except Exception:
1257
- pass
1258
-
1259
- try:
1260
- print(f"[LTXCond]pixel_coords {pixel_coords.shape}")
1261
- except Exception:
1262
- pass
1263
-
1264
- try:
1265
- print(f"[LTXCond]pixel_coords {pixel_coords.shape}")
1266
- except Exception:
1267
- pass
1268
-
1269
-
1270
-
1271
-
1272
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1273
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1274
 
1275
-
1276
- try:
1277
- print(f"[LTX5]LATENTS {latents.shape}")
1278
- except Exception:
1279
- pass
1280
-
1281
  # 7. Denoising loop
1282
  num_warmup_steps = max(
1283
  len(timesteps) - num_inference_steps * self.scheduler.order, 0
@@ -1336,14 +1157,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1336
  orig_conditioning_mask,
1337
  generator,
1338
  )
1339
-
1340
- try:
1341
- print(f"[LTX6]LATENTS {latents.shape}")
1342
- self.spy.inspect(latents, "LTX6_After_Patchify", reference_shape_5d=original_shape)
1343
- except Exception:
1344
- pass
1345
-
1346
-
1347
 
1348
  latent_model_input = (
1349
  torch.cat([latents] * num_conds) if num_conds > 1 else latents
@@ -1351,12 +1164,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1351
  latent_model_input = self.scheduler.scale_model_input(
1352
  latent_model_input, t
1353
  )
1354
-
1355
- try:
1356
- print(f"[LTX7]LATENTS {latent_model_input.shape}")
1357
- self.spy.inspect(latents, "LTX7_After_Patchify", reference_shape_5d=original_shape)
1358
- except Exception:
1359
- pass
1360
 
1361
  current_timestep = t
1362
  if not torch.is_tensor(current_timestep):
@@ -1472,12 +1279,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1472
  extra_step_kwargs,
1473
  stochastic_sampling=stochastic_sampling,
1474
  )
1475
-
1476
- try:
1477
- print(f"[LTX8]LATENTS {latents.shape}")
1478
- self.spy.inspect(latents, "LTX8_After_Patchify", reference_shape_5d=original_shape)
1479
- except Exception:
1480
- pass
1481
 
1482
  # call the callback, if provided
1483
  if i == len(timesteps) - 1 or (
@@ -1488,16 +1289,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1488
  if callback_on_step_end is not None:
1489
  callback_on_step_end(self, i, t, {})
1490
 
1491
-
1492
-
1493
- try:
1494
- print(f"[LTX9]LATENTS {latents.shape}")
1495
- self.spy.inspect(latents, "LTX9_After_Patchify", reference_shape_5d=original_shape)
1496
-
1497
- except Exception:
1498
- pass
1499
-
1500
-
1501
  if offload_to_cpu:
1502
  self.transformer = self.transformer.cpu()
1503
  if self._execution_device == "cuda":
@@ -1505,13 +1296,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1505
 
1506
  # Remove the added conditioning latents
1507
  latents = latents[:, num_cond_latents:]
1508
-
1509
-
1510
- try:
1511
- print(f"[LTX10]LATENTS {latents.shape}")
1512
- self.spy.inspect(latents, "LTX10_After_Patchify", reference_shape_5d=original_shape)
1513
- except Exception:
1514
- pass
1515
 
1516
  latents = self.patchifier.unpatchify(
1517
  latents=latents,
@@ -1547,11 +1331,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1547
  vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
1548
  timestep=decode_timestep,
1549
  )
1550
-
1551
- try:
1552
- print(f"[LTX11]LATENTS {latents.shape}")
1553
- except Exception:
1554
- pass
1555
 
1556
  image = self.image_processor.postprocess(image, output_type=output_type)
1557
 
@@ -1659,30 +1438,41 @@ class LTXVideoPipeline(DiffusionPipeline):
1659
 
1660
  # Process each conditioning item
1661
  for conditioning_item in conditioning_items:
1662
- conditioning_item = self._resize_conditioning_item(
1663
- conditioning_item, height, width
1664
- )
1665
- media_item = conditioning_item.media_item
1666
- media_frame_number = conditioning_item.media_frame_number
1667
- strength = conditioning_item.conditioning_strength
1668
- assert media_item.ndim == 5 # (b, c, f, h, w)
1669
- b, c, n_frames, h, w = media_item.shape
1670
- assert (
1671
- height == h and width == w
1672
- ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
1673
- assert n_frames % 8 == 1
1674
- assert (
1675
- media_frame_number >= 0
1676
- and media_frame_number + n_frames <= num_frames
1677
- )
1678
-
1679
- # Encode the provided conditioning media item
1680
- media_item_latents = vae_encode(
1681
- media_item.to(dtype=self.vae.dtype, device=self.vae.device),
1682
- self.vae,
1683
- vae_per_channel_normalize=vae_per_channel_normalize,
1684
- ).to(dtype=init_latents.dtype)
1685
-
 
 
 
 
 
 
 
 
 
 
 
1686
  # Handle the different conditioning cases
1687
  if media_frame_number == 0:
1688
  # Get the target spatial position of the latent conditioning item
 
13
  from diffusers.models import AutoencoderKL
14
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
15
  from diffusers.schedulers import DPMSolverMultistepScheduler
16
+ from diffusers.utils import deprecate, logging
17
  from diffusers.utils.torch_utils import randn_tensor
18
  from einops import rearrange
19
  from transformers import (
 
24
  AutoTokenizer,
25
  )
26
 
 
27
  from ltx_video.models.autoencoders.causal_video_autoencoder import (
28
  CausalVideoAutoencoder,
29
  )
 
44
  normalize_latents,
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
 
50
 
51
  ASPECT_RATIO_1024_BIN = {
 
127
  num_inference_steps: Optional[int] = None,
128
  device: Optional[Union[str, torch.device]] = None,
129
  timesteps: Optional[List[int]] = None,
130
+ skip_initial_inference_steps: int = 0,
131
+ skip_final_inference_steps: int = 0,
132
  **kwargs,
133
  ):
134
  """
 
186
  ]
187
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
188
  num_inference_steps = len(timesteps)
 
 
 
 
 
 
189
 
190
  return timesteps, num_inference_steps
191
 
 
232
  scheduler ([`SchedulerMixin`]):
233
  A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
234
  """
 
 
235
 
236
  bad_punct_regex = re.compile(
237
  r"["
 
294
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
295
 
296
  self.allowed_inference_steps = allowed_inference_steps
 
 
297
 
298
  def mask_text_embeddings(self, emb, mask):
299
  if emb.shape[0] == 1:
 
343
 
344
  if "mask_feature" in kwargs:
345
  deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
346
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
347
 
348
  if device is None:
349
  device = self._execution_device
 
356
  batch_size = prompt_embeds.shape[0]
357
 
358
  # See Section 3.1. of the paper.
359
+ max_length = (
360
+ text_encoder_max_tokens # TPU supports only lengths multiple of 128
361
+ )
 
362
  if prompt_embeds is None:
363
  assert (
364
  self.text_encoder is not None
 
384
  removed_text = self.tokenizer.batch_decode(
385
  untruncated_ids[:, max_length - 1 : -1]
386
  )
387
+ logger.warning(
388
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
389
+ f" {max_length} tokens: {removed_text}"
390
+ )
391
 
392
  prompt_attention_mask = text_inputs.attention_mask
393
  prompt_attention_mask = prompt_attention_mask.to(text_enc_device)
 
875
  tone_map_compression_ratio: compression ratio for tone mapping, defaults to 0.0.
876
  If set to 0.0, no tone mapping is applied. If set to 1.0 - full compression is applied.
877
  Examples:
878
+
879
  Returns:
880
  [`~pipelines.ImagePipelineOutput`] or `tuple`:
881
  If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
882
  returned where the first element is a list with the generated images
883
  """
 
 
 
 
 
 
 
884
  if "mask_feature" in kwargs:
885
  deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
886
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
887
 
888
  is_video = kwargs.get("is_video", False)
889
  self.check_inputs(
 
948
  skip_final_inference_steps=skip_final_inference_steps,
949
  **retrieve_timesteps_kwargs,
950
  )
951
+
 
 
 
 
 
952
  if self.allowed_inference_steps is not None:
953
  for timestep in [round(x, 4) for x in timesteps.tolist()]:
954
  assert (
 
1016
  conditioning_items,
1017
  max_new_tokens=text_encoder_max_tokens,
1018
  )
1019
+
 
 
 
 
 
1020
  # 3. Encode input prompt
1021
  if self.text_encoder is not None:
1022
  self.text_encoder = self.text_encoder.to(self._execution_device)
 
1081
  generator=generator,
1082
  vae_per_channel_normalize=vae_per_channel_normalize,
1083
  )
1084
+
 
 
 
 
 
 
 
 
1085
  # Update the latents with the conditioning items and patchify them into (b, n, c)
1086
  latents, pixel_coords, conditioning_mask, num_cond_latents = (
1087
  self.prepare_conditioning(
 
1096
  )
1097
  init_latents = latents.clone() # Used for image_cond_noise_update
1098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1099
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1100
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1101
 
 
 
 
 
 
 
1102
  # 7. Denoising loop
1103
  num_warmup_steps = max(
1104
  len(timesteps) - num_inference_steps * self.scheduler.order, 0
 
1157
  orig_conditioning_mask,
1158
  generator,
1159
  )
 
 
 
 
 
 
 
 
1160
 
1161
  latent_model_input = (
1162
  torch.cat([latents] * num_conds) if num_conds > 1 else latents
 
1164
  latent_model_input = self.scheduler.scale_model_input(
1165
  latent_model_input, t
1166
  )
 
 
 
 
 
 
1167
 
1168
  current_timestep = t
1169
  if not torch.is_tensor(current_timestep):
 
1279
  extra_step_kwargs,
1280
  stochastic_sampling=stochastic_sampling,
1281
  )
 
 
 
 
 
 
1282
 
1283
  # call the callback, if provided
1284
  if i == len(timesteps) - 1 or (
 
1289
  if callback_on_step_end is not None:
1290
  callback_on_step_end(self, i, t, {})
1291
 
 
 
 
 
 
 
 
 
 
 
1292
  if offload_to_cpu:
1293
  self.transformer = self.transformer.cpu()
1294
  if self._execution_device == "cuda":
 
1296
 
1297
  # Remove the added conditioning latents
1298
  latents = latents[:, num_cond_latents:]
 
 
 
 
 
 
 
1299
 
1300
  latents = self.patchifier.unpatchify(
1301
  latents=latents,
 
1331
  vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
1332
  timestep=decode_timestep,
1333
  )
 
 
 
 
 
1334
 
1335
  image = self.image_processor.postprocess(image, output_type=output_type)
1336
 
 
1438
 
1439
  # Process each conditioning item
1440
  for conditioning_item in conditioning_items:
1441
+
1442
+ print(f"media_item_latents ini {conditioning_item.media_item.shape}")
1443
+
1444
+ c = conditioning_item.media_item.shape[1]
1445
+ if c == self.transformer.config.in_channels:
1446
+ media_item_latents = conditioning_item.media_item.to(dtype=init_latents_dtype, device=init_latents_device)
1447
+ strength = conditioning_item.conditioning_strength
1448
+ media_frame_number = conditioning_item.media_frame_number
1449
+ else:
1450
+ conditioning_item = self._resize_conditioning_item(
1451
+ conditioning_item, height, width
1452
+ )
1453
+ media_item = conditioning_item.media_item
1454
+ media_frame_number = conditioning_item.media_frame_number
1455
+ strength = conditioning_item.conditioning_strength
1456
+ assert media_item.ndim == 5 # (b, c, f, h, w)
1457
+ b, c, n_frames, h, w = media_item.shape
1458
+ assert (
1459
+ height == h and width == w
1460
+ ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
1461
+ assert n_frames % 8 == 1
1462
+ assert (
1463
+ media_frame_number >= 0
1464
+ and media_frame_number + n_frames <= num_frames
1465
+ )
1466
+
1467
+ # Encode the provided conditioning media item
1468
+ media_item_latents = vae_encode(
1469
+ media_item.to(dtype=self.vae.dtype, device=self.vae.device),
1470
+ self.vae,
1471
+ vae_per_channel_normalize=vae_per_channel_normalize,
1472
+ ).to(dtype=init_latents.dtype)
1473
+
1474
+ print(f"media_item_latents encode vae {conditioning_item.media_item.shape}")
1475
+
1476
  # Handle the different conditioning cases
1477
  if media_frame_number == 0:
1478
  # Get the target spatial position of the latent conditioning item