|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
import json |
|
|
import models.tae as tae |
|
|
import options.option_tae as option_tae |
|
|
import utils.utils_model as utils_model |
|
|
import utils.eval_trans as eval_trans |
|
|
from humanml3d_272 import dataset_eval_tae |
|
|
import sys |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
os.chdir('Evaluator_272') |
|
|
sys.path.insert(0, os.getcwd()) |
|
|
|
|
|
|
|
|
comp_device = torch.device('cuda') |
|
|
|
|
|
|
|
|
args = option_tae.get_args_parser() |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') |
|
|
os.makedirs(args.out_dir, exist_ok = True) |
|
|
|
|
|
|
|
|
logger = utils_model.get_logger(args.out_dir) |
|
|
writer = SummaryWriter(args.out_dir) |
|
|
logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) |
|
|
|
|
|
val_loader = dataset_eval_tae.DATALoader(args.dataname, True, 32) |
|
|
|
|
|
|
|
|
clip_range = [-30,20] |
|
|
|
|
|
net = tae.Causal_HumanTAE( |
|
|
hidden_size=args.hidden_size, |
|
|
down_t=args.down_t, |
|
|
stride_t=args.stride_t, |
|
|
depth=args.depth, |
|
|
dilation_growth_rate=args.dilation_growth_rate, |
|
|
activation='relu', |
|
|
latent_dim=args.latent_dim, |
|
|
clip_range=clip_range |
|
|
) |
|
|
|
|
|
|
|
|
print ('loading checkpoint from {}'.format(args.resume_pth)) |
|
|
ckpt = torch.load(args.resume_pth, map_location='cpu') |
|
|
net.load_state_dict(ckpt['net'], strict=True) |
|
|
net.eval() |
|
|
net.to(comp_device) |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from mld.models.architectures.temos.textencoder.distillbert_actor import DistilbertActorAgnosticEncoder |
|
|
from mld.models.architectures.temos.motionencoder.actor import ActorAgnosticEncoder |
|
|
|
|
|
modelpath = 'distilbert-base-uncased' |
|
|
|
|
|
textencoder = DistilbertActorAgnosticEncoder(modelpath, num_layers=4, latent_dim=256) |
|
|
motionencoder = ActorAgnosticEncoder(nfeats=272, vae = True, num_layers=4, latent_dim=256, max_len=300) |
|
|
|
|
|
ckpt = torch.load('epoch=99.ckpt') |
|
|
|
|
|
|
|
|
textencoder_ckpt = {} |
|
|
for k, v in ckpt['state_dict'].items(): |
|
|
if k.split(".")[0] == "textencoder": |
|
|
name = k.replace("textencoder.", "") |
|
|
textencoder_ckpt[name] = v |
|
|
textencoder.load_state_dict(textencoder_ckpt, strict=True) |
|
|
textencoder.eval() |
|
|
textencoder.to(comp_device) |
|
|
|
|
|
|
|
|
motionencoder_ckpt = {} |
|
|
for k, v in ckpt['state_dict'].items(): |
|
|
if k.split(".")[0] == "motionencoder": |
|
|
name = k.replace("motionencoder.", "") |
|
|
motionencoder_ckpt[name] = v |
|
|
motionencoder.load_state_dict(motionencoder_ckpt, strict=True) |
|
|
motionencoder.eval() |
|
|
motionencoder.to(comp_device) |
|
|
|
|
|
|
|
|
evaluator = [textencoder, motionencoder] |
|
|
|
|
|
fid = [] |
|
|
mpjpe = [] |
|
|
|
|
|
best_fid, best_mpjpe, writer, logger = eval_trans.evaluation_tae_single(args.out_dir, val_loader, net, logger, writer, evaluator=evaluator, device=comp_device) |
|
|
fid.append(best_fid) |
|
|
mpjpe.append(best_mpjpe) |
|
|
|
|
|
logger.info('final result:') |
|
|
logger.info(f'fid: {fid}') |
|
|
logger.info(f'mpjpe: {mpjpe} (mm)') |