import clip import numpy as np import torch import torch.nn as nn from model.config import MDMConfig from model.rotation2xyz import Rotation2xyz class MDM(nn.Module): @classmethod def from_config(cls, config: MDMConfig): """ Instantiate MDM from an MDMConfig object. """ return cls( modeltype=config.modeltype, njoints=config.njoints, nfeats=config.nfeats, num_actions=config.num_actions, translation=config.translation, pose_rep=config.pose_rep, glob=config.glob, glob_rot=config.glob_rot, latent_dim=config.latent_dim, ff_size=config.ff_size, num_layers=config.layers, num_heads=config.num_heads, dropout=config.dropout, ablation=config.ablation, activation=config.activation, legacy=config.legacy, data_rep=config.data_rep, dataset=config.dataset, clip_dim=config.clip_dim, arch=config.arch, emb_trans_dec=config.emb_trans_dec, clip_version=config.clip_version, action_emb=config.action_emb, normalize_encoder_output=config.normalize_encoder_output, cond_mask_prob=config.cond_mask_prob, mask_frames=config.mask_frames, emb_policy=config.emb_policy, pos_embed_max_len=config.pos_embed_max_len, pred_len=config.pred_len, context_len=config.context_len, all_goal_joint_names=config.all_goal_joint_names, multi_target_cond=config.multi_target_cond, multi_encoder_type=config.multi_encoder_type, target_enc_layers=config.target_enc_layers, ) def __init__( self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot, latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, ablation=None, activation="gelu", legacy=False, data_rep="rot6d", dataset="amass", clip_dim=512, arch="trans_enc", emb_trans_dec=False, clip_version=None, **kargs, ): super().__init__() self.legacy = legacy self.modeltype = modeltype self.njoints = njoints self.nfeats = nfeats self.num_actions = num_actions self.data_rep = data_rep self.dataset = dataset self.pose_rep = pose_rep self.glob = glob self.glob_rot = glob_rot self.translation = translation self.latent_dim = latent_dim self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation self.clip_dim = clip_dim self.action_emb = kargs.get("action_emb", None) self.input_feats = self.njoints * self.nfeats self.normalize_output = kargs.get("normalize_encoder_output", False) self.cond_mode = kargs.get("cond_mode", "no_cond") self.cond_mask_prob = kargs.get("cond_mask_prob", 0.0) self.mask_frames = kargs.get("mask_frames", False) self.arch = arch self.emb_policy = kargs.get("emb_policy", "add") self.pred_len = kargs.get("pred_len", 0) self.context_len = kargs.get("context_len", 0) self.total_len = self.pred_len + self.context_len self.is_prefix_comp = self.total_len > 0 self.all_goal_joint_names = kargs.get("all_goal_joint_names", []) self.multi_target_cond = kargs.get("multi_target_cond", False) self.text_encoder_type = kargs.get("text_encoder_type", "clip") # Assert some assumptions we're doing for simplicity assert self.arch == "trans_enc" assert self.cond_mode == "text" assert self.text_encoder_type == "clip" assert not self.multi_target_cond assert not self.is_prefix_comp assert self.emb_policy == "add" assert self.data_rep == "hml_vec" # Using the Encoder architecture transformer_encoder_layer = nn.TransformerEncoderLayer( d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation, ) self.seqTransEncoder = nn.TransformerEncoder( transformer_encoder_layer, num_layers=self.num_layers ) self.sequence_pos_encoder = PositionalEncoding( self.latent_dim, self.dropout, max_len=kargs.get("pos_embed_max_len", 5000) ) self.embed_timestep = TimestepEmbedder( self.latent_dim, self.sequence_pos_encoder ) # We'll use CLIP for now self.clip_version = clip_version self.clip_model = load_and_freeze_clip(clip_version) self.encode_text = self.clip_encode_text self.embed_text = nn.Linear(self.clip_dim, self.latent_dim) # Linear input and output layers self.input_process = InputProcess(self.input_feats, self.latent_dim) self.output_process = OutputProcess( self.input_feats, self.latent_dim, self.njoints, self.nfeats ) self.rot2xyz = Rotation2xyz(device="cpu", dataset=self.dataset) def parameters_wo_clip(self): return [ p for name, p in self.named_parameters() if not name.startswith("clip_model.") ] def mask_cond(self, cond, force_mask=False): bs = cond.shape[-2] if force_mask: return torch.zeros_like(cond) elif self.training and self.cond_mask_prob > 0.0: mask = torch.bernoulli( torch.ones(bs, device=cond.device) * self.cond_mask_prob ).view(1, bs, 1) # 1-> use null_cond, 0-> use real cond return cond * (1.0 - mask) else: return cond def clip_encode_text(self, raw_text): # raw_text - list (batch_size length) of strings with input text prompts device = next(self.parameters()).device max_text_len = ( 20 if self.dataset in ["kit", "humanml", "humanml_with_images"] else None ) # Specific hardcoding for humanml dataset if max_text_len is not None: default_context_length = 77 context_length = max_text_len + 2 # start_token + 20 + end_token assert context_length < default_context_length texts = clip.tokenize( raw_text, context_length=context_length, truncate=True ).to( device ) # [bs, context_length] # if n_tokens > context_length -> will truncate # print('texts', texts.shape) zero_pad = torch.zeros( [texts.shape[0], default_context_length - context_length], dtype=texts.dtype, device=texts.device, ) texts = torch.cat([texts, zero_pad], dim=1) # print('texts after pad', texts.shape, texts) else: texts = clip.tokenize(raw_text, truncate=True).to( device ) # [bs, context_length] # if n_tokens > 77 -> will truncate return self.clip_model.encode_text(texts).float().unsqueeze(0) def motion_to_sequence(self, motion, timesteps, y): if "text_embed" not in y: clip_encoded_text = self.encode_text(y["text"]) else: clip_encoded_text = y["text_embed"] # casting mask for the single-prompt-for-all case force_mask = y.get("uncond", False) # [1, bs, latent_dim] text_embedding = self.embed_text( self.mask_cond(clip_encoded_text, force_mask=force_mask) ) # compute the embedding of the timestep + text, z_tk in the paper time_embedding = self.embed_timestep(timesteps) # [1, bs, latent_dim] embedding = text_embedding + time_embedding # [1, bs, latent_dim] # get the motion into latent space sequence = self.input_process(motion) # [num_frames, bs, latent_dim] sequence_plus_emb = torch.cat( (embedding, sequence), dim=0 ) # [num_frames + 1, bs, latent_dim] return sequence_plus_emb def sequence_to_motion(self, sequence_plus_emb): # remove the embedding from the sequence, remove the z_tk from the paper sequence = sequence_plus_emb[1:] # [num_frames, bs, latent_dim] # get back the motion from the latent space motion = self.output_process( sequence ) # [bs, num_joints, num_features, num_frames] return motion def prepare_mask(self, sequence, device, y, bs): # Don't use mask with the generate script is_valid_mask = y["mask"].shape[-1] > 1 if self.mask_frames and is_valid_mask: frames_mask = (torch.logical_not(y["mask"][..., : sequence.shape[0]].squeeze(1).squeeze(1)) .to(device=device)) step_mask = torch.zeros((bs, 1), dtype=torch.bool, device=device) return torch.cat([step_mask, frames_mask], dim=1) else: return None def forward(self, motion, timesteps, y=None): """ motion: [bs, num_joints, num_features, num_frames] timesteps: [bs] """ sequence = self.motion_to_sequence(motion, timesteps, y) # apply positional encoding sequence = self.sequence_pos_encoder( sequence ) # [num_frames + 1, bs, latent_dim] frames_mask = self.prepare_mask(sequence, motion.device, y, motion.shape[0]) # actual transformer magic sequence = self.seqTransEncoder(sequence, src_key_padding_mask=frames_mask) motion = self.sequence_to_motion(sequence) return motion def _apply(self, fn): super()._apply(fn) self.rot2xyz.smpl_model._apply(fn) def train(self, *args, **kwargs): super().train(*args, **kwargs) self.rot2xyz.smpl_model.train(*args, **kwargs) class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer("pe", pe) def forward(self, x): # not used in the final model x = x + self.pe[: x.shape[0], :] return self.dropout(x) class TimestepEmbedder(nn.Module): def __init__(self, latent_dim, sequence_pos_encoder): super().__init__() self.sequence_pos_encoder = sequence_pos_encoder self.time_embed = nn.Sequential( nn.Linear(latent_dim, latent_dim), nn.SiLU(), nn.Linear(latent_dim, latent_dim), ) def forward(self, timesteps): return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) class InputProcess(nn.Module): """ Applies the linear layer on the motion sequence at the beginning of the MDM Also changes the shape [bs, num_joints, num_features, num_frames] -> [num_frames, bs, latent_dim] """ def __init__(self, input_feats, latent_dim): super().__init__() self.poseEmbedding = nn.Linear(input_feats, latent_dim) def forward(self, sequence): bs, num_joints, num_features, num_frames = sequence.shape sequence = sequence.permute((3, 0, 1, 2)).reshape( num_frames, bs, num_joints * num_features ) sequence = self.poseEmbedding(sequence) return sequence class OutputProcess(nn.Module): """ Applies the linear layer on the motion sequence at the end of the MDM Also changes the shape [num_frames, bs, latent_dim] -> [bs, num_joints, num_features, num_frames] """ def __init__(self, input_feats, latent_dim, num_joints, num_features): super().__init__() self.input_feats = input_feats self.latent_dim = latent_dim self.num_joints = num_joints self.num_features = num_features self.poseFinal = nn.Linear(latent_dim, input_feats) def forward(self, sequence): num_frames, bs, _ = sequence.shape sequence = self.poseFinal(sequence) sequence = sequence.reshape( num_frames, bs, self.num_joints, self.num_features ).permute(1, 2, 3, 0) return sequence def load_and_freeze_clip(clip_version): # Must set jit=False for training clip_model, clip_preprocess = clip.load(clip_version, device="cpu", jit=False) clip_model.eval() # Freeze CLIP weights clip_model.requires_grad_(False) return clip_model