File size: 6,976 Bytes
a50b63b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import torch
import numpy as np
from models.llama_model import LLaMAHF, LLaMAHFConfig
import models.tae as tae
import options.option_transformer as option_trans
import warnings

import smplx
from utils import bvh, quat
from utils.face_z_align_util import rotation_6d_to_matrix, matrix_to_axis_angle, axis_angle_to_quaternion
from sentence_transformers import SentenceTransformer
warnings.filterwarnings('ignore')

# --- save_motion_as_bvh function is unchanged ---
def save_motion_as_bvh(motion_data, output_path, fps=30):
    print(f"--- Starting direct conversion to BVH: {os.path.basename(output_path)} ---")
    try:
        if isinstance(motion_data, torch.Tensor): motion_data = motion_data.detach().cpu().numpy()
        if motion_data.ndim == 3 and motion_data.shape[0] == 1: motion_data = motion_data.squeeze(0)
        elif motion_data.ndim != 2: raise ValueError(f"Input motion data must be 2D, but got shape {motion_data.shape}")
        njoint = 22; nfrm, _ = motion_data.shape
        rotations_matrix = rotation_6d_to_matrix(torch.from_numpy(motion_data[:, 8+6*njoint : 8+12*njoint]).reshape(nfrm, -1, 6)).numpy()
        global_heading_diff_rot_6d = torch.from_numpy(motion_data[:, 2:8])
        global_heading_diff_rot = rotation_6d_to_matrix(global_heading_diff_rot_6d).numpy()
        global_heading_rot = np.zeros_like(global_heading_diff_rot); global_heading_rot[0] = global_heading_diff_rot[0]
        for i in range(1, nfrm): global_heading_rot[i] = np.matmul(global_heading_diff_rot[i], global_heading_rot[i-1])
        velocities_root_xy = motion_data[:, :2]; height = motion_data[:, 8 : 8+3*njoint].reshape(nfrm, -1, 3)[:, 0, 1]
        inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1)); rotations_matrix[:, 0, ...] = np.matmul(inv_global_heading_rot, rotations_matrix[:, 0, ...])
        velocities_root_xyz = np.zeros((nfrm, 3)); velocities_root_xyz[:, 0] = velocities_root_xy[:, 0]; velocities_root_xyz[:, 2] = velocities_root_xy[:, 1]
        velocities_root_xyz[1:, :] = np.matmul(inv_global_heading_rot[:-1], velocities_root_xyz[1:, :, None]).squeeze(-1)
        root_translation = np.cumsum(velocities_root_xyz, axis=0); root_translation[:, 1] = height
        axis_angle = matrix_to_axis_angle(torch.from_numpy(rotations_matrix)).numpy().reshape(nfrm, -1); poses_24_joints = np.zeros((nfrm, 72)); poses_24_joints[:, :66] = axis_angle
        model = smplx.create(model_path="body_models/human_model_files", model_type="smpl", gender="NEUTRAL"); parents = model.parents.detach().cpu().numpy()
        rest_pose = model().joints.detach().cpu().numpy().squeeze()[:24,:]; offsets = rest_pose - rest_pose[parents]; offsets[0] = np.array([0,0,0])
        rotations_quat = axis_angle_to_quaternion(torch.from_numpy(poses_24_joints.reshape(-1, 24, 3))).numpy(); rotations_euler = np.degrees(quat.to_euler(rotations_quat, order="zyx"))
        positions = np.zeros_like(rotations_quat[..., :3]); positions[:, 0] = root_translation
        joint_names = ["Pelvis", "Left_hip", "Right_hip", "Spine1", "Left_knee", "Right_knee", "Spine2", "Left_ankle", "Right_ankle", "Spine3", "Left_foot", "Right_foot", "Neck", "Left_collar", "Right_collar", "Head", "Left_shoulder", "Right_shoulder", "Left_elbow", "Right_elbow", "Left_wrist", "Right_wrist", "Left_hand", "Right_hand"]
        bvh.save(output_path, {"rotations": rotations_euler, "positions": positions, "offsets": offsets, "parents": parents, "names": joint_names, "order": "zyx", "frametime": 1.0 / fps})
        print(f"✅ BVH file saved successfully to {output_path}")
    except Exception as e:
        print(f"❌ BVH Conversion Failed. Error: {e}"); import traceback; traceback.print_exc()


if __name__ == '__main__':
    comp_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args = option_trans.get_args_parser()
    torch.manual_seed(args.seed)

    # --- Load Models ---
    print("Loading models for MotionStreamer...")
    t5_model = SentenceTransformer('sentencet5-xxl/')
    t5_model.eval()
    for p in t5_model.parameters():
        p.requires_grad = False

    print("Loading Causal TAE (t2m_babel) checkpoint...")
    tae_net = tae.Causal_HumanTAE(
        hidden_size=1024, down_t=2, stride_t=2, depth=3, dilation_growth_rate=3,
        latent_dim=16, clip_range=[-30, 20]
    )
    tae_ckpt = torch.load('Causal_TAE_t2m_babel/net_last.pth', map_location='cpu')
    tae_net.load_state_dict(tae_ckpt['net'], strict=True)
    tae_net.eval()
    tae_net.to(comp_device)

    config = LLaMAHFConfig.from_name('Normal_size')
    config.block_size = 78
    trans_encoder = LLaMAHF(config, args.num_diffusion_head_layers, args.latent_dim, comp_device)

    # --- THIS IS THE FIX ---
    print("Loading your trained MotionStreamer checkpoint from 'motionstreamer_model/latest.pth'...")
    # Make sure this path is correct relative to where you run the script
    checkpoint_path = 'motionstreamer_model/latest.pth'
    trans_ckpt = torch.load(checkpoint_path, map_location='cpu')

    # Create a new state dict without the 'module.' prefix
    unwrapped_state_dict = {}
    for key, value in trans_ckpt['trans'].items():
        if key.startswith('module.'):
            # Strip the 'module.' prefix
            unwrapped_state_dict[key[len('module.'):]] = value
        else:
            # Keep keys that don't have the prefix (just in case)
            unwrapped_state_dict[key] = value

    # Load the unwrapped state dict
    trans_encoder.load_state_dict(unwrapped_state_dict, strict=True)
    print("Successfully loaded unwrapped checkpoint.")
    # --- END FIX ---

    trans_encoder.eval()
    trans_encoder.to(comp_device)

    # --- Rest of the script is unchanged ---
    print("Loading mean/std from BABEL dataset...")
    mean = np.load('babel_272/t2m_babel_mean_std/Mean.npy')
    std = np.load('babel_272/t2m_babel_mean_std/Std.npy')

    motion_history = torch.empty(0, 16).to(comp_device)
    cfg_scale = 10.0

    print(f"Generating motion for text: '{args.text}' with CFG scale: {cfg_scale}")
    with torch.no_grad():
        # Use the new two-forward sampling method to match training
        _, motion_latents = trans_encoder.sample_for_eval_CFG_babel_inference_two_forward(
            B_text=args.text,
            A_motion=motion_history,
            tokenizer='t5-xxl',
            clip_model=t5_model,
            device=comp_device,
            cfg=cfg_scale,
            length=240,
            temperature=1.3
        )

        print("Decoding latents to full motion...")
        motion_seqs = tae_net.forward_decoder(motion_latents)

    motion = motion_seqs.detach().cpu().numpy()
    motion_denormalized = motion * std + mean

    output_dir = 'demo_output_streamer'
    if not os.path.exists(output_dir): os.makedirs(output_dir)

    output_bvh_path = os.path.join(output_dir, f'{args.text.replace(" ", "_")}_cfg{cfg_scale}.bvh')
    save_motion_as_bvh(motion_denormalized, output_bvh_path, fps=30)