Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from typing import Any, Dict, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin | |
| from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers | |
| from diffusers.utils.torch_utils import maybe_allow_in_graph | |
| from diffusers.models.attention import FeedForward | |
| from diffusers.models.attention_processor import Attention | |
| from diffusers.models.cache_utils import CacheMixin | |
| from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed | |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
| from diffusers.models.modeling_utils import ModelMixin | |
| try: | |
| from sageattention import sageattn | |
| except ImportError: | |
| sageattn = None | |
| class FP32LayerNorm(nn.LayerNorm): | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| return F.layer_norm( | |
| inputs, | |
| self.normalized_shape, | |
| self.weight if self.weight is not None else None, | |
| self.bias if self.bias is not None else None, | |
| self.eps, | |
| ).to(inputs.dtype) | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class WanAttnProcessor2_0: | |
| def __init__(self): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| rotary_emb: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| encoder_hidden_states_img = None | |
| if attn.add_k_proj is not None: | |
| # 512 is the context length of the text encoder, hardcoded for now | |
| image_context_length = encoder_hidden_states.shape[1] - 512 | |
| encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] | |
| encoder_hidden_states = encoder_hidden_states[:, image_context_length:] | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query).to(hidden_states.dtype) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key).to(hidden_states.dtype) | |
| query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| if rotary_emb is not None: | |
| def apply_rotary_emb( | |
| hidden_states: torch.Tensor, | |
| freqs_cos: torch.Tensor, | |
| freqs_sin: torch.Tensor, | |
| ): | |
| x = hidden_states.view(*hidden_states.shape[:-1], -1, 2) | |
| x1, x2 = x[..., 0], x[..., 1] | |
| cos = freqs_cos[..., 0::2] | |
| sin = freqs_sin[..., 1::2] | |
| out = torch.empty_like(hidden_states) | |
| out[..., 0::2] = x1 * cos - x2 * sin | |
| out[..., 1::2] = x1 * sin + x2 * cos | |
| return out.type_as(hidden_states) | |
| query = apply_rotary_emb(query, *rotary_emb) | |
| key = apply_rotary_emb(key, *rotary_emb) | |
| # I2V task | |
| hidden_states_img = None | |
| if encoder_hidden_states_img is not None: | |
| key_img = attn.add_k_proj(encoder_hidden_states_img) | |
| key_img = attn.norm_added_k(key_img) | |
| value_img = attn.add_v_proj(encoder_hidden_states_img) | |
| key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| if sageattn is not None: | |
| # Ensure kernels receive fp16/bf16 tensors under autocast | |
| if torch.is_autocast_enabled() and query.dtype not in (torch.float16, torch.bfloat16): | |
| target_dtype = torch.bfloat16 | |
| query = query.to(target_dtype) | |
| key_img = key_img.to(target_dtype) | |
| value_img = value_img.to(target_dtype) | |
| hidden_states_img = sageattn( | |
| query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False | |
| ) | |
| else: | |
| hidden_states_img = F.scaled_dot_product_attention( | |
| query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) | |
| hidden_states_img = hidden_states_img.type_as(query) | |
| if sageattn is not None: | |
| # print(query.dtype) | |
| # Ensure kernels receive fp16/bf16 tensors under autocast | |
| if torch.is_autocast_enabled() and query.dtype not in (torch.float16, torch.bfloat16): | |
| target_dtype = torch.bfloat16 | |
| query = query.to(target_dtype) | |
| key = key.to(target_dtype) | |
| value = value.to(target_dtype) | |
| hidden_states = sageattn( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| else: | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) | |
| hidden_states = hidden_states.type_as(query) | |
| if hidden_states_img is not None: | |
| hidden_states = hidden_states + hidden_states_img | |
| hidden_states = attn.to_out[0](hidden_states) | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states | |
| class WanImageEmbedding(torch.nn.Module): | |
| def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): | |
| super().__init__() | |
| self.norm1 = FP32LayerNorm(in_features) | |
| self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") | |
| self.norm2 = FP32LayerNorm(out_features) | |
| if pos_embed_seq_len is not None: | |
| self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) | |
| else: | |
| self.pos_embed = None | |
| def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: | |
| if self.pos_embed is not None: | |
| batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape | |
| encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim) | |
| encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed | |
| hidden_states = self.norm1(encoder_hidden_states_image) | |
| hidden_states = self.ff(hidden_states) | |
| hidden_states = self.norm2(hidden_states) | |
| return hidden_states | |
| class WanTimeTextImageEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| time_freq_dim: int, | |
| time_proj_dim: int, | |
| text_embed_dim: int, | |
| image_embed_dim: Optional[int] = None, | |
| pos_embed_seq_len: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) | |
| self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) | |
| self.act_fn = nn.SiLU() | |
| self.time_proj = nn.Linear(dim, time_proj_dim) | |
| self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") | |
| self.image_embedder = None | |
| if image_embed_dim is not None: | |
| self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len) | |
| def forward( | |
| self, | |
| timestep: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_hidden_states_image: Optional[torch.Tensor] = None, | |
| timestep_seq_len: Optional[int] = None, | |
| ): | |
| timestep = self.timesteps_proj(timestep) | |
| if timestep_seq_len is not None: | |
| timestep = timestep.unflatten(0, (1, timestep_seq_len)) | |
| time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype | |
| if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: | |
| timestep = timestep.to(time_embedder_dtype) | |
| temb = self.time_embedder(timestep).type_as(encoder_hidden_states) | |
| timestep_proj = self.time_proj(self.act_fn(temb)) | |
| encoder_hidden_states = self.text_embedder(encoder_hidden_states) | |
| if encoder_hidden_states_image is not None: | |
| encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) | |
| return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image | |
| class WanRotaryPosEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| attention_head_dim: int, | |
| patch_size: Tuple[int, int, int], | |
| max_seq_len: int, | |
| theta: float = 10000.0, | |
| ): | |
| super().__init__() | |
| self.attention_head_dim = attention_head_dim | |
| self.patch_size = patch_size | |
| self.max_seq_len = max_seq_len | |
| h_dim = w_dim = 2 * (attention_head_dim // 6) | |
| t_dim = attention_head_dim - h_dim - w_dim | |
| freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 | |
| freqs_cos = [] | |
| freqs_sin = [] | |
| for dim in [t_dim, h_dim, w_dim]: | |
| freq_cos, freq_sin = get_1d_rotary_pos_embed( | |
| dim, | |
| max_seq_len, | |
| theta, | |
| use_real=True, | |
| repeat_interleave_real=True, | |
| freqs_dtype=freqs_dtype, | |
| ) | |
| freqs_cos.append(freq_cos) | |
| freqs_sin.append(freq_sin) | |
| self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) | |
| self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| batch_size, num_channels, num_frames, height, width = hidden_states.shape | |
| p_t, p_h, p_w = self.patch_size | |
| ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w | |
| split_sizes = [ | |
| self.attention_head_dim - 2 * (self.attention_head_dim // 3), | |
| self.attention_head_dim // 3, | |
| self.attention_head_dim // 3, | |
| ] | |
| freqs_cos = self.freqs_cos.split(split_sizes, dim=1) | |
| freqs_sin = self.freqs_sin.split(split_sizes, dim=1) | |
| freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) | |
| freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) | |
| freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) | |
| freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) | |
| freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) | |
| freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) | |
| freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) | |
| freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) | |
| return freqs_cos, freqs_sin | |
| class WanTransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| ffn_dim: int, | |
| num_heads: int, | |
| qk_norm: str = "rms_norm_across_heads", | |
| cross_attn_norm: bool = False, | |
| eps: float = 1e-6, | |
| added_kv_proj_dim: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| # 1. Self-attention | |
| self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) | |
| self.attn1 = Attention( | |
| query_dim=dim, | |
| heads=num_heads, | |
| kv_heads=num_heads, | |
| dim_head=dim // num_heads, | |
| qk_norm=qk_norm, | |
| eps=eps, | |
| bias=True, | |
| cross_attention_dim=None, | |
| out_bias=True, | |
| processor=WanAttnProcessor2_0(), | |
| ) | |
| # 2. Cross-attention | |
| self.attn2 = Attention( | |
| query_dim=dim, | |
| heads=num_heads, | |
| kv_heads=num_heads, | |
| dim_head=dim // num_heads, | |
| qk_norm=qk_norm, | |
| eps=eps, | |
| bias=True, | |
| cross_attention_dim=None, | |
| out_bias=True, | |
| added_kv_proj_dim=added_kv_proj_dim, | |
| added_proj_bias=True, | |
| processor=WanAttnProcessor2_0(), | |
| ) | |
| self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() | |
| # 3. Feed-forward | |
| self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") | |
| self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) | |
| self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| temb: torch.Tensor, | |
| rotary_emb: torch.Tensor, | |
| ) -> torch.Tensor: | |
| if temb.ndim == 4: | |
| # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) | |
| shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( | |
| self.scale_shift_table.unsqueeze(0) + temb | |
| ).chunk(6, dim=2) | |
| # batch_size, seq_len, 1, inner_dim | |
| shift_msa = shift_msa.squeeze(2) | |
| scale_msa = scale_msa.squeeze(2) | |
| gate_msa = gate_msa.squeeze(2) | |
| c_shift_msa = c_shift_msa.squeeze(2) | |
| c_scale_msa = c_scale_msa.squeeze(2) | |
| c_gate_msa = c_gate_msa.squeeze(2) | |
| else: | |
| # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) | |
| shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( | |
| self.scale_shift_table + temb | |
| ).chunk(6, dim=1) | |
| # print(hidden_states.dtype) | |
| # 1. Self-attention | |
| norm_hidden_states = (self.norm1(hidden_states).mul_(1 + scale_msa).add_(shift_msa)) | |
| attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) | |
| hidden_states += attn_output * gate_msa | |
| # hidden_states = hidden_states.type_as(hidden_states) | |
| # print(hidden_states.dtype) | |
| # 2. Cross-attention | |
| norm_hidden_states = self.norm2(hidden_states) | |
| attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) | |
| hidden_states += attn_output | |
| # print(hidden_states.dtype) | |
| # 3. Feed-forward | |
| norm_hidden_states = (self.norm3(hidden_states).mul_(1 + c_scale_msa).add_(c_shift_msa)) | |
| ff_output = self.ffn(norm_hidden_states) | |
| hidden_states += ff_output.mul_(c_gate_msa) | |
| # hidden_states = hidden_states.type_as(hidden_states) | |
| return hidden_states | |
| class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): | |
| r""" | |
| A Transformer model for video-like data used in the Wan model. | |
| Args: | |
| patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): | |
| 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). | |
| num_attention_heads (`int`, defaults to `40`): | |
| Fixed length for text embeddings. | |
| attention_head_dim (`int`, defaults to `128`): | |
| The number of channels in each head. | |
| in_channels (`int`, defaults to `16`): | |
| The number of channels in the input. | |
| out_channels (`int`, defaults to `16`): | |
| The number of channels in the output. | |
| text_dim (`int`, defaults to `512`): | |
| Input dimension for text embeddings. | |
| freq_dim (`int`, defaults to `256`): | |
| Dimension for sinusoidal time embeddings. | |
| ffn_dim (`int`, defaults to `13824`): | |
| Intermediate dimension in feed-forward network. | |
| num_layers (`int`, defaults to `40`): | |
| The number of layers of transformer blocks to use. | |
| window_size (`Tuple[int]`, defaults to `(-1, -1)`): | |
| Window size for local attention (-1 indicates global attention). | |
| cross_attn_norm (`bool`, defaults to `True`): | |
| Enable cross-attention normalization. | |
| qk_norm (`bool`, defaults to `True`): | |
| Enable query/key normalization. | |
| eps (`float`, defaults to `1e-6`): | |
| Epsilon value for normalization layers. | |
| add_img_emb (`bool`, defaults to `False`): | |
| Whether to use img_emb. | |
| added_kv_proj_dim (`int`, *optional*, defaults to `None`): | |
| The number of channels to use for the added key and value projections. If `None`, no projection is used. | |
| """ | |
| _supports_gradient_checkpointing = True | |
| _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] | |
| _no_split_modules = ["WanTransformerBlock"] | |
| _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] | |
| _keys_to_ignore_on_load_unexpected = ["norm_added_q"] | |
| _repeated_blocks = ["WanTransformerBlock"] | |
| def __init__( | |
| self, | |
| patch_size: Tuple[int] = (1, 2, 2), | |
| num_attention_heads: int = 40, | |
| attention_head_dim: int = 128, | |
| in_channels: int = 16, | |
| out_channels: int = 16, | |
| text_dim: int = 4096, | |
| freq_dim: int = 256, | |
| ffn_dim: int = 13824, | |
| num_layers: int = 40, | |
| cross_attn_norm: bool = True, | |
| qk_norm: Optional[str] = "rms_norm_across_heads", | |
| eps: float = 1e-6, | |
| image_dim: Optional[int] = None, | |
| added_kv_proj_dim: Optional[int] = None, | |
| rope_max_seq_len: int = 1024, | |
| pos_embed_seq_len: Optional[int] = None, | |
| ) -> None: | |
| super().__init__() | |
| inner_dim = num_attention_heads * attention_head_dim | |
| out_channels = out_channels or in_channels | |
| # 1. Patch & position embedding | |
| self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) | |
| self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) | |
| # 2. Condition embeddings | |
| # image_embedding_dim=1280 for I2V model | |
| self.condition_embedder = WanTimeTextImageEmbedding( | |
| dim=inner_dim, | |
| time_freq_dim=freq_dim, | |
| time_proj_dim=inner_dim * 6, | |
| text_embed_dim=text_dim, | |
| image_embed_dim=image_dim, | |
| pos_embed_seq_len=pos_embed_seq_len, | |
| ) | |
| # 3. Transformer blocks | |
| self.blocks = nn.ModuleList( | |
| [ | |
| WanTransformerBlock( | |
| inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| # 4. Output norm & projection | |
| self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) | |
| self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) | |
| self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| timestep: torch.LongTensor, | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_hidden_states_image: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| attention_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: | |
| if attention_kwargs is not None: | |
| attention_kwargs = attention_kwargs.copy() | |
| lora_scale = attention_kwargs.pop("scale", 1.0) | |
| else: | |
| lora_scale = 1.0 | |
| if USE_PEFT_BACKEND: | |
| # weight the lora layers by setting `lora_scale` for each PEFT layer | |
| scale_lora_layers(self, lora_scale) | |
| else: | |
| if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: | |
| logger.warning( | |
| "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." | |
| ) | |
| batch_size, num_channels, num_frames, height, width = hidden_states.shape | |
| p_t, p_h, p_w = self.config.patch_size | |
| post_patch_num_frames = num_frames // p_t | |
| post_patch_height = height // p_h | |
| post_patch_width = width // p_w | |
| rotary_emb = self.rope(hidden_states) | |
| hidden_states = self.patch_embedding(hidden_states) | |
| hidden_states = hidden_states.flatten(2).transpose(1, 2) | |
| # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) | |
| if timestep.ndim == 2: | |
| ts_seq_len = timestep.shape[1] | |
| timestep = timestep.flatten() # batch_size * seq_len | |
| else: | |
| ts_seq_len = None | |
| temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( | |
| timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len | |
| ) | |
| if ts_seq_len is not None: | |
| # batch_size, seq_len, 6, inner_dim | |
| timestep_proj = timestep_proj.unflatten(2, (6, -1)) | |
| else: | |
| # batch_size, 6, inner_dim | |
| timestep_proj = timestep_proj.unflatten(1, (6, -1)) | |
| if encoder_hidden_states_image is not None: | |
| encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) | |
| if True: | |
| encoder_hidden_states = encoder_hidden_states.to(torch.bfloat16) | |
| timestep_proj = timestep_proj.to(torch.bfloat16) | |
| rotary_emb = [rotary_emb[0].to(torch.bfloat16), rotary_emb[1].to(torch.bfloat16)] | |
| hidden_states = hidden_states.to(torch.bfloat16) | |
| # 4. Transformer blocks | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| for block in self.blocks: | |
| hidden_states = self._gradient_checkpointing_func( | |
| block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb | |
| ) | |
| else: | |
| for block in self.blocks: | |
| hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) | |
| # 5. Output norm, projection & unpatchify | |
| if temb.ndim == 3: | |
| # batch_size, seq_len, inner_dim (wan 2.2 ti2v) | |
| shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) | |
| shift = shift.squeeze(2) | |
| scale = scale.squeeze(2) | |
| else: | |
| # batch_size, inner_dim | |
| shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) | |
| # Move the shift and scale tensors to the same device as hidden_states. | |
| # When using multi-GPU inference via accelerate these will be on the | |
| # first device rather than the last device, which hidden_states ends up | |
| # on. | |
| shift = shift.to(hidden_states.device) | |
| scale = scale.to(hidden_states.device) | |
| hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states) | |
| hidden_states = self.proj_out(hidden_states) | |
| hidden_states = hidden_states.reshape( | |
| batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 | |
| ) | |
| hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) | |
| output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) | |
| if USE_PEFT_BACKEND: | |
| # remove `lora_scale` from each PEFT layer | |
| unscale_lora_layers(self, lora_scale) | |
| if not return_dict: | |
| return (output,) | |
| return Transformer2DModelOutput(sample=output) |