class MDMConfig: """ Configuration class for MDM model. Encapsulates all model hyperparameters and options. """ model_name = "MDM" def __init__( self, arch="trans_enc", text_encoder_type="clip", emb_trans_dec=False, layers=8, latent_dim=512, cond_mask_prob=0.1, mask_frames=False, lambda_rcxyz=0.0, lambda_vel=0.0, lambda_fc=0.0, lambda_target_loc=0.0, unconstrained=False, pos_embed_max_len=5000, use_ema=False, multi_target_cond=False, multi_encoder_type="single", target_enc_layers=1, context_len=0, pred_len=0, # Additional MDM-specific args modeltype=None, njoints=None, nfeats=None, num_actions=None, translation=None, pose_rep=None, glob=None, glob_rot=None, ff_size=1024, num_heads=4, dropout=0.1, ablation=None, activation="gelu", legacy=False, data_rep="rot6d", dataset="amass", cond_mode="text", clip_dim=512, clip_version=None, action_emb=None, normalize_encoder_output=False, emb_policy="add", all_goal_joint_names=None, diffusion_steps=1000, noise_schedule="linear", sigma_small=False, ): self.arch = arch self.text_encoder_type = text_encoder_type self.emb_trans_dec = emb_trans_dec self.layers = layers self.latent_dim = latent_dim self.cond_mask_prob = cond_mask_prob self.mask_frames = mask_frames self.lambda_rcxyz = lambda_rcxyz self.lambda_vel = lambda_vel self.lambda_fc = lambda_fc self.lambda_target_loc = lambda_target_loc self.unconstrained = unconstrained self.pos_embed_max_len = pos_embed_max_len self.use_ema = use_ema self.multi_target_cond = multi_target_cond self.multi_encoder_type = multi_encoder_type self.target_enc_layers = target_enc_layers self.context_len = context_len self.pred_len = pred_len # MDM-specific self.modeltype = modeltype self.njoints = njoints self.nfeats = nfeats self.num_actions = num_actions self.translation = translation self.pose_rep = pose_rep self.glob = glob self.glob_rot = glob_rot self.ff_size = ff_size self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation self.legacy = legacy self.data_rep = data_rep self.dataset = dataset self.cond_mode = cond_mode self.clip_dim = clip_dim self.clip_version = clip_version self.action_emb = action_emb self.normalize_encoder_output = normalize_encoder_output self.emb_policy = emb_policy self.all_goal_joint_names = all_goal_joint_names or [] self.diffusion_steps = diffusion_steps self.noise_schedule = noise_schedule self.sigma_small = sigma_small def to_dict(self): return self.__dict__ @classmethod def from_args(cls, args): """Create config from argparse.Namespace or dict.""" return cls(**vars(args))