FlashWorld-ZeroGPU / models /transformer_wan.py
imlixinyang's picture
add app!
c8df52d
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import Attention
from diffusers.models.cache_utils import CacheMixin
from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
try:
from sageattention import sageattn
except ImportError:
sageattn = None
class FP32LayerNorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
inputs,
self.normalized_shape,
self.weight if self.weight is not None else None,
self.bias if self.bias is not None else None,
self.eps,
).to(inputs.dtype)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query).to(hidden_states.dtype)
if attn.norm_k is not None:
key = attn.norm_k(key).to(hidden_states.dtype)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if rotary_emb is not None:
def apply_rotary_emb(
hidden_states: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
x1, x2 = x[..., 0], x[..., 1]
cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
out = torch.empty_like(hidden_states)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out.type_as(hidden_states)
query = apply_rotary_emb(query, *rotary_emb)
key = apply_rotary_emb(key, *rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if sageattn is not None:
# Ensure kernels receive fp16/bf16 tensors under autocast
if torch.is_autocast_enabled() and query.dtype not in (torch.float16, torch.bfloat16):
target_dtype = torch.bfloat16
query = query.to(target_dtype)
key_img = key_img.to(target_dtype)
value_img = value_img.to(target_dtype)
hidden_states_img = sageattn(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
else:
hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
if sageattn is not None:
# print(query.dtype)
# Ensure kernels receive fp16/bf16 tensors under autocast
if torch.is_autocast_enabled() and query.dtype not in (torch.float16, torch.bfloat16):
target_dtype = torch.bfloat16
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
hidden_states = sageattn(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
self.pos_embed = None
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
return hidden_states
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
timestep_seq_len: Optional[int] = None,
):
timestep = self.timesteps_proj(timestep)
if timestep_seq_len is not None:
timestep = timestep.unflatten(0, (1, timestep_seq_len))
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
class WanRotaryPosEmbed(nn.Module):
def __init__(
self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int,
theta: float = 10000.0,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.max_seq_len = max_seq_len
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
freqs_sin = []
for dim in [t_dim, h_dim, w_dim]:
freq_cos, freq_sin = get_1d_rotary_pos_embed(
dim,
max_seq_len,
theta,
use_real=True,
repeat_interleave_real=True,
freqs_dtype=freqs_dtype,
)
freqs_cos.append(freq_cos)
freqs_sin.append(freq_sin)
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class WanTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: Optional[int] = None,
):
super().__init__()
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = Attention(
query_dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
processor=WanAttnProcessor2_0(),
)
# 2. Cross-attention
self.attn2 = Attention(
query_dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
added_proj_bias=True,
processor=WanAttnProcessor2_0(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
if temb.ndim == 4:
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table.unsqueeze(0) + temb
).chunk(6, dim=2)
# batch_size, seq_len, 1, inner_dim
shift_msa = shift_msa.squeeze(2)
scale_msa = scale_msa.squeeze(2)
gate_msa = gate_msa.squeeze(2)
c_shift_msa = c_shift_msa.squeeze(2)
c_scale_msa = c_scale_msa.squeeze(2)
c_gate_msa = c_gate_msa.squeeze(2)
else:
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb
).chunk(6, dim=1)
# print(hidden_states.dtype)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states).mul_(1 + scale_msa).add_(shift_msa))
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
hidden_states += attn_output * gate_msa
# hidden_states = hidden_states.type_as(hidden_states)
# print(hidden_states.dtype)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states += attn_output
# print(hidden_states.dtype)
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states).mul_(1 + c_scale_msa).add_(c_shift_msa))
ff_output = self.ffn(norm_hidden_states)
hidden_states += ff_output.mul_(c_gate_msa)
# hidden_states = hidden_states.type_as(hidden_states)
return hidden_states
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in the Wan model.
Args:
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
num_attention_heads (`int`, defaults to `40`):
Fixed length for text embeddings.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_dim (`int`, defaults to `512`):
Input dimension for text embeddings.
freq_dim (`int`, defaults to `256`):
Dimension for sinusoidal time embeddings.
ffn_dim (`int`, defaults to `13824`):
Intermediate dimension in feed-forward network.
num_layers (`int`, defaults to `40`):
The number of layers of transformer blocks to use.
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
Window size for local attention (-1 indicates global attention).
cross_attn_norm (`bool`, defaults to `True`):
Enable cross-attention normalization.
qk_norm (`bool`, defaults to `True`):
Enable query/key normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
add_img_emb (`bool`, defaults to `False`):
Whether to use img_emb.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 16,
out_channels: int = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 13824,
num_layers: int = 40,
cross_attn_norm: bool = True,
qk_norm: Optional[str] = "rms_norm_across_heads",
eps: float = 1e-6,
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: Optional[int] = None,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
# 2. Condition embeddings
# image_embedding_dim=1280 for I2V model
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
# 3. Transformer blocks
self.blocks = nn.ModuleList(
[
WanTransformerBlock(
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
)
for _ in range(num_layers)
]
)
# 4. Output norm & projection
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
rotary_emb = self.rope(hidden_states)
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
if timestep.ndim == 2:
ts_seq_len = timestep.shape[1]
timestep = timestep.flatten() # batch_size * seq_len
else:
ts_seq_len = None
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
)
if ts_seq_len is not None:
# batch_size, seq_len, 6, inner_dim
timestep_proj = timestep_proj.unflatten(2, (6, -1))
else:
# batch_size, 6, inner_dim
timestep_proj = timestep_proj.unflatten(1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
if True:
encoder_hidden_states = encoder_hidden_states.to(torch.bfloat16)
timestep_proj = timestep_proj.to(torch.bfloat16)
rotary_emb = [rotary_emb[0].to(torch.bfloat16), rotary_emb[1].to(torch.bfloat16)]
hidden_states = hidden_states.to(torch.bfloat16)
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
)
else:
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
# 5. Output norm, projection & unpatchify
if temb.ndim == 3:
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
scale = scale.squeeze(2)
else:
# batch_size, inner_dim
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)