motion-stream / eval_causal_TAE.py
zirobtc's picture
Initial upload of MotionStreamer code, excluding large extracted data and output folders.
0e267a7 verified
import os
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import json
import models.tae as tae
import options.option_tae as option_tae
import utils.utils_model as utils_model
import utils.eval_trans as eval_trans
from humanml3d_272 import dataset_eval_tae
import sys
import warnings
warnings.filterwarnings('ignore')
os.chdir('Evaluator_272')
sys.path.insert(0, os.getcwd())
comp_device = torch.device('cuda')
##### ---- Exp dirs ---- #####
args = option_tae.get_args_parser()
torch.manual_seed(args.seed)
args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
os.makedirs(args.out_dir, exist_ok = True)
##### ---- Logger ---- #####
logger = utils_model.get_logger(args.out_dir)
writer = SummaryWriter(args.out_dir)
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
val_loader = dataset_eval_tae.DATALoader(args.dataname, True, 32)
##### ---- Network ---- #####
clip_range = [-30,20]
net = tae.Causal_HumanTAE(
hidden_size=args.hidden_size,
down_t=args.down_t,
stride_t=args.stride_t,
depth=args.depth,
dilation_growth_rate=args.dilation_growth_rate,
activation='relu',
latent_dim=args.latent_dim,
clip_range=clip_range
)
print ('loading checkpoint from {}'.format(args.resume_pth))
ckpt = torch.load(args.resume_pth, map_location='cpu')
net.load_state_dict(ckpt['net'], strict=True)
net.eval()
net.to(comp_device)
# load evaluator:--------------------------------
import torch
from mld.models.architectures.temos.textencoder.distillbert_actor import DistilbertActorAgnosticEncoder
from mld.models.architectures.temos.motionencoder.actor import ActorAgnosticEncoder
modelpath = 'distilbert-base-uncased'
textencoder = DistilbertActorAgnosticEncoder(modelpath, num_layers=4, latent_dim=256)
motionencoder = ActorAgnosticEncoder(nfeats=272, vae = True, num_layers=4, latent_dim=256, max_len=300)
ckpt = torch.load('epoch=99.ckpt')
# load textencoder
textencoder_ckpt = {}
for k, v in ckpt['state_dict'].items():
if k.split(".")[0] == "textencoder":
name = k.replace("textencoder.", "")
textencoder_ckpt[name] = v
textencoder.load_state_dict(textencoder_ckpt, strict=True)
textencoder.eval()
textencoder.to(comp_device)
# load motionencoder
motionencoder_ckpt = {}
for k, v in ckpt['state_dict'].items():
if k.split(".")[0] == "motionencoder":
name = k.replace("motionencoder.", "")
motionencoder_ckpt[name] = v
motionencoder.load_state_dict(motionencoder_ckpt, strict=True)
motionencoder.eval()
motionencoder.to(comp_device)
#--------------------------------
evaluator = [textencoder, motionencoder]
fid = []
mpjpe = []
best_fid, best_mpjpe, writer, logger = eval_trans.evaluation_tae_single(args.out_dir, val_loader, net, logger, writer, evaluator=evaluator, device=comp_device)
fid.append(best_fid)
mpjpe.append(best_mpjpe)
logger.info('final result:')
logger.info(f'fid: {fid}')
logger.info(f'mpjpe: {mpjpe} (mm)')