motion-stream / utils /utils_model.py
zirobtc's picture
Initial upload of MotionStreamer code, excluding large extracted data and output folders.
0e267a7 verified
import numpy as np
import torch
import torch.optim as optim
import logging
import os
import sys
def getCi(accLog):
mean = np.mean(accLog)
std = np.std(accLog)
ci95 = 1.96*std/np.sqrt(len(accLog))
return mean, ci95
def get_logger(out_dir):
logger = logging.getLogger('Exp')
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
file_path = os.path.join(out_dir, "run.log")
file_hdlr = logging.FileHandler(file_path)
file_hdlr.setFormatter(formatter)
strm_hdlr = logging.StreamHandler(sys.stdout)
strm_hdlr.setFormatter(formatter)
logger.addHandler(file_hdlr)
logger.addHandler(strm_hdlr)
return logger
def initial_optim(decay_option, lr, weight_decay, net, optimizer) :
if optimizer == 'adamw' :
optimizer_adam_family = optim.AdamW
elif optimizer == 'adam' :
optimizer_adam_family = optim.Adam
if decay_option == 'all':
optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=weight_decay)
else:
raise NotImplementedError
return optimizer
def initial_optim_with_eps(decay_option, lr, weight_decay, net, optimizer, eps) :
if optimizer == 'adamw' :
optimizer_adam_family = optim.AdamW
elif optimizer == 'adam' :
optimizer_adam_family = optim.Adam
if decay_option == 'all':
optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=weight_decay, eps=eps)
elif decay_option == 'noVQ':
all_params = set(net.parameters())
no_decay = set([net.vq_layer])
decay = all_params - no_decay
optimizer = optimizer_adam_family([
{'params': list(no_decay), 'weight_decay': 0},
{'params': list(decay), 'weight_decay' : weight_decay}], lr=lr, eps=eps)
return optimizer
def get_motion_with_trans(motion, velocity) :
'''
motion : torch.tensor, shape (batch_size, T, 72), with the global translation = 0
velocity : torch.tensor, shape (batch_size, T, 3), contain the information of velocity = 0
'''
trans = torch.cumsum(velocity, dim=1)
trans = trans - trans[:, :1]
trans = trans.repeat((1, 1, 21))
motion_with_trans = motion + trans
return motion_with_trans