import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class ReConsLoss(nn.Module): def __init__(self, motion_dim=272): super(ReConsLoss, self).__init__() self.motion_dim = motion_dim def softclip(self, tensor, min): result_tensor = min + F.softplus(tensor - min) return result_tensor def gaussian_nll(self, mu, log_sigma, x): return 0.5 * torch.pow((x - mu) / log_sigma.exp(), 2) + log_sigma + 0.5 * np.log(2 * np.pi) def forward(self, motion_pred, motion_gt) : """Optimal sigma VAE loss, see https://arxiv.org/pdf/2006.13202 for more details""" log_sigma = ((motion_gt[..., :self.motion_dim] - motion_pred[..., :self.motion_dim]) ** 2).mean([0,1,2], keepdim=True).sqrt().log() log_sigma = self.softclip(log_sigma, -6) loss = self.gaussian_nll(motion_pred[..., :self.motion_dim], log_sigma, motion_gt[..., :self.motion_dim]).sum() return loss def forward_KL(self, mu, logvar): loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=(1, 2)) return loss.mean() def forward_root(self, motion_pred, motion_gt): """[..., :8] relate to the root joint""" root_log_sigma = ((motion_gt[..., :8] - motion_pred[..., :8]) ** 2).mean([0,1,2], keepdim=True).sqrt().log() root_log_sigma = self.softclip(root_log_sigma, -6) root_loss = self.gaussian_nll(motion_pred[..., :8], root_log_sigma, motion_gt[..., :8]).sum() return root_loss