diff --git a/Causal_TAE/net_last.pth b/Causal_TAE/net_last.pth new file mode 100644 index 0000000000000000000000000000000000000000..2a6f7fdfb4df76b1c702fb6beafe95988e0c1edd --- /dev/null +++ b/Causal_TAE/net_last.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8becaeebbd0588d7080ea3baf19ca036fe06851035c8b5f214dac1a5cf23949c +size 304843534 diff --git a/Causal_TAE_t2m_babel/net_last.pth b/Causal_TAE_t2m_babel/net_last.pth new file mode 100644 index 0000000000000000000000000000000000000000..18bf35ff91598c94f4658219278905a36b0ea853 --- /dev/null +++ b/Causal_TAE_t2m_babel/net_last.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d4cf982269fed7887c45076852fe44be3611ac3c7761caaa5c849a8725ae3c6 +size 304843534 diff --git a/Evaluator_272/.DS_Store b/Evaluator_272/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7f24622f6c835010627fe92b2c8f432625b4dd6d Binary files /dev/null and b/Evaluator_272/.DS_Store differ diff --git a/Evaluator_272/configs/assets.yaml b/Evaluator_272/configs/assets.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4f8e7bb4971c64c4268845fd35670aedddcc6e6 --- /dev/null +++ b/Evaluator_272/configs/assets.yaml @@ -0,0 +1,13 @@ +FOLDER: './experiments' # Experiment files saving path + +TEST: + FOLDER: './results' # Testing files saving path + +DATASET: + HUMANML3D_272: + ROOT: './datasets/humanml3d_272' # HumanML3D_272 directory + SPLIT_ROOT: './datasets/humanml3d_272/split' # HumanML3D_272 splits directory + +model: + bert_path: './deps/distilbert-base-uncased' + diff --git a/Evaluator_272/configs/base.yaml b/Evaluator_272/configs/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c79e464bb6b248e584b6d36f287a6a35af890adb --- /dev/null +++ b/Evaluator_272/configs/base.yaml @@ -0,0 +1,92 @@ +SEED_VALUE: 1234 +DEBUG: True +TRAIN: + SPLIT: 'train' + NUM_WORKERS: 2 # Number of workers + BATCH_SIZE: 4 # Size of batches + START_EPOCH: 0 # Start epoch + END_EPOCH: 400 # End epoch + RESUME: '' # Experiment path to be resumed training + PRETRAINED_VAE: '' + PRETRAINED: '' # Pretrained model path + + OPTIM: + OPTIM.TYPE: 'AdamW' # Optimizer type + OPTIM.LR: 1e-4 # Learning rate + + ABLATION: + VAE_TYPE: 'actor' # vae ablation: actor or mcross + VAE_ARCH: 'encoder_decoder' # mdiffusion vae architecture + PE_TYPE: 'actor' # mdiffusion mld or actor + DIFF_PE_TYPE: 'actor' # mdiffusion mld or actor + SKIP_CONNECT: False # skip connection for denoiser va + # use linear to expand mean and std rather expand token nums + MLP_DIST: False + IS_DIST: False # Mcross distribution kl + PREDICT_EPSILON: True # noise or motion + +EVAL: + SPLIT: 'gtest' + BATCH_SIZE: 1 # Evaluating Batch size + NUM_WORKERS: 12 # Evaluating Batch size + +TEST: + TEST_DIR: '' + CHECKPOINTS: '' # Pretrained model path + SPLIT: 'gtest' + BATCH_SIZE: 1 # Testing Batch size + NUM_WORKERS: 12 # Evaluating Batch size + SAVE_PREDICTIONS: False # Weather to save predictions + COUNT_TIME: False # Weather to count time during test + REPLICATION_TIMES: 20 # Number of times to replicate the test + MM_NUM_SAMPLES: 100 # Number of samples for multimodal test + MM_NUM_REPEATS: 30 # Number of repeats for multimodal test + MM_NUM_TIMES: 10 # Number of times to repeat the multimodal test + DIVERSITY_TIMES: 300 # Number of times to repeat the diversity test + REP_I: 0 +model: + target: 'modules' + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 +LOSS: + LAMBDA_LATENT: 1e-5 # Lambda for latent losses + LAMBDA_KL: 1e-5 # Lambda for kl losses + LAMBDA_REC: 1.0 # Lambda for reconstruction losses + LAMBDA_JOINT: 1.0 # Lambda for joint losses + LAMBDA_GEN: 1.0 # Lambda for text-motion generation losses + LAMBDA_CROSS: 1.0 # Lambda for cross-reconstruction losses + LAMBDA_CYCLE: 1.0 # Lambda for cycle losses + LAMBDA_PRIOR: 0.0 + DIST_SYNC_ON_STEP: True +METRIC: + FORCE_IN_METER: True + DIST_SYNC_ON_STEP: True +DATASET: + NCLASSES: 10 + SAMPLER: + MAX_SQE: -1 + MAX_LEN: 196 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + HUMANML3D_272: + UNIT_LEN: 4 + + +LOGGER: + SACE_CHECKPOINT_EPOCH: 1 + LOG_EVERY_STEPS: 1 + VAL_EVERY_STEPS: 10 + TENSORBOARD: true + WANDB: + OFFLINE: false + PROJECT: null + RESUME_ID: null diff --git a/Evaluator_272/configs/configs_evaluator_272/H3D-TMR.yaml b/Evaluator_272/configs/configs_evaluator_272/H3D-TMR.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5b8eb8ea8cf99e582b117d6e1ffbabc73bc065d2 --- /dev/null +++ b/Evaluator_272/configs/configs_evaluator_272/H3D-TMR.yaml @@ -0,0 +1,95 @@ +NAME: EXP1 # Experiment name +DEBUG: False # Debug mode +ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto” +DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3] +# DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3] + +# Training configuration +TRAIN: + #--------------------------------- + STAGE: temos # stage "vae" or "diffusion", "vae_diffusion" + #--------------------------------- + DATASETS: ['humanml3d_272'] # Training datasets + NUM_WORKERS: 11 # Number of workers + BATCH_SIZE: 256 # Size of batches + START_EPOCH: 0 # Start epochMMOTIONENCODER + END_EPOCH: 100 # End epoch + RESUME: '' # Resume training from this path + OPTIM: + TYPE: AdamW # Optimizer type + LR: 1e-4 # Learning rate + PRETRAINED_MLD: False + +# Evaluating Configuration +EVAL: + DATASETS: ['humanml3d_272'] # Evaluating datasets + BATCH_SIZE: 32 # Evaluating Batch size + SPLIT: test + eval_self_on_gt: True + +# Test Configuration +TEST: + PRETRAINED_CHECKPOINTS_VAE: '' + SAVE_PREDICTIONS: False + CHECKPOINTS: '' # Pretrained model path + DATASETS: ['humanml3d_272'] # training datasets + SPLIT: test + BATCH_SIZE: 32 # training Batch size + MEAN: False + NUM_SAMPLES: 1 + FACT: 1 + inference_vq_code: False + # REPLICATION_TIM + +# Datasets Configuration +DATASET: + JOINT_TYPE: 'humanml3d_v3' # join type + VERSION: '' + MOTION_TYPE: '' +METRIC: + TYPE: ['TMR_TM2TMetrics'] +# Losses Configuration +LOSS: + TYPE: temos # Losses type + USE_INFONCE: True + USE_INFONCE_FILTER: True + LAMBDA_LATENT: 1.0e-5 # Lambda for latent Losses + LAMBDA_KL: 1.0e-5 # Lambda for kl Losses + LAMBDA_REC: 1.0 # Lambda for reconstruction Losses + LAMBDA_GEN: 1.0 # Lambda for text-motion generation losses + LAMBDA_CROSS: 1.0 # Lambda for reconstruction Losses + LAMBDA_CYCLE: 0.0 # Lambda for cycle Losses + LAMBDA_PRIOR: 0.0 + LAMBDA_INFONCE: 0.1 # Lambda for infonce + INFONCE_TEMP: 0.1 + DIST_SYNC_ON_STEP: False # Sync Losses on step when distributed trained + USE_RECLIPLOSS: False + SYNC: False + TRAIN_TMR: False + +# Model Configuration +model: + vae: true # whether vae model + model_type: temos # model type + condition: 'text' + target: modules_temos + ##### + latent_dim: 256 # latent dimension + ff_size: 1024 # + num_layers: 4 # number of layers + num_head: 6 # number of head layers + dropout: 0.1 # dropout rate + activation: gelu # activation type + eval_text_encode_way: given_glove + eval_text_source: token + +# Logger configuration +LOGGER: + SAVE_CHECKPOINT_EPOCH: 10 + LOG_EVERY_STEPS: 1 + VAL_EVERY_STEPS: 5 + TENSORBOARD: True + WANDB: + PROJECT: null + OFFLINE: False + RESUME_ID: null \ No newline at end of file diff --git a/Evaluator_272/configs/modules/denoiser.yaml b/Evaluator_272/configs/modules/denoiser.yaml new file mode 100644 index 0000000000000000000000000000000000000000..96964dd96c29fcdc38a9500d82c3cd87f8acfcd9 --- /dev/null +++ b/Evaluator_272/configs/modules/denoiser.yaml @@ -0,0 +1,22 @@ +denoiser: + target: mld.models.architectures.mld_denoiser.MldDenoiser + params: + text_encoded_dim: 768 + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: False + activation: 'gelu' + flip_sin_to_cos: True + return_intermediate_dec: False + position_embedding: 'learned' + arch: trans_enc + freq_shift: 0 + condition: ${model.condition} + latent_dim: ${model.latent_dim} + guidance_scale: ${model.guidance_scale} + guidance_uncondp: ${model.guidance_uncondp} + nfeats: ${DATASET.NFEATS} + nclasses: ${DATASET.NCLASSES} + ablation: ${TRAIN.ABLATION} diff --git a/Evaluator_272/configs/modules/evaluators.yaml b/Evaluator_272/configs/modules/evaluators.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12145873742544d94cfab32660143d91a8739d42 --- /dev/null +++ b/Evaluator_272/configs/modules/evaluators.yaml @@ -0,0 +1,20 @@ +t2m_textencoder: + target: mld.models.architectures.t2m_textenc.TextEncoderBiGRUCo + params: + word_size: 300 + pos_size: 15 + hidden_size: 512 + output_size: 512 + +t2m_moveencoder: + target: mld.models.architectures.t2m_textenc.MovementConvEncoder + params: + hidden_size: 512 + output_size: 512 + +t2m_motionencoder: + target: mld.models.architectures.t2m_motionenc.MotionEncoder + params: + input_size: ${model.t2m_moveencoder.output_size} + hidden_size: 1024 + output_size: 512 diff --git a/Evaluator_272/configs/modules/motion_vae.yaml b/Evaluator_272/configs/modules/motion_vae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3be33a3e9e5fb84deedba4026c012d4272be039b --- /dev/null +++ b/Evaluator_272/configs/modules/motion_vae.yaml @@ -0,0 +1,15 @@ +motion_vae: + # Optional: mld_vae, vposert_vae + target: mld.models.architectures.mld_vae.MldVae + params: + arch: 'encoder_decoder' + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: false + activation: 'gelu' + position_embedding: 'learned' + latent_dim: ${model.latent_dim} + nfeats: ${DATASET.NFEATS} + ablation: ${TRAIN.ABLATION} diff --git a/Evaluator_272/configs/modules/scheduler.yaml b/Evaluator_272/configs/modules/scheduler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c248593217d0d8d5ee225fd8ca468b55bfd4d56d --- /dev/null +++ b/Evaluator_272/configs/modules/scheduler.yaml @@ -0,0 +1,25 @@ +scheduler: + target: diffusers.DDIMScheduler + num_inference_timesteps: 50 + eta: 0.0 + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: 'scaled_linear' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2'] + # variance_type: 'fixed_small' + clip_sample: false # clip sample to -1~1 + # below are for ddim + set_alpha_to_one: false + steps_offset: 1 + + +noise_scheduler: + target: diffusers.DDPMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: 'scaled_linear' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2'] + variance_type: 'fixed_small' + clip_sample: false # clip sample to -1~1 diff --git a/Evaluator_272/configs/modules/text_encoder.yaml b/Evaluator_272/configs/modules/text_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0fb89f45c69a09250bc596229752c0ffc19dbb98 --- /dev/null +++ b/Evaluator_272/configs/modules/text_encoder.yaml @@ -0,0 +1,8 @@ +text_encoder: + # Optional: mld_clip, mld_bert + target: mld.models.architectures.mld_clip.MldTextEncoder + params: + finetune: false # if false, model weights are frozen + last_hidden_state: false # if true, the last hidden state is used as the text embedding + latent_dim: ${model.latent_dim} + modelpath: ${model.clip_path} diff --git a/Evaluator_272/configs/modules_temos/motiondecoder.yaml b/Evaluator_272/configs/modules_temos/motiondecoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cd701ae85044143773d37d3be23833d527634ff7 --- /dev/null +++ b/Evaluator_272/configs/modules_temos/motiondecoder.yaml @@ -0,0 +1,11 @@ +motiondecoder: + name: actor_decoder + target: mld.models.architectures.temos.motiondecoder.actor.ActorAgnosticDecoder + params: + latent_dim: ${model.latent_dim} + ff_size: ${model.ff_size} + num_layers: ${model.num_layers} + num_head: ${model.num_head} + droupout: ${model.dropout} + activation: ${model.activation} + nfeats: ${DATASET.NFEATS} \ No newline at end of file diff --git a/Evaluator_272/configs/modules_temos/motionencoder.yaml b/Evaluator_272/configs/modules_temos/motionencoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e696f00c00cdc4018843dec1489a5b68fd749819 --- /dev/null +++ b/Evaluator_272/configs/modules_temos/motionencoder.yaml @@ -0,0 +1,12 @@ +motionencoder: + name: actor_encoder + target: mld.models.architectures.temos.motionencoder.actor.ActorAgnosticEncoder + params: + latent_dim: ${model.latent_dim} + vae: ${model.vae} + ff_size: ${model.ff_size} + num_layers: ${model.num_layers} + num_head: ${model.num_head} + droupout: ${model.dropout} + activation: ${model.activation} + nfeats: ${DATASET.NFEATS} \ No newline at end of file diff --git a/Evaluator_272/configs/modules_temos/text_encoder.yaml b/Evaluator_272/configs/modules_temos/text_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c718b4c7a54571bde6fa64891a31f87b679909d6 --- /dev/null +++ b/Evaluator_272/configs/modules_temos/text_encoder.yaml @@ -0,0 +1,13 @@ +textencoder: + name: distilbert_actor + target: mld.models.architectures.temos.textencoder.distillbert_actor.DistilbertActorAgnosticEncoder + params: + latent_dim: ${model.latent_dim} + vae: ${model.vae} + ff_size: ${model.ff_size} + num_layers: ${model.num_layers} + num_head: ${model.num_head} + droupout: ${model.dropout} + activation: ${model.activation} + finetune: false + modelpath: ${model.bert_path} \ No newline at end of file diff --git a/Evaluator_272/datasets/__init__.py b/Evaluator_272/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/__init__.py b/Evaluator_272/mld/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/callback/__init__.py b/Evaluator_272/mld/callback/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e290a7d9ac46f036f793c88d7286cbc070d2057d --- /dev/null +++ b/Evaluator_272/mld/callback/__init__.py @@ -0,0 +1 @@ +from .progress import ProgressLogger diff --git a/Evaluator_272/mld/callback/progress.py b/Evaluator_272/mld/callback/progress.py new file mode 100644 index 0000000000000000000000000000000000000000..eca07fc20b6e2ec457ac46a7ac938c6f202ac51d --- /dev/null +++ b/Evaluator_272/mld/callback/progress.py @@ -0,0 +1,54 @@ +import logging + +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import Callback +import psutil + +logger = logging.getLogger() + + +class ProgressLogger(Callback): + + def __init__(self, metric_monitor: dict, precision: int = 3): + # Metric to monitor + self.metric_monitor = metric_monitor + self.precision = precision + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule, + **kwargs) -> None: + logger.info("Training started") + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule, + **kwargs) -> None: + logger.info("Training done") + + def on_validation_epoch_end(self, trainer: Trainer, + pl_module: LightningModule, **kwargs) -> None: + if trainer.sanity_checking: + logger.info("Sanity checking ok.") + + def on_train_epoch_end(self, + trainer: Trainer, + pl_module: LightningModule, + padding=False, + **kwargs) -> None: + metric_format = f"{{:.{self.precision}e}}" + line = f"Epoch {trainer.current_epoch}" + if padding: + line = f"{line:>{len('Epoch xxxx')}}" # Right padding + metrics_str = [] + + losses_dict = trainer.callback_metrics + for metric_name, dico_name in self.metric_monitor.items(): + if dico_name in losses_dict: + metric = losses_dict[dico_name].item() + metric = metric_format.format(metric) + metric = f"{metric_name} {metric}" + metrics_str.append(metric) + + if len(metrics_str) == 0: + return + + memory = f"Memory {psutil.virtual_memory().percent}%" + line = line + ": " + " ".join(metrics_str) + " " + memory + logger.info(line) diff --git a/Evaluator_272/mld/config.py b/Evaluator_272/mld/config.py new file mode 100644 index 0000000000000000000000000000000000000000..05bac9aa6784420357e67b6f37c2383ca6533724 --- /dev/null +++ b/Evaluator_272/mld/config.py @@ -0,0 +1,104 @@ +import importlib +from argparse import ArgumentParser +from omegaconf import OmegaConf +import os + + +def get_module_config(cfg_model, path="modules"): + module_conf = OmegaConf.create() + files = os.listdir(f'./configs/{path}/') + for file in files: + if file.endswith('.yaml'): + with open(f'./configs/{path}/' + file, 'r') as f: + module_conf.merge_with(OmegaConf.load(f)) + module_conf.merge_with(cfg_model) + return module_conf + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def parse_args(phase="train"): + parser = ArgumentParser() + + group = parser.add_argument_group("Training options") + if phase in ["train", "test"]: + group.add_argument( + "--cfg", + type=str, + required=False, + default="./configs/config.yaml", + help="config file", + ) + group.add_argument( + "--cfg_assets", + type=str, + required=False, + default="./configs/assets.yaml", + help="config file for asset paths", + ) + group.add_argument("--batch_size", + type=int, + required=False, + help="training batch size") + group.add_argument("--device", + type=int, + nargs="+", + required=False, + help="training device") + group.add_argument("--nodebug", + action="store_true", + required=False, + help="debug or not") + group.add_argument("--dir", + type=str, + required=False, + help="evaluate existing npys") + + # remove None params, and create a dictionnary + params = parser.parse_args() + # params = {key: val for key, val in vars(opt).items() if val is not None} + + # update config from files + cfg_base = OmegaConf.load('./configs/base.yaml') + cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load(params.cfg)) + cfg_model = get_module_config(cfg_exp.model, cfg_exp.model.target) + cfg_exp.model = cfg_model + cfg_assets = OmegaConf.load(params.cfg_assets) + cfg = OmegaConf.merge(cfg_exp, cfg_model, cfg_assets) + + if phase in ["train", "test"]: + cfg.TRAIN.BATCH_SIZE = (params.batch_size + if params.batch_size else cfg.TRAIN.BATCH_SIZE) + cfg.DEVICE = params.device if params.device else cfg.DEVICE + cfg.DEBUG = not params.nodebug if params.nodebug is not None else cfg.DEBUG + + cfg.DEBUG = False if phase == "test" else cfg.DEBUG + if phase == "test": + cfg.DEBUG = False + cfg.DEVICE = [0] + print("Force no debugging and one gpu when testing") + cfg.TEST.TEST_DIR = params.dir if params.dir else cfg.TEST.TEST_DIR + + # debug mode + if cfg.DEBUG: + cfg.NAME = "debug--" + cfg.NAME + cfg.LOGGER.WANDB.OFFLINE = True + cfg.LOGGER.VAL_EVERY_STEPS = 1 + + return cfg diff --git a/Evaluator_272/mld/data/HumanML3D_272.py b/Evaluator_272/mld/data/HumanML3D_272.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd5c21f18acc0418851f7e0205fdc6964942118 --- /dev/null +++ b/Evaluator_272/mld/data/HumanML3D_272.py @@ -0,0 +1,131 @@ +import numpy as np +import torch + +from mld.data.humanml.scripts.motion_process import (process_file, + recover_from_ric, recover_from_root_rot6d) + +from .base import BASEDataModule +from .humanml.data.dataset import Text2MotionDatasetV2 +from .humanml.common.skeleton import Skeleton +import torch.nn.functional as F + + +class HumanML3D_272_DataModule(BASEDataModule): + + def __init__(self, + cfg, + batch_size, + num_workers, + collate_fn=None, + phase="train", + **kwargs): + super().__init__(batch_size=batch_size, + num_workers=num_workers, + collate_fn=collate_fn) + + self.save_hyperparameters(logger=False) + self.name = "humanml3d_272" + self.njoints = 22 + self.hparams['njoints']=22 + if phase == "text_only": + self.Dataset = TextOnlyDataset + else: + if cfg.TRAIN.STAGE in ['gpt'] and (not cfg.TEST.inference_vq_code): + if cfg.model.vae_type in ['humanvq']: + self.Dataset = Text2MotionDatasetV2_VQToken + elif cfg.model.vae_type in ['hvq']: + self.Dataset = Text2MotionDatasetV2_Dual_codebook_VQToken + else: + raise NotImplementedError + elif cfg.TEST.inference_vq_code: + self.Dataset = VQMotionDataset + else: + self.Dataset = Text2MotionDatasetV2 + self.cfg = cfg + sample_overrides = { + "split": "val", + "tiny": True, + "progress_bar": False + } + + self._sample_set = self.get_sample_set(overrides=sample_overrides) + + self.nfeats = self._sample_set.nfeats + + def recover_from_local_position(self, final_x, njoint): + + def accumulate_rotations(relative_rotations): + R_total = [relative_rotations[0]] + for R_rel in relative_rotations[1:]: + R_total.append(np.matmul(R_rel, R_total[-1])) + + return np.array(R_total) + + def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + nfrm, _ = final_x.shape + positions_no_heading = final_x[:,8:8+3*njoint].reshape(nfrm, -1, 3) + velocities_root_xy_no_heading = final_x[:,:2] + global_heading_diff_rot = final_x[:,2:8] + + global_heading_rot = accumulate_rotations(rotation_6d_to_matrix(torch.from_numpy(global_heading_diff_rot)).numpy()) + inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1)) + positions_with_heading = np.matmul(np.repeat(inv_global_heading_rot[:, None,:, :], njoint, axis=1), positions_no_heading[...,None]).squeeze(-1) + velocities_root_xyz_no_heading = np.zeros((velocities_root_xy_no_heading.shape[0], 3)) + velocities_root_xyz_no_heading[:, 0] = velocities_root_xy_no_heading[:, 0] + velocities_root_xyz_no_heading[:, 2] = velocities_root_xy_no_heading[:, 1] + velocities_root_xyz_no_heading[1:, :] = np.matmul(inv_global_heading_rot[:-1], velocities_root_xyz_no_heading[1:, :,None]).squeeze(-1) + + root_translation = np.cumsum(velocities_root_xyz_no_heading, axis=0) + positions_with_heading[:, :, 0] += root_translation[:, 0:1] + positions_with_heading[:, :, 2] += root_translation[:, 2:] + + return positions_with_heading + + def feats2joints(self, features, skel=None, motion_type=''): + assert motion_type in [''] + assert features.shape[2] == 272 + mean = torch.tensor(self.hparams.mean).to(features) + std = torch.tensor(self.hparams.std).to(features) + features = features * std + mean + return self.recover_from_local_position(features.reshape(-1, 272).detach().cpu().numpy(), self.njoints).reshape(features.shape[0], -1, 22, 3) + + + def joints2feats(self, features): + features = process_file(features, self.njoints)[0] + return features + + def renorm4t2m(self, features): + ori_mean = torch.tensor(self.hparams.mean).to(features) + ori_std = torch.tensor(self.hparams.std).to(features) + eval_mean = torch.tensor(self.hparams.mean_eval).to(features) + eval_std = torch.tensor(self.hparams.std_eval).to(features) + features = features * ori_std + ori_mean + features = (features - eval_mean) / eval_std + return features + + def renorm2ori(self, features): + mean = torch.tensor(self.hparams.mean).to(features) + std = torch.tensor(self.hparams.std).to(features) + features = features * std + mean + + return features + + + def mm_mode(self, mm_on=True): + if mm_on: + self.is_mm = True + self.name_list = self.test_dataset.name_list + self.mm_list = np.random.choice(self.name_list, + self.cfg.TEST.MM_NUM_SAMPLES, + replace=False) + self.test_dataset.name_list = self.mm_list + else: + self.is_mm = False + self.test_dataset.name_list = self.name_list diff --git a/Evaluator_272/mld/data/__init__.py b/Evaluator_272/mld/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/data/base.py b/Evaluator_272/mld/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..013fe81e4260a14497dc504435db9b2e7ea14f94 --- /dev/null +++ b/Evaluator_272/mld/data/base.py @@ -0,0 +1,105 @@ +from os.path import join as pjoin +import numpy as np +import pytorch_lightning as pl +from torch.utils.data import DataLoader + + +class BASEDataModule(pl.LightningDataModule): + + def __init__(self, collate_fn, batch_size: int, num_workers: int): + super().__init__() + + self.dataloader_options = { + "batch_size": batch_size, + "num_workers": num_workers, + "collate_fn": collate_fn, + } + + self.persistent_workers = True + self.is_mm = False + + def get_sample_set(self, overrides={}): + sample_params = self.hparams.copy() + sample_params.update(overrides) + split_file = pjoin( + eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"), self.cfg.DATASET.VERSION, + self.cfg.EVAL.SPLIT + ".txt", + ) + return self.Dataset(split_file=split_file, **sample_params) + + def __getattr__(self, item): + # train_dataset/val_dataset etc cached like properties + if item.endswith("_dataset") and not item.startswith("_"): + subset = item[:-len("_dataset")] + item_c = "_" + item + if item_c not in self.__dict__: + # todo: config name not consistent + subset = subset.upper() if subset != "val" else "EVAL" + split = eval(f"self.cfg.{subset}.SPLIT") + split_file = pjoin( + eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"), + self.cfg.DATASET.VERSION, + eval(f"self.cfg.{subset}.SPLIT") + ".txt", + ) + self.__dict__[item_c] = self.Dataset(split_file=split_file, + split=split, + **self.hparams) + return getattr(self, item_c) + classname = self.__class__.__name__ + raise AttributeError(f"'{classname}' object has no attribute '{item}'") + + def setup(self, stage=None): + self.stage = stage + # Use the getter the first time to load the data + if stage in (None, "fit"): + _ = self.train_dataset + _ = self.val_dataset + if stage in (None, "test"): + _ = self.test_dataset + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + shuffle=True, + persistent_workers=True, + **self.dataloader_options, + ) + + def predict_dataloader(self): + dataloader_options = self.dataloader_options.copy() + dataloader_options[ + "batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE + dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS + dataloader_options["shuffle"] = False + return DataLoader( + self.test_dataset, + persistent_workers=True, + **dataloader_options, + ) + + def val_dataloader(self): + # overrides batch_size and num_workers + dataloader_options = self.dataloader_options.copy() + dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE + dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS + dataloader_options["shuffle"] = False + + return DataLoader( + self.val_dataset, + persistent_workers=True, + **dataloader_options, + ) + + def test_dataloader(self): + # overrides batch_size and num_workers + dataloader_options = self.dataloader_options.copy() + dataloader_options[ + "batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE + dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS + # dataloader_options["drop_last"] = True + dataloader_options["shuffle"] = False + return DataLoader( + self.test_dataset, + persistent_workers=True, + **dataloader_options, + ) diff --git a/Evaluator_272/mld/data/get_data.py b/Evaluator_272/mld/data/get_data.py new file mode 100644 index 0000000000000000000000000000000000000000..5e64ede3c171dacae32182ae8f92ec78bfce7dc4 --- /dev/null +++ b/Evaluator_272/mld/data/get_data.py @@ -0,0 +1,183 @@ +from os.path import join as pjoin +import numpy as np +# from .humanml.utils.word_vectorizer import WordVectorizer, WordVectorizer_only_text_token +from .utils import * +from .HumanML3D_272 import HumanML3D_272_DataModule + + +def get_mean_std(phase, cfg, dataset_name): + assert dataset_name == 'humanml3d_272' + + data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") + mean = np.load(pjoin(data_root, 'mean_std', cfg.DATASET.VERSION, cfg.DATASET.MOTION_TYPE, "Mean.npy")) + std = np.load(pjoin(data_root, 'mean_std', cfg.DATASET.VERSION, cfg.DATASET.MOTION_TYPE, "Std.npy")) + return mean, std + + + +def get_njoints(dataset_name): + njoints = 22 + return njoints + + +def reget_mean_std(cfg, dataset_name, mean, std): + if 'MINOR_MOTION_TYPE' in cfg.DATASET: + select_motion_type = cfg.DATASET.MINOR_MOTION_TYPE + else: + select_motion_type = cfg.DATASET.MOTION_TYPE + + njoints = get_njoints(dataset_name) + if select_motion_type == 'root_position': + mean = mean[..., :4+(njoints - 1) * 3] + elif select_motion_type == 'root_position_vel': + mean = np.concatenate((mean[..., :4+(njoints - 1) * 3], mean[..., 4+(njoints - 1) * 9: 4+(njoints - 1) * 9 + njoints*3]), axis=0) + elif select_motion_type == 'root_position_rot6d': + mean = np.concatenate((mean[..., :4+(njoints - 1) * 3], mean[..., 4+(njoints - 1) * 3: 4+(njoints - 1) * 9]), axis=0) + elif select_motion_type == 'root_rot6d': + mean = np.concatenate((mean[..., :4], mean[..., 4+(njoints - 1) * 3: 4+(njoints - 1) * 9]), axis=0) + elif select_motion_type in ['all', 'smplx_212', 'vector_263', 'vector_263_ori_humanml', 'smplx_159', '']: + pass + elif select_motion_type == 'root_body_pos_vel_hand_all': + mean = np.concatenate((mean[..., :4+(njoints - 1) * 3], mean[..., 4+(njoints - 1) * 3 + 21 * 6 : 4+(njoints - 1) * 9], mean[..., 4+(njoints - 1) * 9: 4+(njoints - 1) * 9 + njoints*3]), axis=0) + # pass + elif select_motion_type == 'root_body_pos_vel_hand_pos_vel': + mean = np.concatenate((mean[..., :4+(njoints - 1) * 3], mean[..., 4+(njoints - 1) * 9: 4+(njoints - 1) * 9 + njoints*3]), axis=0) + elif select_motion_type == 'root_body_pos_vel_hand_pos': + mean = np.concatenate((mean[..., :4+(njoints - 1) * 3], mean[..., 4+(njoints - 1) * 9 + 22 * 3: 4+(njoints - 1) * 9 + 52*3]), axis=0) + elif select_motion_type == 'root_body_pos_vel_hand_rot': + mean = np.concatenate((mean[..., :4+(22 - 1) * 3], mean[..., 4+(52 - 1) * 3 + (22-1)*6 : 4+(52-1)*9], mean[..., 4+(52 - 1) * 9: 4+(52 - 1) * 9 + 22*3]), axis=0) + elif select_motion_type == 'root_position_vel_only_body': + mean = np.concatenate((mean[..., :4+(22 - 1) * 3], mean[..., 4+(52 - 1) * 9: 4+(52 - 1) * 9 + 22*3]), axis=0) + elif select_motion_type == 'root_body_pos_vel_hand_pos_vel_hand_wrist': + body_pos_mean = mean[..., :4+(22 - 1) * 3] # 67 + left_hand_pos_mean = (mean[..., 4+(22 - 1) * 3:4+(37 - 1) * 3].reshape(15, 3) - body_pos_mean[..., -6:-3]).reshape(-1) # 45 + right_hand_pos_mean = (mean[..., 4+(37 - 1) * 3:4+(52 - 1) * 3].reshape(15, 3) - body_pos_mean[..., -3:]).reshape(-1) # 45 + + body_vel_mean = mean[..., 4+(52 - 1) * 9: 4+(52 - 1) * 9 + 22*3] # 66 + left_hand_vel_mean = (mean[..., 4+(52 - 1) * 9 + 22*3: 4+(52 - 1) * 9 + 22*3 + 15 * 3].reshape(15, 3) - body_vel_mean[..., -6:-3]).reshape(-1) + right_hand_vel_mean = (mean[..., 4+(52 - 1) * 9 + 22*3+ 15 * 3: 4+(52 - 1) * 9 + 22*3 + 15 * 3 + 15 * 3].reshape(15, 3) - body_vel_mean[..., -3:]).reshape(-1) + + mean = np.concatenate((body_pos_mean, left_hand_pos_mean, right_hand_pos_mean, body_vel_mean, left_hand_vel_mean, right_hand_vel_mean), axis=0) + else: + raise NotImplementedError + + if select_motion_type == 'root_position': + std = std[..., :4+(njoints-1)*3] + elif select_motion_type == 'root_position_vel': + std = np.concatenate((std[..., :4+(njoints - 1) * 3], std[..., 4+(njoints - 1) * 9: 4+(njoints - 1) * 9 + njoints*3]), axis=0) + elif select_motion_type == 'root_position_rot6d': + std = np.concatenate((std[..., :4+(njoints - 1) * 3], std[..., 4+(njoints - 1) * 3: 4+(njoints - 1) * 9]), axis=0) + elif select_motion_type == 'root_rot6d': + std = np.concatenate((std[..., :4], std[..., 4+(njoints - 1) * 3: 4+(njoints - 1) * 9]), axis=0) + elif select_motion_type in ['all', 'smplx_212', 'vector_263', 'vector_263_ori_humanml', 'smplx_159', '']: + pass + elif select_motion_type == 'root_body_pos_vel_hand_all': + std = np.concatenate((std[..., :4+(njoints - 1) * 3], std[..., 4+(njoints - 1) * 3 + 21 * 6 : 4+(njoints - 1) * 9], std[..., 4+(njoints - 1) * 9: 4+(njoints - 1) * 9 + njoints*3]), axis=0) + # pass + elif select_motion_type == 'root_body_pos_vel_hand_pos_vel': + std = np.concatenate((std[..., :4+(njoints - 1) * 3], std[..., 4+(njoints - 1) * 9: 4+(njoints - 1) * 9 + njoints*3]), axis=0) + elif select_motion_type == 'root_body_pos_vel_hand_pos': + std = np.concatenate((std[..., :4+(njoints - 1) * 3], std[..., 4+(njoints - 1) * 9 + 22 * 3: 4+(njoints - 1) * 9 + 52*3]), axis=0) + elif select_motion_type == 'root_body_pos_vel_hand_rot': + std = np.concatenate((std[..., :4+(22 - 1) * 3], std[..., 4+(52 - 1) * 3 + (22-1)*6 : 4+(52-1)*9], std[..., 4+(52 - 1) * 9: 4+(52 - 1) * 9 + 22*3]), axis=0) + elif select_motion_type == 'root_position_vel_only_body': + std = np.concatenate((std[..., :4+(22 - 1) * 3], std[..., 4+(52 - 1) * 9: 4+(52 - 1) * 9 + 22*3]), axis=0) + elif select_motion_type == 'root_body_pos_vel_hand_pos_vel_hand_wrist': + std = np.concatenate((std[..., :4+(njoints - 1) * 3], std[..., 4+(njoints - 1) * 9: 4+(njoints - 1) * 9 + njoints*3]), axis=0) + else: + raise NotImplementedError + + return mean, std + +# def get_WordVectorizer(cfg, phase, dataset_name): +# if phase not in ["text_only"]: +# if dataset_name.lower() in ['humanml3d_272']: +# if cfg.model.eval_text_source == 'token': +# return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab", cfg.model.eval_text_encode_way) +# else: +# return WordVectorizer_only_text_token(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab", cfg.model.eval_text_encode_way) +# else: +# raise ValueError("Only support WordVectorizer for HumanML3D_272") +# else: +# return None + + +def get_collate_fn(name, cfg, phase="train"): + if name.lower() in ['humanml3d_272']: + if cfg.model.condition in ['text_all', 'text_face', 'text_body', 'text_hand', 'text_face_body', 'text_seperate', 'only_pose_concat', 'only_pose_fusion'] and (not cfg.TEST.inference_vq_code): + return mld_collate_text_all + elif cfg.TEST.inference_vq_code: + return vq_collate + elif cfg.TRAIN.STAGE in ['gpt'] and (not cfg.TEST.inference_vq_code): + return mld_collate_vq_token + else: + return mld_collate + else: + raise NotImplementedError + + +# map config name to module&path +dataset_module_map = { + 'humanml3d_272': HumanML3D_272_DataModule +} +motion_subdir = {'humanml3d_272': 'motion_data'} + + +def get_datasets(cfg, logger=None, phase="train"): + # get dataset names form cfg + dataset_names = eval(f"cfg.{phase.upper()}.DATASETS") + datasets = [] + for dataset_name in dataset_names: + if dataset_name.lower() in ["humanml3d_272"]: + + if 'MINOR_MOTION_TYPE' in cfg.DATASET: + input_format = cfg.DATASET.MINOR_MOTION_TYPE + else: + input_format = cfg.DATASET.MOTION_TYPE + + data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") + # get mean and std corresponding to dataset + mean, std = get_mean_std(phase, cfg, dataset_name) + mean_eval, std_eval = get_mean_std("val", cfg, dataset_name) + + mean, std = reget_mean_std(cfg, dataset_name, mean, std) + mean_eval, std_eval = reget_mean_std(cfg, dataset_name, mean_eval, std_eval) + + # get WordVectorizer + # wordVectorizer = get_WordVectorizer(cfg, phase, dataset_name) + # get collect_fn + collate_fn = get_collate_fn(dataset_name, cfg, phase) + # get dataset module + + dataset = dataset_module_map[dataset_name.lower()]( + cfg=cfg, + batch_size=cfg.TRAIN.BATCH_SIZE, + num_workers=cfg.TRAIN.NUM_WORKERS, + debug=cfg.DEBUG, + collate_fn=collate_fn, + mean=mean, + std=std, + mean_eval=mean_eval, + std_eval=std_eval, + # w_vectorizer=wordVectorizer, + input_format=cfg.DATASET.MOTION_TYPE, + text_dir=pjoin(data_root, "texts"), + motion_dir=pjoin(data_root, motion_subdir[dataset_name]), + max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN, + min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN, + max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN, + unit_length=eval( + f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"), + ) + datasets.append(dataset) + + else: + raise NotImplementedError + + if input_format == 'root_body_pos_vel_hand_pos_vel': + cfg.DATASET.NFEATS = 313 + else: + cfg.DATASET.NFEATS = datasets[0].nfeats + + cfg.DATASET.NJOINTS = datasets[0].njoints + return datasets diff --git a/Evaluator_272/mld/data/humanml/__init__.py b/Evaluator_272/mld/data/humanml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/data/humanml/common/quaternion.py b/Evaluator_272/mld/data/humanml/common/quaternion.py new file mode 100644 index 0000000000000000000000000000000000000000..dca3d890080a4e91e3f275f442b0aed006562881 --- /dev/null +++ b/Evaluator_272/mld/data/humanml/common/quaternion.py @@ -0,0 +1,423 @@ +# Copyright (c) 2018-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import numpy as np + +_EPS4 = np.finfo(float).eps * 4.0 + +_FLOAT_EPS = np.finfo(np.float64).eps + +# PyTorch-backed implementations +def qinv(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + + +def qinv_np(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return qinv(torch.from_numpy(q).float()).numpy() + + +def qnormalize(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return q / torch.norm(q, dim=-1, keepdim=True) + + +def qmul(q, r): + """ + Multiply quaternion(s) q with quaternion(s) r. + Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. + Returns q*r as a tensor of shape (*, 4). + """ + assert q.shape[-1] == 4 + assert r.shape[-1] == 4 + + original_shape = q.shape + + # Compute outer product + terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) + + w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] + x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] + y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] + z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] + return torch.stack((w, x, y, z), dim=1).view(original_shape) + + +def qrot(q, v): + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + # print(q.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) + + +def qeuler(q, order, epsilon=0, deg=True): + """ + Convert quaternion(s) q to Euler angles. + Expects a tensor of shape (*, 4), where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + + original_shape = list(q.shape) + original_shape[-1] = 3 + q = q.view(-1, 4) + + q0 = q[:, 0] + q1 = q[:, 1] + q2 = q[:, 2] + q3 = q[:, 3] + + if order == 'xyz': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + elif order == 'yzx': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) + elif order == 'zxy': + x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'xzy': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) + elif order == 'yxz': + x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'zyx': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + else: + raise + + if deg: + return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi + else: + return torch.stack((x, y, z), dim=1).view(original_shape) + + +# Numpy-backed implementations + +def qmul_np(q, r): + q = torch.from_numpy(q).contiguous().float() + r = torch.from_numpy(r).contiguous().float() + return qmul(q, r).numpy() + + +def qrot_np(q, v): + q = torch.from_numpy(q).contiguous().float() + v = torch.from_numpy(v).contiguous().float() + return qrot(q, v).numpy() + + +def qeuler_np(q, order, epsilon=0, use_gpu=False): + if use_gpu: + q = torch.from_numpy(q).cuda().float() + return qeuler(q, order, epsilon).cpu().numpy() + else: + q = torch.from_numpy(q).contiguous().float() + return qeuler(q, order, epsilon).numpy() + + +def qfix(q): + """ + Enforce quaternion continuity across the time dimension by selecting + the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) + between two consecutive frames. + + Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. + Returns a tensor of the same shape. + """ + assert len(q.shape) == 3 + assert q.shape[-1] == 4 + + result = q.copy() + dot_products = np.sum(q[1:] * q[:-1], axis=2) + mask = dot_products < 0 + mask = (np.cumsum(mask, axis=0) % 2).astype(bool) + result[1:][mask] *= -1 + return result + + +def euler2quat(e, order, deg=True): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.view(-1, 3) + + ## if euler angles in degrees + if deg: + e = e * np.pi / 180. + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) + ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) + rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.view(original_shape) + + +def expmap_to_quaternion(e): + """ + Convert axis-angle rotations (aka exponential maps) to quaternions. + Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". + Expects a tensor of shape (*, 3), where * denotes any number of dimensions. + Returns a tensor of shape (*, 4). + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + e = e.reshape(-1, 3) + + theta = np.linalg.norm(e, axis=1).reshape(-1, 1) + w = np.cos(0.5 * theta).reshape(-1, 1) + xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e + return np.concatenate((w, xyz), axis=1).reshape(original_shape) + + +def euler_to_quaternion(e, order): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.reshape(-1, 3) + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) + ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) + rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul_np(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.reshape(original_shape) + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def quaternion_to_matrix_np(quaternions): + q = torch.from_numpy(quaternions).contiguous().float() + return quaternion_to_matrix(q).numpy() + + +def quaternion_to_cont6d_np(quaternions): + rotation_mat = quaternion_to_matrix_np(quaternions) + cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) + return cont_6d + + +def quaternion_to_cont6d(quaternions): + rotation_mat = quaternion_to_matrix(quaternions) + cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) + return cont_6d + + +def cont6d_to_matrix(cont6d): + assert cont6d.shape[-1] == 6, "The last dimension must be 6" + x_raw = cont6d[..., 0:3] + y_raw = cont6d[..., 3:6] + + x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) + z = torch.cross(x, y_raw, dim=-1) + z = z / torch.norm(z, dim=-1, keepdim=True) + + y = torch.cross(z, x, dim=-1) + + x = x[..., None] + y = y[..., None] + z = z[..., None] + + mat = torch.cat([x, y, z], dim=-1) + return mat + + +def cont6d_to_matrix_np(cont6d): + q = torch.from_numpy(cont6d).contiguous().float() + return cont6d_to_matrix(q).numpy() + + +def qpow(q0, t, dtype=torch.float): + ''' q0 : tensor of quaternions + t: tensor of powers + ''' + q0 = qnormalize(q0) + theta0 = torch.acos(q0[..., 0]) + + ## if theta0 is close to zero, add epsilon to avoid NaNs + mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) + theta0 = (1 - mask) * theta0 + mask * 10e-10 + v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) + + if isinstance(t, torch.Tensor): + q = torch.zeros(t.shape + q0.shape) + theta = t.view(-1, 1) * theta0.view(1, -1) + else: ## if t is a number + q = torch.zeros(q0.shape) + theta = t * theta0 + + q[..., 0] = torch.cos(theta) + q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) + + return q.to(dtype) + + +def qslerp(q0, q1, t): + ''' + q0: starting quaternion + q1: ending quaternion + t: array of points along the way + + Returns: + Tensor of Slerps: t.shape + q0.shape + ''' + + q0 = qnormalize(q0) + q1 = qnormalize(q1) + q_ = qpow(qmul(q1, qinv(q0)), t) + + return qmul(q_, + q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous()) + + +def qbetween(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v = torch.cross(v0, v1) + w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, + keepdim=True) + return qnormalize(torch.cat([w, v], dim=-1)) + + +def qbetween_np(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v0 = torch.from_numpy(v0).float() + v1 = torch.from_numpy(v1).float() + return qbetween(v0, v1).numpy() + + +def lerp(p0, p1, t): + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]) + + new_shape = t.shape + p0.shape + new_view_t = t.shape + torch.Size([1] * len(p0.shape)) + new_view_p = torch.Size([1] * len(t.shape)) + p0.shape + p0 = p0.view(new_view_p).expand(new_shape) + p1 = p1.view(new_view_p).expand(new_shape) + t = t.view(new_view_t).expand(new_shape) + + return p0 + t * (p1 - p0) diff --git a/Evaluator_272/mld/data/humanml/common/skeleton.py b/Evaluator_272/mld/data/humanml/common/skeleton.py new file mode 100644 index 0000000000000000000000000000000000000000..b2ae85ad14df8c1a8d77e689b1cffbc6c814a979 --- /dev/null +++ b/Evaluator_272/mld/data/humanml/common/skeleton.py @@ -0,0 +1,199 @@ +from .quaternion import * +import scipy.ndimage.filters as filters + +class Skeleton(object): + def __init__(self, offset, kinematic_tree, device): + self.device = device + self._raw_offset_np = offset.numpy() + self._raw_offset = offset.clone().detach().to(device).float() + self._kinematic_tree = kinematic_tree + self._offset = None + self._parents = [0] * len(self._raw_offset) + self._parents[0] = -1 + for chain in self._kinematic_tree: + for j in range(1, len(chain)): + self._parents[chain[j]] = chain[j-1] + + def njoints(self): + return len(self._raw_offset) + + def offset(self): + return self._offset + + def set_offset(self, offsets): + self._offset = offsets.clone().detach().to(self.device).float() + + def kinematic_tree(self): + return self._kinematic_tree + + def parents(self): + return self._parents + + # joints (batch_size, joints_num, 3) + def get_offsets_joints_batch(self, joints): + assert len(joints.shape) == 3 + _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() + for i in range(1, self._raw_offset.shape[0]): + _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] + + self._offset = _offsets.detach() + return _offsets + + # joints (joints_num, 3) + def get_offsets_joints(self, joints): + assert len(joints.shape) == 2 + _offsets = self._raw_offset.clone() + for i in range(1, self._raw_offset.shape[0]): + # print(joints.shape) + _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] + + self._offset = _offsets.detach() + return _offsets + + # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder + # joints (batch_size, joints_num, 3) + def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): + assert len(face_joint_idx) == 4 + '''Get Forward Direction''' + l_hip, r_hip, sdr_r, sdr_l = face_joint_idx + across1 = joints[:, r_hip] - joints[:, l_hip] + across2 = joints[:, sdr_r] - joints[:, sdr_l] + across = across1 + across2 + across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] + # print(across1.shape, across2.shape) + + # forward (batch_size, 3) + forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + if smooth_forward: + forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') + # forward (batch_size, 3) + forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] + + '''Get Root Rotation''' + target = np.array([[0,0,1]]).repeat(len(forward), axis=0) + root_quat = qbetween_np(forward, target) + + '''Inverse Kinematics''' + # quat_params (batch_size, joints_num, 4) + # print(joints.shape[:-1]) + quat_params = np.zeros(joints.shape[:-1] + (4,)) + # print(quat_params.shape) + root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + quat_params[:, 0] = root_quat + # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + for chain in self._kinematic_tree: + R = root_quat + for j in range(len(chain) - 1): + # (batch, 3) + u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) + # print(u.shape) + # (batch, 3) + v = joints[:, chain[j+1]] - joints[:, chain[j]] + v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] + # print(u.shape, v.shape) + rot_u_v = qbetween_np(u, v) + + R_loc = qmul_np(qinv_np(R), rot_u_v) + + quat_params[:,chain[j + 1], :] = R_loc + R = qmul_np(R, R_loc) + + return quat_params + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device) + for i in range(1, len(chain)): + R = qmul(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] + return joints + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(quat_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) + for i in range(1, len(chain)): + R = qmul_np(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] + return joints + + def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(cont6d_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix_np(cont6d_params[:, 0]) + else: + matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) + for i in range(1, len(chain)): + matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]][..., np.newaxis] + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + # skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) + joints[..., 0, :] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix(cont6d_params[:, 0]) + else: + matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device) + for i in range(1, len(chain)): + matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]].unsqueeze(-1) + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + + + + diff --git a/Evaluator_272/mld/data/humanml/data/__init__.py b/Evaluator_272/mld/data/humanml/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/data/humanml/data/dataset.py b/Evaluator_272/mld/data/humanml/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..eb899d49f22f40d85243010323361b1620d86dce --- /dev/null +++ b/Evaluator_272/mld/data/humanml/data/dataset.py @@ -0,0 +1,227 @@ +import codecs as cs +import os +import random +from os.path import join as pjoin + +import numpy as np +import spacy +import torch +from rich.progress import track +from torch.utils import data +from torch.utils.data._utils.collate import default_collate +from tqdm import tqdm +import json + + +def collate_fn(batch): + batch.sort(key=lambda x: x[3], reverse=True) + return default_collate(batch) + + + +def findAllFile(base): + file_path = [] + for root, ds, fs in os.walk(base, followlinks=True): + for f in fs: + fullname = os.path.join(root, f) + file_path.append(fullname) + return file_path + + +class Text2MotionDatasetV2(data.Dataset): + + def __init__( + self, + mean, + std, + split_file, + max_motion_length, + min_motion_length, + max_text_len, + unit_length, + motion_dir, + text_dir, + input_format, + njoints, + tiny=False, + debug=False, + progress_bar=True, + **kwargs, + ): + + self.max_length = 20 + self.pointer = 0 + self.max_motion_length = max_motion_length + + self.min_motion_length = min_motion_length + self.max_text_len = max_text_len + self.unit_length = unit_length + data_dict = {} + id_list = [] + with cs.open(split_file, "r") as f: + for line in f.readlines(): + id_list.append(line.strip()) + self.id_list = id_list + if tiny or debug: + progress_bar = False + maxdata = 10 if tiny else 100 + else: + maxdata = 1e10 + + if progress_bar: + enumerator = enumerate( + track( + id_list, + f"Loading {split_file.split('/')[-2]} {split_file.split('/')[-1].split('.')[0]}", + )) + else: + enumerator = enumerate(id_list) + count = 0 + bad_count = 0 + miss_count = 0 + new_name_list = [] + length_list = [] + + for i, name in enumerator: + if count > maxdata: + break + try: + + motion = np.load(pjoin(motion_dir, name + ".npy")) + + if input_format == 'root_position': + motion = motion[..., :4+(njoints-1)*3] + elif input_format == 'root_position_vel': + motion = np.concatenate((motion[..., :4+(njoints - 1) * 3], motion[..., 4+(njoints - 1) * 9: 4+(njoints - 1) * 9 + njoints*3]), axis=-1) + elif input_format == 'root_position_rot6d': + motion = np.concatenate((motion[..., :4+(njoints - 1) * 3], motion[..., 4+(njoints - 1) * 3: 4+(njoints - 1) * 9]), axis=-1) + elif input_format == 'root_rot6d': + motion = np.concatenate((motion[..., :4], motion[..., 4+(njoints - 1) * 3: 4+(njoints - 1) * 9]), axis=-1) + elif input_format in ['vector_263', '']: + pass + else: + raise NotImplementedError + + + text_data = [] + flag = False + with cs.open(pjoin(text_dir, name + ".txt")) as f: + for line in f.readlines(): + text_dict = {} + line_split = line.strip().split("#") + caption = line_split[0] + tokens = line_split[1].split(" ") + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict["caption"] = caption + text_dict["tokens"] = tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + try: + n_motion = motion[int(f_tag * 30):int(to_tag * 30)] + + new_name = ( + random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + + "_" + name) + while new_name in data_dict: + new_name = (random.choice( + "ABCDEFGHIJKLMNOPQRSTUVW") + "_" + + name) + data_dict[new_name] = { + "motion": n_motion, + "length": len(n_motion), + "text": [text_dict], + } + new_name_list.append(new_name) + length_list.append(len(n_motion)) + except: + print(line_split) + print(line_split[2], line_split[3], f_tag, + to_tag, name) + + + if flag: + data_dict[name] = { + "motion": motion, + "length": len(motion), + "text": text_data, + } + new_name_list.append(name) + length_list.append(len(motion)) + count += 1 + + except: + miss_count += 1 + pass + + print(f'Here are {miss_count} not in dataset!') + + name_list, length_list = zip( + *sorted(zip(new_name_list, length_list), key=lambda x: x[1])) + + + + self.mean = mean + self.std = std + + self.length_arr = np.array(length_list) + self.data_dict = data_dict + self.nfeats = motion.shape[1] + self.name_list = name_list + self.reset_max_len(self.max_length) + + + def reset_max_len(self, length): + assert length <= self.max_motion_length + self.pointer = np.searchsorted(self.length_arr, length) + print("Pointer Pointing at %d" % self.pointer) + self.max_length = length + + def inv_transform(self, data): + return data * self.std + self.mean + + def __len__(self): + return len(self.name_list) - self.pointer + + def __getitem__(self, item): + idx = self.pointer + item + data = self.data_dict[self.name_list[idx]] + + retrieval_name = self.name_list[idx].split('_')[-1] + + motion, m_length, text_list = data["motion"], data["length"], data["text"] + + # Randomly select a caption + text_data = random.choice(text_list) + # caption, tokens = text_data["caption"], text_data["tokens"] + caption = text_data["caption"] + + # Crop the motions in to times of 4, and introduce small variations + if self.unit_length < 10: + coin2 = np.random.choice(["single", "single", "double"]) + else: + coin2 = "single" + + if coin2 == "double": + m_length = (m_length // self.unit_length - 1) * self.unit_length + elif coin2 == "single": + m_length = (m_length // self.unit_length) * self.unit_length + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx:idx + m_length] + "Normalization" + motion = (motion - self.mean) / self.std + + if np.any(np.isnan(motion)): + raise ValueError("nan in motion") + + return ( + caption, + motion, + m_length, + retrieval_name + ) diff --git a/Evaluator_272/mld/data/humanml/scripts/motion_process.py b/Evaluator_272/mld/data/humanml/scripts/motion_process.py new file mode 100644 index 0000000000000000000000000000000000000000..12bbbfa13ede245946339a417b1f8a1f36f7ac9f --- /dev/null +++ b/Evaluator_272/mld/data/humanml/scripts/motion_process.py @@ -0,0 +1,576 @@ +from os.path import join as pjoin + +from ..common.skeleton import Skeleton +import numpy as np +import os +from ..common.quaternion import * +from ..utils.paramUtil import * + +import torch +from tqdm import tqdm + +# positions (batch, joint_num, 3) +def uniform_skeleton(positions, target_offset): + src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu') + src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0])) + src_offset = src_offset.numpy() + tgt_offset = target_offset.numpy() + # print(src_offset) + # print(tgt_offset) + '''Calculate Scale Ratio as the ratio of legs''' + src_leg_len = np.abs(src_offset[l_idx1]).max() + np.abs(src_offset[l_idx2]).max() + tgt_leg_len = np.abs(tgt_offset[l_idx1]).max() + np.abs(tgt_offset[l_idx2]).max() + + scale_rt = tgt_leg_len / src_leg_len + # print(scale_rt) + src_root_pos = positions[:, 0] + tgt_root_pos = src_root_pos * scale_rt + + '''Inverse Kinematics''' + quat_params = src_skel.inverse_kinematics_np(positions, face_joint_indx) + # print(quat_params.shape) + + '''Forward Kinematics''' + src_skel.set_offset(target_offset) + new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos) + return new_joints + + +def extract_features(positions, feet_thre, n_raw_offsets, kinematic_chain, face_joint_indx, fid_r, fid_l): + global_positions = positions.copy() + """ Get Foot Contacts """ + + def foot_detect(positions, thres): + velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0]) + + feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2 + feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2 + feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2 + # feet_l_h = positions[:-1,fid_l,1] + # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float64) + feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float64) + + feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2 + feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2 + feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2 + # feet_r_h = positions[:-1,fid_r,1] + # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float64) + feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float64) + return feet_l, feet_r + + # + feet_l, feet_r = foot_detect(positions, feet_thre) + # feet_l, feet_r = foot_detect(positions, 0.002) + + '''Quaternion and Cartesian representation''' + r_rot = None + + def get_rifke(positions): + '''Local pose''' + positions[..., 0] -= positions[:, 0:1, 0] + positions[..., 2] -= positions[:, 0:1, 2] + '''All pose face Z+''' + positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions) + return positions + + def get_quaternion(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False) + + '''Fix Quaternion Discontinuity''' + quat_params = qfix(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + quat_params[1:, 0] = r_velocity + # (seq_len, joints_num, 4) + return quat_params, r_velocity, velocity, r_rot + + def get_cont6d_params(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True) + + '''Quaternion to continuous 6D''' + cont_6d_params = quaternion_to_cont6d_np(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + # (seq_len, joints_num, 4) + return cont_6d_params, r_velocity, velocity, r_rot + + cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions) + positions = get_rifke(positions) + + # trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0) + # r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]]) + + # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*') + # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r') + # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g') + # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y') + # plt.xlabel('x') + # plt.ylabel('z') + # plt.axis('equal') + # plt.show() + + '''Root height''' + root_y = positions[:, 0, 1:2] + + '''Root rotation and linear velocity''' + # (seq_len-1, 1) rotation velocity along y-axis + # (seq_len-1, 2) linear velovity on xz plane + r_velocity = np.arcsin(r_velocity[:, 2:3]) + l_velocity = velocity[:, [0, 2]] + # print(r_velocity.shape, l_velocity.shape, root_y.shape) + root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1) + + '''Get Joint Rotation Representation''' + # (seq_len, (joints_num-1) *6) quaternion for skeleton joints + rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1) + + '''Get Joint Rotation Invariant Position Represention''' + # (seq_len, (joints_num-1)*3) local joint position + ric_data = positions[:, 1:].reshape(len(positions), -1) + + '''Get Joint Velocity Representation''' + # (seq_len-1, joints_num*3) + local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1), + global_positions[1:] - global_positions[:-1]) + local_vel = local_vel.reshape(len(local_vel), -1) + + data = root_data + data = np.concatenate([data, ric_data[:-1]], axis=-1) + data = np.concatenate([data, rot_data[:-1]], axis=-1) + # print(dataset.shape, local_vel.shape) + data = np.concatenate([data, local_vel], axis=-1) + data = np.concatenate([data, feet_l, feet_r], axis=-1) + + return data + + +def process_file(positions, feet_thre): + # (seq_len, joints_num, 3) + # '''Down Sample''' + # positions = positions[::ds_num] + + '''Uniform Skeleton''' + positions = uniform_skeleton(positions, tgt_offsets) + + '''Put on Floor''' + floor_height = positions.min(axis=0).min(axis=0)[1] + positions[:, :, 1] -= floor_height + # print(floor_height) + + # plot_3d_motion("./positions_1.mp4", kinematic_chain, positions, 'title', fps=20) + + '''XZ at origin''' + root_pos_init = positions[0] + root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1]) + positions = positions - root_pose_init_xz + + # '''Move the first pose to origin ''' + # root_pos_init = positions[0] + # positions = positions - root_pos_init[0] + + '''All initially face Z+''' + r_hip, l_hip, sdr_r, sdr_l = face_joint_indx + across1 = root_pos_init[r_hip] - root_pos_init[l_hip] + across2 = root_pos_init[sdr_r] - root_pos_init[sdr_l] + across = across1 + across2 + across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] + + # forward (3,), rotate around y-axis + forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + # forward (3,) + forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] + + # print(forward_init) + + target = np.array([[0, 0, 1]]) + root_quat_init = qbetween_np(forward_init, target) + root_quat_init = np.ones(positions.shape[:-1] + (4,)) * root_quat_init + + positions_b = positions.copy() + + positions = qrot_np(root_quat_init, positions) + + # plot_3d_motion("./positions_2.mp4", kinematic_chain, positions, 'title', fps=20) + + '''New ground truth positions''' + global_positions = positions.copy() + + # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*') + # plt.plot(positions[:, 0, 0], positions[:, 0, 2], marker='o', color='r') + # plt.xlabel('x') + # plt.ylabel('z') + # plt.axis('equal') + # plt.show() + + """ Get Foot Contacts """ + + def foot_detect(positions, thres): + velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0]) + + feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2 + feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2 + feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2 + # feet_l_h = positions[:-1,fid_l,1] + # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float64) + feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float64) + + feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2 + feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2 + feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2 + # feet_r_h = positions[:-1,fid_r,1] + # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float64) + feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float64) + return feet_l, feet_r + # + feet_l, feet_r = foot_detect(positions, feet_thre) + # feet_l, feet_r = foot_detect(positions, 0.002) + + '''Quaternion and Cartesian representation''' + r_rot = None + + def get_rifke(positions): + '''Local pose''' + positions[..., 0] -= positions[:, 0:1, 0] + positions[..., 2] -= positions[:, 0:1, 2] + '''All pose face Z+''' + positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions) + return positions + + def get_quaternion(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False) + + '''Fix Quaternion Discontinuity''' + quat_params = qfix(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + quat_params[1:, 0] = r_velocity + # (seq_len, joints_num, 4) + return quat_params, r_velocity, velocity, r_rot + + def get_cont6d_params(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True) + + '''Quaternion to continuous 6D''' + cont_6d_params = quaternion_to_cont6d_np(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + # (seq_len, joints_num, 4) + return cont_6d_params, r_velocity, velocity, r_rot + + cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions) + positions = get_rifke(positions) + + # trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0) + # r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]]) + + # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*') + # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r') + # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g') + # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y') + # plt.xlabel('x') + # plt.ylabel('z') + # plt.axis('equal') + # plt.show() + + '''Root height''' + root_y = positions[:, 0, 1:2] + + '''Root rotation and linear velocity''' + # (seq_len-1, 1) rotation velocity along y-axis + # (seq_len-1, 2) linear velovity on xz plane + r_velocity = np.arcsin(r_velocity[:, 2:3]) + l_velocity = velocity[:, [0, 2]] + # print(r_velocity.shape, l_velocity.shape, root_y.shape) + root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1) + + '''Get Joint Rotation Representation''' + # (seq_len, (joints_num-1) *6) quaternion for skeleton joints + rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1) + + '''Get Joint Rotation Invariant Position Represention''' + # (seq_len, (joints_num-1)*3) local joint position + ric_data = positions[:, 1:].reshape(len(positions), -1) + + '''Get Joint Velocity Representation''' + # (seq_len-1, joints_num*3) + local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1), + global_positions[1:] - global_positions[:-1]) + local_vel = local_vel.reshape(len(local_vel), -1) + + data = root_data + data = np.concatenate([data, ric_data[:-1]], axis=-1) + data = np.concatenate([data, rot_data[:-1]], axis=-1) + # print(dataset.shape, local_vel.shape) + data = np.concatenate([data, local_vel], axis=-1) + data = np.concatenate([data, feet_l, feet_r], axis=-1) + + return data, global_positions, positions, l_velocity + + +# Recover global angle and positions for rotation dataset +# root_rot_velocity (B, seq_len, 1) +# root_linear_velocity (B, seq_len, 2) +# root_y (B, seq_len, 1) +# ric_data (B, seq_len, (joint_num - 1)*3) +# rot_data (B, seq_len, (joint_num - 1)*6) +# local_velocity (B, seq_len, joint_num*3) +# foot contact (B, seq_len, 4) +def recover_root_rot_pos(data): + rot_vel = data[..., 0] + r_rot_ang = torch.zeros_like(rot_vel).to(data.device) + '''Get Y-axis rotation from rotation velocity''' + r_rot_ang[..., 1:] = rot_vel[..., :-1] + r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) + + r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device) + r_rot_quat[..., 0] = torch.cos(r_rot_ang) + r_rot_quat[..., 2] = torch.sin(r_rot_ang) + + r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device) + r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] + '''Add Y-axis rotation to root position''' + r_pos = qrot(qinv(r_rot_quat), r_pos) + + r_pos = torch.cumsum(r_pos, dim=-2) + + r_pos[..., 1] = data[..., 3] + return r_rot_quat, r_pos + + +def recover_from_rot(data, joints_num, skeleton): + r_rot_quat, r_pos = recover_root_rot_pos(data) + + r_rot_cont6d = quaternion_to_cont6d(r_rot_quat) + + start_indx = 1 + 2 + 1 + (joints_num - 1) * 3 + end_indx = start_indx + (joints_num - 1) * 6 + cont6d_params = data[..., start_indx:end_indx] + # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape) + cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1) + cont6d_params = cont6d_params.view(-1, joints_num, 6) + + positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos) + + return positions + +def recover_from_root_rot6d(data, joints_num, skeleton): + + r_rot_quat, r_pos = recover_root_rot_pos(data) + + r_rot_cont6d = quaternion_to_cont6d(r_rot_quat) + + start_indx = 1 + 2 + 1 + end_indx = start_indx + (joints_num - 1) * 6 + cont6d_params = data[..., start_indx:end_indx] + # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape) + cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1) + cont6d_params = cont6d_params.view(-1, joints_num, 6) + r_pos = r_pos.view(-1,3) + positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos) + return positions + +def recover_from_body_pos_vel_hand_rot(data, joints_num, skeleton): + assert len(skeleton) == 2 + body_skel = skeleton[0] + all_skel = skeleton[1] + assert joints_num == 52 + face_joint_indx = [2, 1, 17, 16] + + r_rot_quat, r_pos = recover_root_rot_pos(data) + + r_rot_cont6d = quaternion_to_cont6d(r_rot_quat) + + pos_body_data = data[..., : 4 + 21 * 3] + pos_body_data_global = recover_from_ric(pos_body_data, 22) + # pos_body_data_global shape (bs, frame, 22, 3) + quat_params = body_skel.inverse_kinematics(pos_body_data_global, face_joint_indx) + bs = quat_params.shape[0] + frame = quat_params.shape[1] + cont6d_params = quaternion_to_cont6d(quat_params).view(bs, frame, -1) + + # cont6d_params + rot6d_hand_data = data[..., 4 + 21 * 3: 4 + 21 * 3 + 30 * 6] + + cont6d_params = torch.cat([cont6d_params, rot6d_hand_data], dim=-1) + cont6d_params = cont6d_params.view(-1, joints_num, 6) + r_pos = r_pos.view(-1,3) + positions = all_skel.forward_kinematics_cont6d(cont6d_params, r_pos) + return positions + + +def recover_rot(data): + # dataset [bs, seqlen, 263/251] HumanML/KIT + joints_num = 22 if data.shape[-1] == 263 else 21 + r_rot_quat, r_pos = recover_root_rot_pos(data) + r_pos_pad = torch.cat([r_pos, torch.zeros_like(r_pos)], dim=-1).unsqueeze(-2) + r_rot_cont6d = quaternion_to_cont6d(r_rot_quat) + start_indx = 1 + 2 + 1 + (joints_num - 1) * 3 + end_indx = start_indx + (joints_num - 1) * 6 + cont6d_params = data[..., start_indx:end_indx] + cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1) + cont6d_params = cont6d_params.view(-1, joints_num, 6) + cont6d_params = torch.cat([cont6d_params, r_pos_pad], dim=-2) + return cont6d_params + + +def recover_from_ric(data, joints_num): + r_rot_quat, r_pos = recover_root_rot_pos(data) + positions = data[..., 4:(joints_num - 1) * 3 + 4] + positions = positions.view(positions.shape[:-1] + (-1, 3)) + + '''Add Y-axis rotation to local joints''' + positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) + + '''Add root XZ to joints''' + positions[..., 0] += r_pos[..., 0:1] + positions[..., 2] += r_pos[..., 2:3] + + '''Concate root and joints''' + positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) + + return positions + + +''' +For Text2Motion Dataset +''' +''' +if __name__ == "__main__": + example_id = "000021" + # Lower legs + l_idx1, l_idx2 = 5, 8 + # Right/Left foot + fid_r, fid_l = [8, 11], [7, 10] + # Face direction, r_hip, l_hip, sdr_r, sdr_l + face_joint_indx = [2, 1, 17, 16] + # l_hip, r_hip + r_hip, l_hip = 2, 1 + joints_num = 22 + # ds_num = 8 + data_dir = '../dataset/pose_data_raw/joints/' + save_dir1 = '../dataset/pose_data_raw/new_joints/' + save_dir2 = '../dataset/pose_data_raw/new_joint_vecs/' + + n_raw_offsets = torch.from_numpy(t2m_raw_offsets) + kinematic_chain = t2m_kinematic_chain + + # Get offsets of target skeleton + example_data = np.load(os.path.join(data_dir, example_id + '.npy')) + example_data = example_data.reshape(len(example_data), -1, 3) + example_data = torch.from_numpy(example_data) + tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu') + # (joints_num, 3) + tgt_offsets = tgt_skel.get_offsets_joints(example_data[0]) + # print(tgt_offsets) + + source_list = os.listdir(data_dir) + frame_num = 0 + for source_file in tqdm(source_list): + source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num] + try: + dataset, ground_positions, positions, l_velocity = process_file(source_data, 0.002) + rec_ric_data = recover_from_ric(torch.from_numpy(dataset).unsqueeze(0).float(), joints_num) + np.save(pjoin(save_dir1, source_file), rec_ric_data.squeeze().numpy()) + np.save(pjoin(save_dir2, source_file), dataset) + frame_num += dataset.shape[0] + except Exception as e: + print(source_file) + print(e) + + print('Total clips: %d, Frames: %d, Duration: %fm' % + (len(source_list), frame_num, frame_num / 20 / 60)) +''' + +if __name__ == "__main__": + example_id = "03950_gt" + # Lower legs + l_idx1, l_idx2 = 17, 18 + # Right/Left foot + fid_r, fid_l = [14, 15], [19, 20] + # Face direction, r_hip, l_hip, sdr_r, sdr_l + face_joint_indx = [11, 16, 5, 8] + # l_hip, r_hip + r_hip, l_hip = 11, 16 + joints_num = 21 + # ds_num = 8 + data_dir = '../dataset/kit_mocap_dataset/joints/' + save_dir1 = '../dataset/kit_mocap_dataset/new_joints/' + save_dir2 = '../dataset/kit_mocap_dataset/new_joint_vecs/' + + n_raw_offsets = torch.from_numpy(kit_raw_offsets) + kinematic_chain = kit_kinematic_chain + + '''Get offsets of target skeleton''' + example_data = np.load(os.path.join(data_dir, example_id + '.npy')) + example_data = example_data.reshape(len(example_data), -1, 3) + example_data = torch.from_numpy(example_data) + tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu') + # (joints_num, 3) + tgt_offsets = tgt_skel.get_offsets_joints(example_data[0]) + # print(tgt_offsets) + + source_list = os.listdir(data_dir) + frame_num = 0 + '''Read source dataset''' + for source_file in tqdm(source_list): + source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num] + try: + name = ''.join(source_file[:-7].split('_')) + '.npy' + data, ground_positions, positions, l_velocity = process_file(source_data, 0.05) + rec_ric_data = recover_from_ric(torch.from_numpy(data).unsqueeze(0).float(), joints_num) + if np.isnan(rec_ric_data.numpy()).any(): + print(source_file) + continue + np.save(pjoin(save_dir1, name), rec_ric_data.squeeze().numpy()) + np.save(pjoin(save_dir2, name), data) + frame_num += data.shape[0] + except Exception as e: + print(source_file) + print(e) + + print('Total clips: %d, Frames: %d, Duration: %fm' % + (len(source_list), frame_num, frame_num / 12.5 / 60)) diff --git a/Evaluator_272/mld/data/humanml/utils/__init__.py b/Evaluator_272/mld/data/humanml/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/data/humanml/utils/metrics.py b/Evaluator_272/mld/data/humanml/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..8357de6035068fce1f67669e2f021175a61cd8f5 --- /dev/null +++ b/Evaluator_272/mld/data/humanml/utils/metrics.py @@ -0,0 +1,142 @@ +import numpy as np +from scipy import linalg + +def euclidean_distance_matrix(matrix1, matrix2): + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dist: N1 x N2 + dist[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * np.dot(matrix1, matrix2.T) + d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) + d3 = np.sum(np.square(matrix2), axis=1) + dists = np.sqrt(d1 + d2 + d3) + return dists + +def calculate_top_k(mat, top_k): + size = mat.shape[0] + gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) + bool_mat = (mat == gt_mat) + correct_vec = False + top_k_list = [] + for i in range(top_k): + correct_vec = (correct_vec | bool_mat[:, i]) + top_k_list.append(correct_vec[:, None]) + top_k_mat = np.concatenate(top_k_list, axis=1) + return top_k_mat + + +def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False): + dist_mat = euclidean_distance_matrix(embedding1, embedding2) + argmax = np.argsort(dist_mat, axis=1) + top_k_mat = calculate_top_k(argmax, top_k) + if sum_all: + return top_k_mat.sum(axis=0) + else: + return top_k_mat + + +def calculate_matching_score(embedding1, embedding2, sum_all=False): + assert len(embedding1.shape) == 2 + assert embedding1.shape[0] == embedding2.shape[0] + assert embedding1.shape[1] == embedding2.shape[1] + + dist = linalg.norm(embedding1 - embedding2, axis=1) + if sum_all: + return dist.sum(axis=0) + else: + return dist + + + +def calculate_activation_statistics(activations): + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + mu = np.mean(activations, axis=0) + cov = np.cov(activations, rowvar=False) + return mu, cov + + +def calculate_diversity(activation, diversity_times): + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, diversity_times, replace=False) + second_indices = np.random.choice(num_samples, diversity_times, replace=False) + dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) + return dist.mean() + + +def calculate_multimodality(activation, multimodality_times): + assert len(activation.shape) == 3 + assert activation.shape[1] > multimodality_times + num_per_sent = activation.shape[1] + + first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) + second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) + dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) + return dist.mean() + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative dataset set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative dataset set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) \ No newline at end of file diff --git a/Evaluator_272/mld/data/humanml/utils/paramUtil.py b/Evaluator_272/mld/data/humanml/utils/paramUtil.py new file mode 100644 index 0000000000000000000000000000000000000000..a9f1708b85ca80a9051cb3675cec9b999a0d0e2b --- /dev/null +++ b/Evaluator_272/mld/data/humanml/utils/paramUtil.py @@ -0,0 +1,63 @@ +import numpy as np + +# Define a kinematic tree for the skeletal struture +kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] + +kit_raw_offsets = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1] + ] +) + +t2m_raw_offsets = np.array([[0,0,0], + [1,0,0], + [-1,0,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,0,1], + [0,0,1], + [0,1,0], + [1,0,0], + [-1,0,0], + [0,0,1], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0]]) + +t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] +t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] +t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] + + +kit_tgt_skel_id = '03950' + +t2m_tgt_skel_id = '000021' + diff --git a/Evaluator_272/mld/data/humanml/utils/plot_script.py b/Evaluator_272/mld/data/humanml/utils/plot_script.py new file mode 100644 index 0000000000000000000000000000000000000000..0118cdac276b6c330730196953a5510a8e72f786 --- /dev/null +++ b/Evaluator_272/mld/data/humanml/utils/plot_script.py @@ -0,0 +1,103 @@ +import math +# import cv2 +from textwrap import wrap + +import matplotlib +import matplotlib.pyplot as plt +import mpl_toolkits.mplot3d.axes3d as p3 +import numpy as np +from matplotlib.animation import FFMpegFileWriter, FuncAnimation +from mpl_toolkits.mplot3d import Axes3D +from mpl_toolkits.mplot3d.art3d import Poly3DCollection + +import mld.data.humanml.utils.paramUtil as paramUtil + +skeleton = paramUtil.t2m_kinematic_chain + + +def list_cut_average(ll, intervals): + if intervals == 1: + return ll + + bins = math.ceil(len(ll) * 1.0 / intervals) + ll_new = [] + for i in range(bins): + l_low = intervals * i + l_high = l_low + intervals + l_high = l_high if l_high < len(ll) else len(ll) + ll_new.append(np.mean(ll[l_low:l_high])) + return ll_new + + +def plot_3d_motion(save_path, joints, title, figsize=(3, 3), fps=120, radius=3, kinematic_tree=skeleton): + matplotlib.use('Agg') + title = '\n'.join(wrap(title, 20)) + + def init(): + ax.set_xlim3d([-radius / 2, radius / 2]) + ax.set_ylim3d([0, radius]) + ax.set_zlim3d([-radius / 3., radius * 2 / 3.]) + fig.suptitle(title, fontsize=10) + ax.grid(b=False) + + def plot_xzPlane(minx, maxx, miny, minz, maxz): + verts = [ + [minx, miny, minz], + [minx, miny, maxz], + [maxx, miny, maxz], + [maxx, miny, minz] + ] + xz_plane = Poly3DCollection([verts]) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) + ax.add_collection3d(xz_plane) + + + data = joints.copy().reshape(len(joints), -1, 3) + fig = plt.figure(figsize=figsize) + plt.tight_layout() + ax = p3.Axes3D(fig) + init() + MINS = data.min(axis=0).min(axis=0) + MAXS = data.max(axis=0).max(axis=0) + + colors = ["#DD5A37", "#D69E00", "#B75A39", "#DD5A37", "#D69E00", + "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", + "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", ] + + frame_number = data.shape[0] + + height_offset = MINS[1] + data[:, :, 1] -= height_offset + trajec = data[:, 0, [0, 2]] + + data[..., 0] -= data[:, 0:1, 0] + data[..., 2] -= data[:, 0:1, 2] + + + def update(index): + + ax.view_init(elev=120, azim=-90) + ax.dist = 7.5 + plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], + MAXS[2] - trajec[index, 1]) + + + for i, (chain, color) in enumerate(zip(kinematic_tree, colors)): + # print(color) + if i < 5: + linewidth = 4.0 + else: + linewidth = 2.0 + ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, + color=color) + + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + ani = FuncAnimation(fig, update, frames=frame_number, + interval=1000 / fps, repeat=False) + + ani.save(save_path, fps=fps) + plt.close() diff --git a/Evaluator_272/mld/data/humanml/utils/utils.py b/Evaluator_272/mld/data/humanml/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e7044995214e321a44b92f619e1039a5bb4fbe62 --- /dev/null +++ b/Evaluator_272/mld/data/humanml/utils/utils.py @@ -0,0 +1,163 @@ +import os +import numpy as np +# import cv2 +from PIL import Image +import paramUtil +import math +import time +import matplotlib.pyplot as plt +from scipy.ndimage import gaussian_filter + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + +COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + +MISSING_VALUE = -1 + +def save_image(image_numpy, image_path): + img_pil = Image.fromarray(image_numpy) + img_pil.save(image_path) + + +def save_logfile(log_loss, save_path): + with open(save_path, 'wt') as f: + for k, v in log_loss.items(): + w_line = k + for digit in v: + w_line += ' %.3f' % digit + f.write(w_line + '\n') + + +def print_current_loss(start_time, niter_state, losses, epoch=None, sub_epoch=None, + inner_iter=None, tf_ratio=None, sl_steps=None): + + def as_minutes(s): + m = math.floor(s / 60) + s -= m * 60 + return '%dm %ds' % (m, s) + + def time_since(since, percent): + now = time.time() + s = now - since + es = s / percent + rs = es - s + return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) + + if epoch is not None: + print('epoch: %3d niter: %6d sub_epoch: %2d inner_iter: %4d' % (epoch, niter_state, sub_epoch, inner_iter), end=" ") + + + now = time.time() + message = '%s'%(as_minutes(now - start_time)) + + for k, v in losses.items(): + message += ' %s: %.4f ' % (k, v) + message += ' sl_length:%2d tf_ratio:%.2f'%(sl_steps, tf_ratio) + print(message) + +def print_current_loss_decomp(start_time, niter_state, total_niters, losses, epoch=None, inner_iter=None): + + def as_minutes(s): + m = math.floor(s / 60) + s -= m * 60 + return '%dm %ds' % (m, s) + + def time_since(since, percent): + now = time.time() + s = now - since + es = s / percent + rs = es - s + return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) + + print('epoch: %03d inner_iter: %5d' % (epoch, inner_iter), end=" ") + # now = time.time() + message = '%s niter: %07d completed: %3d%%)'%(time_since(start_time, niter_state / total_niters), niter_state, niter_state / total_niters * 100) + for k, v in losses.items(): + message += ' %s: %.4f ' % (k, v) + print(message) + + +def compose_gif_img_list(img_list, fp_out, duration): + img, *imgs = [Image.fromarray(np.array(image)) for image in img_list] + img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False, + save_all=True, loop=0, duration=duration) + + +def save_images(visuals, image_path): + if not os.path.exists(image_path): + os.makedirs(image_path) + + for i, (label, img_numpy) in enumerate(visuals.items()): + img_name = '%d_%s.jpg' % (i, label) + save_path = os.path.join(image_path, img_name) + save_image(img_numpy, save_path) + + +def save_images_test(visuals, image_path, from_name, to_name): + if not os.path.exists(image_path): + os.makedirs(image_path) + + for i, (label, img_numpy) in enumerate(visuals.items()): + img_name = "%s_%s_%s" % (from_name, to_name, label) + save_path = os.path.join(image_path, img_name) + save_image(img_numpy, save_path) + + +def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)): + # print(col, row) + compose_img = compose_image(img_list, col, row, img_size) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + img_path = os.path.join(save_dir, img_name) + compose_img.save(img_path) + + +def compose_image(img_list, col, row, img_size): + to_image = Image.new('RGB', (col * img_size[0], row * img_size[1])) + for y in range(0, row): + for x in range(0, col): + from_img = Image.fromarray(img_list[y * col + x]) + + paste_area = (x * img_size[0], y*img_size[1], + (x + 1) * img_size[0], (y + 1) * img_size[1]) + to_image.paste(from_img, paste_area) + return to_image + + +def plot_loss_curve(losses, save_path, intervals=500): + plt.figure(figsize=(10, 5)) + plt.title("Loss During Training") + for key in losses.keys(): + plt.plot(list_cut_average(losses[key], intervals), label=key) + plt.xlabel("Iterations/" + str(intervals)) + plt.ylabel("Loss") + plt.legend() + plt.savefig(save_path) + plt.show() + + +def list_cut_average(ll, intervals): + if intervals == 1: + return ll + + bins = math.ceil(len(ll) * 1.0 / intervals) + ll_new = [] + for i in range(bins): + l_low = intervals * i + l_high = l_low + intervals + l_high = l_high if l_high < len(ll) else len(ll) + ll_new.append(np.mean(ll[l_low:l_high])) + return ll_new + + +def motion_temporal_filter(motion, sigma=1): + motion = motion.reshape(motion.shape[0], -1) + for i in range(motion.shape[1]): + motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") + return motion.reshape(motion.shape[0], -1, 3) + diff --git a/Evaluator_272/mld/data/humanml/utils/word_vectorizer.py b/Evaluator_272/mld/data/humanml/utils/word_vectorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..dc48321b4c8ba11607610ab1cb220fba23b9febf --- /dev/null +++ b/Evaluator_272/mld/data/humanml/utils/word_vectorizer.py @@ -0,0 +1,143 @@ +import numpy as np +import pickle +from os.path import join as pjoin + +POS_enumerator = { + 'VERB': 0, + 'NOUN': 1, + 'DET': 2, + 'ADP': 3, + 'NUM': 4, + 'AUX': 5, + 'PRON': 6, + 'ADJ': 7, + 'ADV': 8, + 'Loc_VIP': 9, + 'Body_VIP': 10, + 'Obj_VIP': 11, + 'Act_VIP': 12, + 'Desc_VIP': 13, + 'OTHER': 14, +} + +Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', + 'up', 'down', 'straight', 'curve') + +Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') + +Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') + +Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', + 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', + 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') + +Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', + 'angrily', 'sadly') + +VIP_dict = { + 'Loc_VIP': Loc_list, + 'Body_VIP': Body_list, + 'Obj_VIP': Obj_List, + 'Act_VIP': Act_list, + 'Desc_VIP': Desc_list, +} + + +class WordVectorizer(object): + def __init__(self, meta_root, prefix, text_encode_way): + + self.text_encode_way = text_encode_way + + vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) + words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) + word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) + self.word2vec = {w: vectors[word2idx[w]] for w in words} + + if 'glove_6B' in self.text_encode_way: + from torchtext.vocab import GloVe + glove_6b = GloVe(name='6B', dim=300) + self.word2vec_glove_6b = glove_6b.get_vecs_by_tokens + + def _get_pos_ohot(self, pos): + pos_vec = np.zeros(len(POS_enumerator)) + if pos in POS_enumerator: + pos_vec[POS_enumerator[pos]] = 1 + else: + pos_vec[POS_enumerator['OTHER']] = 1 + return pos_vec + + def __len__(self): + return len(self.word2vec) + + def __getitem__(self, item): + word, pos = item.split('/') + if 'given_glove' in self.text_encode_way: + if word in self.word2vec: + word_vec = self.word2vec[word] + vip_pos = None + for key, values in VIP_dict.items(): + if word in values: + vip_pos = key + break + if vip_pos is not None: + pos_vec = self._get_pos_ohot(vip_pos) + else: + pos_vec = self._get_pos_ohot(pos) + else: + word_vec = self.word2vec['unk'] + pos_vec = self._get_pos_ohot('OTHER') + + elif 'glove_6B' in self.text_encode_way: + word_vec = self.word2vec_glove_6b([word]).squeeze() + + if word in self.word2vec: + vip_pos = None + for key, values in VIP_dict.items(): + if word in values: + vip_pos = key + break + if vip_pos is not None: + pos_vec = self._get_pos_ohot(vip_pos) + else: + pos_vec = self._get_pos_ohot(pos) + else: + pos_vec = self._get_pos_ohot('OTHER') + + + + return word_vec, pos_vec + +class WordVectorizer_only_text_token(object): + def __init__(self, meta_root, prefix, text_encode_way): + + self.text_encode_way = text_encode_way + + vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) + words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) + word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) + self.word2vec = {w: vectors[word2idx[w]] for w in words} + + if 'glove_6B' in self.text_encode_way: + from torchtext.vocab import GloVe + glove_6b = GloVe(name='6B', dim=300) + self.word2vec_glove_6b = glove_6b.get_vecs_by_tokens + + def __len__(self): + return len(self.word2vec) + + def __getitem__(self, item): + word = item + + if 'given_glove' in self.text_encode_way: + if word in self.word2vec: + word_vec = self.word2vec[word] + else: + word_vec = self.word2vec['unk'] + + elif 'glove_6B' in self.text_encode_way: + word_vec = self.word2vec_glove_6b([word]).squeeze() + + return word_vec + + + diff --git a/Evaluator_272/mld/data/sampling/__init__.py b/Evaluator_272/mld/data/sampling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c9d0bea9d4a507a43e240a3644bdedf83a107e4 --- /dev/null +++ b/Evaluator_272/mld/data/sampling/__init__.py @@ -0,0 +1,2 @@ +from .base import FrameSampler +from .framerate import subsample, upsample diff --git a/Evaluator_272/mld/data/sampling/base.py b/Evaluator_272/mld/data/sampling/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0ab6a8a0fb9fa41d4a586d392a4722f55e34e689 --- /dev/null +++ b/Evaluator_272/mld/data/sampling/base.py @@ -0,0 +1,41 @@ +from .frames import get_frameix_from_data_index + +class FrameSampler: + def __init__(self, sampling="conseq", sampling_step=1, request_frames=None,threshold_reject=0.75,max_len=1000,min_len=10): + self.sampling = sampling + + self.sampling_step = sampling_step + self.request_frames = request_frames + self.threshold_reject = threshold_reject + self.max_len = max_len + self.min_len = min_len + + def __call__(self, num_frames): + + return get_frameix_from_data_index(num_frames, + self.request_frames, + self.sampling, + self.sampling_step) + + def accept(self, duration): + # Outputs have original lengths + # Check if it is too long + if self.request_frames is None: + if duration > self.max_len: + return False + elif duration < self.min_len: + return False + else: + # Reject sample if the length is + # too little relative to + # the request frames + min_number = self.threshold_reject * self.request_frames + if duration < min_number: + return False + return True + + def get(self, key, default=None): + return getattr(self, key, default) + + def __getitem__(self, key): + return getattr(self, key) diff --git a/Evaluator_272/mld/data/sampling/framerate.py b/Evaluator_272/mld/data/sampling/framerate.py new file mode 100644 index 0000000000000000000000000000000000000000..72dd08f0ff7e2fedfab55c9d04393a740aaab54f --- /dev/null +++ b/Evaluator_272/mld/data/sampling/framerate.py @@ -0,0 +1,32 @@ +import numpy as np + +def subsample(num_frames, last_framerate, new_framerate): + step = int(last_framerate / new_framerate) + assert step >= 1 + frames = np.arange(0, num_frames, step) + return frames + + + +def upsample(motion, last_framerate, new_framerate): + step = int(new_framerate / last_framerate) + assert step >= 1 + + # Alpha blending => interpolation + alpha = np.linspace(0, 1, step+1) + last = np.einsum("l,...->l...", 1-alpha, motion[:-1]) + new = np.einsum("l,...->l...", alpha, motion[1:]) + + chuncks = (last + new)[:-1] + output = np.concatenate(chuncks.swapaxes(1, 0)) + # Don't forget the last one + output = np.concatenate((output, motion[[-1]])) + return output + + +if __name__ == "__main__": + motion = np.arange(105) + submotion = motion[subsample(len(motion), 100.0, 12.5)] + newmotion = upsample(submotion, 12.5, 100) + + print(newmotion) diff --git a/Evaluator_272/mld/data/sampling/frames.py b/Evaluator_272/mld/data/sampling/frames.py new file mode 100644 index 0000000000000000000000000000000000000000..ab9a6ed47987d5d04651fa153f54741ad5442e64 --- /dev/null +++ b/Evaluator_272/mld/data/sampling/frames.py @@ -0,0 +1,58 @@ +from typing import Optional + +import numpy as np +from numpy import ndarray as Array +import random + + +def get_frameix_from_data_index(num_frames: int, + request_frames: Optional[int], + sampling: str = "conseq", + sampling_step: int = 1) -> Array: + nframes = num_frames + + if request_frames is None: + frame_ix = np.arange(nframes) + else: + + if request_frames > nframes: + fair = False # True + if fair: + # distills redundancy everywhere + choices = np.random.choice(range(nframes), + request_frames, + replace=True) + frame_ix = sorted(choices) + else: + # adding the last frame until done + ntoadd = max(0, request_frames - nframes) + lastframe = nframes - 1 + padding = lastframe * np.ones(ntoadd, dtype=int) + frame_ix = np.concatenate((np.arange(0, nframes), + padding)) + + elif sampling in ["conseq", "random_conseq"]: + step_max = (nframes - 1) // (request_frames - 1) + if sampling == "conseq": + if sampling_step == -1 or sampling_step * (request_frames - 1) >= nframes: + step = step_max + else: + step = sampling_step + elif sampling == "random_conseq": + step = random.randint(1, step_max) + + lastone = step * (request_frames - 1) + shift_max = nframes - lastone - 1 + shift = random.randint(0, max(0, shift_max - 1)) + frame_ix = shift + np.arange(0, lastone + 1, step) + + elif sampling == "random": + choices = np.random.choice(range(nframes), + request_frames, + replace=False) + frame_ix = sorted(choices) + + else: + raise ValueError("Sampling not recognized.") + + return frame_ix diff --git a/Evaluator_272/mld/data/utils.py b/Evaluator_272/mld/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b16acd6ba8b8eaafdba4cde0765f1bc812c25392 --- /dev/null +++ b/Evaluator_272/mld/data/utils.py @@ -0,0 +1,38 @@ +import torch + + +def lengths_to_mask(lengths): + max_len = max(lengths) + mask = torch.arange(max_len, device=lengths.device).expand( + len(lengths), max_len) < lengths.unsqueeze(1) + return mask + + +def collate_tensors(batch): + dims = batch[0].dim() + max_size = [max([b.size(i) for b in batch]) for i in range(dims)] + size = (len(batch), ) + tuple(max_size) + canvas = batch[0].new_zeros(size=size) + for i, b in enumerate(batch): + sub_tensor = canvas[i] + for d in range(dims): + sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) + sub_tensor.add_(b) + return canvas + +def mld_collate(batch): + notnone_batches = [b for b in batch if b is not None] + notnone_batches.sort(key=lambda x: x[2], reverse=True) + adapted_batch = { + "motion": + collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]), + "text": [b[0] for b in notnone_batches], + "length": [b[2] for b in notnone_batches], + "retrieval_name": [b[3] for b in notnone_batches] + } + return adapted_batch + + + + + diff --git a/Evaluator_272/mld/launch/__init__.py b/Evaluator_272/mld/launch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/launch/blender.py b/Evaluator_272/mld/launch/blender.py new file mode 100644 index 0000000000000000000000000000000000000000..cad6c9daa0461ee0bfecc5f76a40def4101d837a --- /dev/null +++ b/Evaluator_272/mld/launch/blender.py @@ -0,0 +1,23 @@ +# Fix blender path +import sys +import os +# local packages +sys.path.append(os.path.expanduser("~/.local/lib/python3.9/site-packages")) +import bpy +import os +from argparse import ArgumentParser + +# Monkey patch argparse such that +# blender / python / hydra parsing works +def parse_args(self, args=None, namespace=None): + if args is not None: + return self.parse_args_bak(args=args, namespace=namespace) + try: + idx = sys.argv.index("--") + args = sys.argv[idx+1:] # the list after '--' + except ValueError as e: # '--' not in the list: + args = [] + return self.parse_args_bak(args=args, namespace=namespace) + +setattr(ArgumentParser, 'parse_args_bak', ArgumentParser.parse_args) +setattr(ArgumentParser, 'parse_args', parse_args) diff --git a/Evaluator_272/mld/launch/prepare.py b/Evaluator_272/mld/launch/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..a9934211edb0bce0222ff74d53df6ae0fe6fa72e --- /dev/null +++ b/Evaluator_272/mld/launch/prepare.py @@ -0,0 +1,66 @@ +import os +import warnings +from pathlib import Path + +import hydra +from mld.tools.runid import generate_id +from omegaconf import OmegaConf + + +# Local paths +def code_path(path=""): + code_dir = hydra.utils.get_original_cwd() + code_dir = Path(code_dir) + return str(code_dir / path) + + +def working_path(path): + return str(Path(os.getcwd()) / path) + + +# fix the id for this run +ID = generate_id() + + +def generate_id(): + return ID + + +def get_last_checkpoint(path, ckpt_name="last.ckpt"): + output_dir = Path(hydra.utils.to_absolute_path(path)) + last_ckpt_path = output_dir / "checkpoints" / ckpt_name + return str(last_ckpt_path) + + +def get_kitname(load_amass_data: bool, load_with_rot: bool): + if not load_amass_data: + return "kit-mmm-xyz" + if load_amass_data and not load_with_rot: + return "kit-amass-xyz" + if load_amass_data and load_with_rot: + return "kit-amass-rot" + + +OmegaConf.register_new_resolver("code_path", code_path) +OmegaConf.register_new_resolver("working_path", working_path) +OmegaConf.register_new_resolver("generate_id", generate_id) +OmegaConf.register_new_resolver("absolute_path", hydra.utils.to_absolute_path) +OmegaConf.register_new_resolver("get_last_checkpoint", get_last_checkpoint) +OmegaConf.register_new_resolver("get_kitname", get_kitname) + + +# Remove warnings +warnings.filterwarnings( + "ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*" +) + +warnings.filterwarnings( + "ignore", ".*does not have many workers which may be a bottleneck*" +) + +warnings.filterwarnings( + "ignore", ".*Our suggested max number of worker in current system is*" +) + + +os.environ["NUMEXPR_MAX_THREADS"] = "24" diff --git a/Evaluator_272/mld/launch/tools.py b/Evaluator_272/mld/launch/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..119185262a35d89a4c99bc265dbccda554200d82 --- /dev/null +++ b/Evaluator_272/mld/launch/tools.py @@ -0,0 +1,9 @@ +from pathlib import Path +from omegaconf import DictConfig, OmegaConf +import hydra +import os + + +def resolve_cfg_path(cfg: DictConfig): + working_dir = os.getcwd() + cfg.working_dir = working_dir diff --git a/Evaluator_272/mld/models/__init__.py b/Evaluator_272/mld/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/models/architectures/__init__.py b/Evaluator_272/mld/models/architectures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/models/architectures/actor_vae.py b/Evaluator_272/mld/models/architectures/actor_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..be712ca758a22a6d0a387c8434620c8ca13a294a --- /dev/null +++ b/Evaluator_272/mld/models/architectures/actor_vae.py @@ -0,0 +1,258 @@ +from typing import List, Optional, Union +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor, nn +from torch.distributions.distribution import Distribution +from mld.utils.temos_utils import lengths_to_mask +from mld.models.operator import PositionalEncoding + + +class ActorVae(nn.Module): + + def __init__(self, + ablation, + nfeats: int, + latent_dim: list = [1, 256], + ff_size: int = 1024, + num_layers: int = 9, + num_heads: int = 4, + dropout: float = 0.1, + is_vae: bool = True, + activation: str = "gelu", + position_embedding: str = "learned", + **kwargs) -> None: + + super().__init__() + + self.latent_size = latent_dim[0] + self.latent_dim = latent_dim[-1] + self.is_vae = is_vae + input_feats = nfeats + output_feats = nfeats + + self.encoder = ActorAgnosticEncoder(nfeats=input_feats, + vae=True, + latent_dim=self.latent_dim, + ff_size=ff_size, + num_layers=num_layers, + num_heads=num_heads, + dropout=dropout, + activation=activation, + **kwargs) + + self.decoder = ActorAgnosticDecoder(nfeats=output_feats, + vae=True, + latent_dim=self.latent_dim, + ff_size=ff_size, + num_layers=num_layers, + num_heads=num_heads, + dropout=dropout, + activation=activation, + **kwargs) + + def forward(self, features: Tensor, lengths: Optional[List[int]] = None): + # Temp + # Todo + # remove and test this function + print("Should Not enter here") + + z, dist = self.encode(features, lengths) + feats_rst = self.decode(z, lengths) + return feats_rst, z, dist + + def encode( + self, + features: Tensor, + lengths: Optional[List[int]] = None + ) -> Union[Tensor, Distribution]: + + dist = self.encoder(features, lengths) + if self.is_vae: + latent = sample_from_distribution(dist) + else: + latent = dist.unsqueeze(0) + + return latent, dist + + def decode(self, z: Tensor, lengths: List[int]): + + feats = self.decoder(z, lengths) + return feats + + +class ActorAgnosticEncoder(nn.Module): + + def __init__(self, + nfeats: int, + vae: bool, + latent_dim: int = 256, + ff_size: int = 1024, + num_layers: int = 4, + num_heads: int = 4, + dropout: float = 0.1, + activation: str = "gelu", + **kwargs) -> None: + super().__init__() + + input_feats = nfeats + self.vae = vae + self.skel_embedding = nn.Linear(input_feats, latent_dim) + + # Action agnostic: only one set of params + if vae: + self.mu_token = nn.Parameter(torch.randn(latent_dim)) + self.logvar_token = nn.Parameter(torch.randn(latent_dim)) + else: + self.emb_token = nn.Parameter(torch.randn(latent_dim)) + + self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout) + + seq_trans_encoder_layer = nn.TransformerEncoderLayer( + d_model=latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation) + + self.seqTransEncoder = nn.TransformerEncoder(seq_trans_encoder_layer, + num_layers=num_layers) + + def forward( + self, + features: Tensor, + lengths: Optional[List[int]] = None + ) -> Union[Tensor, Distribution]: + if lengths is None: + lengths = [len(feature) for feature in features] + + device = features.device + + bs, nframes, nfeats = features.shape + mask = lengths_to_mask(lengths, device) + + x = features + # Embed each human poses into latent vectors + x = self.skel_embedding(x) + + # Switch sequence and batch_size because the input of + # Pytorch Transformer is [Sequence, Batch size, ...] + x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim] + + # Each batch has its own set of tokens + if self.vae: + mu_token = torch.tile(self.mu_token, (bs, )).reshape(bs, -1) + logvar_token = torch.tile(self.logvar_token, + (bs, )).reshape(bs, -1) + + # adding the distribution tokens for all sequences + xseq = torch.cat((mu_token[None], logvar_token[None], x), 0) + + # create a bigger mask, to allow attend to mu and logvar + token_mask = torch.ones((bs, 2), dtype=bool, device=x.device) + aug_mask = torch.cat((token_mask, mask), 1) + else: + emb_token = torch.tile(self.emb_token, (bs, )).reshape(bs, -1) + + # adding the embedding token for all sequences + xseq = torch.cat((emb_token[None], x), 0) + + # create a bigger mask, to allow attend to emb + token_mask = torch.ones((bs, 1), dtype=bool, device=x.device) + aug_mask = torch.cat((token_mask, mask), 1) + + # add positional encoding + xseq = self.sequence_pos_encoding(xseq) + final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) + + if self.vae: + mu, logvar = final[0], final[1] + std = logvar.exp().pow(0.5) + # https://github.com/kampta/pytorch-distributions/blob/master/gaussian_vae.py + dist = torch.distributions.Normal(mu, std) + return dist + else: + return final[0] + + +class ActorAgnosticDecoder(nn.Module): + + def __init__(self, + nfeats: int, + latent_dim: int = 256, + ff_size: int = 1024, + num_layers: int = 4, + num_heads: int = 4, + dropout: float = 0.1, + activation: str = "gelu", + **kwargs) -> None: + super().__init__() + + output_feats = nfeats + self.latent_dim = latent_dim + self.nfeats = nfeats + + self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout) + + seq_trans_decoder_layer = nn.TransformerDecoderLayer( + d_model=latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation) + + self.seqTransDecoder = nn.TransformerDecoder(seq_trans_decoder_layer, + num_layers=num_layers) + + self.final_layer = nn.Linear(latent_dim, output_feats) + + def forward(self, z: Tensor, lengths: List[int]): + mask = lengths_to_mask(lengths, z.device) + # latent_dim = z.shape[1] + bs, nframes = mask.shape + nfeats = self.nfeats + + # z = z[None] # sequence of 1 element for the memory + + # Construct time queries + time_queries = torch.zeros(nframes, + bs, + self.latent_dim, + device=z.device) + time_queries = self.sequence_pos_encoding(time_queries) + + # Pass through the transformer decoder + # with the latent vector for memory + output = self.seqTransDecoder(tgt=time_queries, + memory=z, + tgt_key_padding_mask=~mask) + + output = self.final_layer(output) + # zero for padded area + output[~mask.T] = 0 + # Pytorch Transformer: [Sequence, Batch size, ...] + feats = output.permute(1, 0, 2) + return feats + + +def sample_from_distribution( + dist, + *, + fact=1.0, + sample_mean=False, +) -> Tensor: + + if sample_mean: + return dist.loc.unsqueeze(0) + + # Reparameterization trick + if fact is None: + return dist.rsample().unsqueeze(0) + + # Resclale the eps + eps = dist.rsample() - dist.loc + z = dist.loc + fact * eps + + # add latent size + z = z.unsqueeze(0) + return z diff --git a/Evaluator_272/mld/models/architectures/fc.py b/Evaluator_272/mld/models/architectures/fc.py new file mode 100644 index 0000000000000000000000000000000000000000..91380acf73428b711bfcf1d19a63c16b61270c36 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/fc.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Encoder_FC(nn.Module): + def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes, translation, pose_rep, glob, glob_rot, + latent_dim=256, **kargs): + super().__init__() + + self.modeltype = modeltype + self.njoints = njoints + self.nfeats = nfeats + self.num_frames = num_frames + self.num_classes = num_classes + self.translation = translation + self.pose_rep = pose_rep + self.glob = glob + self.glob_rot = glob_rot + + self.latent_dim = latent_dim + + self.activation = nn.GELU() + + self.input_dim = self.njoints*self.nfeats*self.num_frames+self.num_classes + + self.fully_connected = nn.Sequential(nn.Linear(self.input_dim, 512), + nn.GELU(), + nn.Linear(512, 256), + nn.GELU()) + if self.modeltype == "cvae": + self.mu = nn.Linear(256, self.latent_dim) + self.var = nn.Linear(256, self.latent_dim) + else: + self.final = nn.Linear(256, self.latent_dim) + + def forward(self, batch): + x, y = batch["x"], batch["y"] + bs, njoints, feats, nframes = x.size() + if (njoints * feats * nframes) != self.njoints*self.nfeats*self.num_frames: + raise ValueError("This model is not adapted with this input") + + if len(y.shape) == 1: # can give on hot encoded as input + y = F.one_hot(y, self.num_classes) + y = y.to(dtype=x.dtype) + x = x.reshape(bs, njoints*feats*nframes) + x = torch.cat((x, y), 1) + + x = self.fully_connected(x) + + if self.modeltype == "cvae": + return {"mu": self.mu(x), "logvar": self.var(x)} + else: + return {"z": self.final(x)} + + +class Decoder_FC(nn.Module): + def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes, translation, pose_rep, glob, glob_rot, + latent_dim=256, **kargs): + super().__init__() + + self.modeltype = modeltype + self.njoints = njoints + self.nfeats = nfeats + self.num_frames = num_frames + self.num_classes = num_classes + self.translation = translation + self.pose_rep = pose_rep + self.glob = glob + self.glob_rot = glob_rot + + self.latent_dim = latent_dim + + self.input_dim = self.latent_dim + self.num_classes + self.output_dim = self.njoints*self.nfeats*self.num_frames + + self.fully_connected = nn.Sequential(nn.Linear(self.input_dim, 256), + nn.GELU(), + nn.Linear(256, 512), + nn.GELU(), + nn.Linear(512, self.output_dim), + nn.GELU()) + + def forward(self, batch): + z, y = batch["z"], batch["y"] + # z: [batch_size, latent_dim] + # y: [batch_size] + if len(y.shape) == 1: # can give on hot encoded as input + y = F.one_hot(y, self.num_classes) + y = y.to(dtype=z.dtype) # y: [batch_size, num_classes] + # z: [batch_size, latent_dim+num_classes] + z = torch.cat((z, y), dim=1) + + z = self.fully_connected(z) + + bs, _ = z.size() + + z = z.reshape(bs, self.njoints, self.nfeats, self.num_frames) + batch["output"] = z + return batch diff --git a/Evaluator_272/mld/models/architectures/gpt/clip.py b/Evaluator_272/mld/models/architectures/gpt/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..270b6d0ef5db571a142795e0159326871b482f9c --- /dev/null +++ b/Evaluator_272/mld/models/architectures/gpt/clip.py @@ -0,0 +1,90 @@ +import os +from typing import List, Union + +import torch +from torch import Tensor, nn +from torch.distributions.distribution import Distribution +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer + +from mld.models.operator import PositionalEncoding +from mld.utils.temos_utils import lengths_to_mask + +import pytorch_lightning as pl +class TextEncoder(pl.LightningModule): + + def __init__( + self, + modelpath: str, + finetune: bool = False, + last_hidden_state: bool = False, + latent_dim: list = [1, 256], + ) -> None: + + super().__init__() + + self.latent_dim = latent_dim + + self.tokenizer = AutoTokenizer.from_pretrained(modelpath) + self.text_model = AutoModel.from_pretrained(modelpath) + + # Don't train the model + if not finetune: + self.text_model.training = False + for p in self.text_model.parameters(): + p.requires_grad = False + + # Then configure the model + self.max_length = self.tokenizer.model_max_length + if "clip" in modelpath: + self.text_encoded_dim = self.text_model.config.text_config.hidden_size + if last_hidden_state: + self.name = "clip_hidden" + else: + self.name = "clip" + elif "bert" in modelpath: + self.name = "bert" + self.text_encoded_dim = self.text_model.config.hidden_size + else: + raise ValueError(f"Model {modelpath} not supported") + + def forward(self, texts: List[str]): + # get prompt text embeddings + if self.name in ["clip", "clip_hidden"]: + text_inputs = self.tokenizer( + texts, + padding="max_length", + truncation=True, + max_length=self.max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + # split into max length Clip can handle + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + text_input_ids = text_input_ids[:, :self.tokenizer. + model_max_length] + elif self.name == "bert": + text_inputs = self.tokenizer(texts, + return_tensors="pt", + padding=True) + + # use pooled ouuput if latent dim is two-dimensional + # pooled = 0 if self.latent_dim[0] == 1 else 1 # (bs, seq_len, text_encoded_dim) -> (bs, text_encoded_dim) + # text encoder forward, clip must use get_text_features + if self.name == "clip": + # (batch_Size, text_encoded_dim) + text_embeddings = self.text_model.get_text_features( + text_input_ids.to(self.text_model.device)) + # (batch_Size, 1, text_encoded_dim) + text_embeddings = text_embeddings.unsqueeze(1) + elif self.name == "clip_hidden": + # (batch_Size, seq_length , text_encoded_dim) + text_embeddings = self.text_model.text_model( + text_input_ids.to(self.text_model.device)).last_hidden_state + elif self.name == "bert": + # (batch_Size, seq_length , text_encoded_dim) + text_embeddings = self.text_model( + **text_inputs.to(self.text_model.device)).last_hidden_state + else: + raise NotImplementedError(f"Model {self.name} not implemented") + + return text_embeddings diff --git a/Evaluator_272/mld/models/architectures/gpt/pos_encoding.py b/Evaluator_272/mld/models/architectures/gpt/pos_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..066be3e1f8a1636f7eaabd1c534b9c618ee3e9f8 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/gpt/pos_encoding.py @@ -0,0 +1,43 @@ +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +def PE1d_sincos(seq_length, dim): + """ + :param d_model: dimension of the model + :param length: length of positions + :return: length*d_model position matrix + """ + if dim % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(dim)) + pe = torch.zeros(seq_length, dim) + position = torch.arange(0, seq_length).unsqueeze(1) + div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * + -(math.log(10000.0) / dim))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + + return pe.unsqueeze(1) + + +class PositionEmbedding(nn.Module): + """ + Absolute pos embedding (standard), learned. + """ + def __init__(self, seq_length, dim, dropout, grad=False): + super().__init__() + self.embed = nn.Parameter(data=PE1d_sincos(seq_length, dim), requires_grad=grad) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x): + # x.shape: bs, seq_len, feat_dim + l = x.shape[1] + x = x.permute(1, 0, 2) + self.embed[:l].expand(x.permute(1, 0, 2).shape) + x = self.dropout(x.permute(1, 0, 2)) + return x + + \ No newline at end of file diff --git a/Evaluator_272/mld/models/architectures/gpt/t2m_trans.py b/Evaluator_272/mld/models/architectures/gpt/t2m_trans.py new file mode 100644 index 0000000000000000000000000000000000000000..10f9a0f24ac93f72602e7c65e8c7fe6c6778ddfe --- /dev/null +++ b/Evaluator_272/mld/models/architectures/gpt/t2m_trans.py @@ -0,0 +1,265 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.distributions import Categorical +import mld.models.architectures.gpt.pos_encoding as pos_encoding +import random + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) + x = x + self.positional_embedding[:, None, :].to(x.dtype) + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + +class Text2Motion_Transformer(nn.Module): + + def __init__(self, + num_vq=1024, + embed_dim=512, + clip_dim=512, + block_size=16, + num_layers=2, + n_head=8, + drop_out_rate=0.1, + fc_rate=4): + super().__init__() + self.trans_base = CrossCondTransBase(num_vq, embed_dim, clip_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate) + self.trans_head = CrossCondTransHead(num_vq, embed_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate) + self.block_size = block_size + self.num_vq = num_vq + + def get_block_size(self): + return self.block_size + + def forward(self, idxs, clip_feature): + ''' + Input: + idx: [32, 50] + clip_feature: [32, 768] + + Output: + logits: (32, 51, 513) + ''' + feat = self.trans_base(idxs, clip_feature) + logits = self.trans_head(feat) + return logits + + def sample(self, clip_feature, if_categorial=False): + for k in range(self.block_size): + if k == 0: + x = [] + else: + x = xs + logits = self.forward(x, clip_feature) + logits = logits[:, -1, :] + probs = F.softmax(logits, dim=-1) + if if_categorial: + dist = Categorical(probs) + idx = dist.sample() + if idx == self.num_vq: + break + idx = idx.unsqueeze(-1) + else: + _, idx = torch.topk(probs, k=1, dim=-1) + if idx[0] == self.num_vq: + break + # append to the sequence and continue + if k == 0: + xs = idx + else: + xs = torch.cat((xs, idx), dim=1) + + if k == self.block_size - 1: + return xs[:, :-1] + + + return xs + + + + +class CausalCrossConditionalSelfAttention(nn.Module): + + def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1): + super().__init__() + assert embed_dim % 8 == 0 + # key, query, value projections for all heads + self.key = nn.Linear(embed_dim, embed_dim) + self.query = nn.Linear(embed_dim, embed_dim) + self.value = nn.Linear(embed_dim, embed_dim) + + self.attn_drop = nn.Dropout(drop_out_rate) + self.resid_drop = nn.Dropout(drop_out_rate) + + self.proj = nn.Linear(embed_dim, embed_dim) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) + self.n_head = n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + +class Block(nn.Module): + + def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1, fc_rate=4): + super().__init__() + self.ln1 = nn.LayerNorm(embed_dim) + self.ln2 = nn.LayerNorm(embed_dim) + self.attn = CausalCrossConditionalSelfAttention(embed_dim, block_size, n_head, drop_out_rate) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, fc_rate * embed_dim), + nn.GELU(), + nn.Linear(fc_rate * embed_dim, embed_dim), + nn.Dropout(drop_out_rate), + ) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class CrossCondTransBase(nn.Module): + + def __init__(self, + num_vq=1024, + embed_dim=512, + clip_dim=512, + block_size=16, + num_layers=2, + n_head=8, + drop_out_rate=0.1, + fc_rate=4): + super().__init__() + self.tok_emb = nn.Embedding(num_vq + 2, embed_dim) + self.cond_emb = nn.Linear(clip_dim, embed_dim) + self.pos_embedding = nn.Embedding(block_size, embed_dim) + self.drop = nn.Dropout(drop_out_rate) + # transformer block + self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)]) + self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False) + # self.attention_pool = AttentionPool2d() + self.block_size = block_size + + self.apply(self._init_weights) + + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, idx, clip_feature): + if len(clip_feature.shape) == 3: + clip_feature = clip_feature.mean(axis=1, keepdim=False) + + assert len(clip_feature.shape) == 2 + if len(idx) == 0: + token_embeddings = self.cond_emb(clip_feature).unsqueeze(1) + else: + b, t = idx.size() + assert t <= self.block_size, "Cannot forward, model block size is exhausted." + # forward the Trans model + token_embeddings = self.tok_emb(idx) + token_embeddings = torch.cat([self.cond_emb(clip_feature).unsqueeze(1), token_embeddings], dim=1) + + x = self.pos_embed(token_embeddings) + x = self.blocks(x) + + return x + + +class CrossCondTransHead(nn.Module): + + def __init__(self, + num_vq=1024, + embed_dim=512, + block_size=16, + num_layers=2, + n_head=8, + drop_out_rate=0.1, + fc_rate=4): + super().__init__() + + self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)]) + self.ln_f = nn.LayerNorm(embed_dim) + self.head = nn.Linear(embed_dim, num_vq + 1, bias=False) + self.block_size = block_size + + self.apply(self._init_weights) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, x): + x = self.blocks(x) + x = self.ln_f(x) + logits = self.head(x) + return logits + + + + + + diff --git a/Evaluator_272/mld/models/architectures/gpt/wmr_text_encoder.py b/Evaluator_272/mld/models/architectures/gpt/wmr_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e10548bc454e5655b4ddc48b2f362e019de1a0 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/gpt/wmr_text_encoder.py @@ -0,0 +1,55 @@ +import os +from typing import List, Union + +import torch +from torch import Tensor, nn +from torch.distributions.distribution import Distribution +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer + +from mld.models.operator import PositionalEncoding +from mld.utils.temos_utils import lengths_to_mask + +from mld.models.architectures.temos.motionencoder.actor import ActorAgnosticEncoder +from mld.models.architectures.temos.textencoder.distillbert_actor import DistilbertActorAgnosticEncoder +from collections import OrderedDict +import pytorch_lightning as pl + +class TextEncoder(pl.LightningModule): + + def __init__( + self, + modelpath: str, + finetune: bool = False, + last_hidden_state: bool = False, + latent_dim: list = [1, 256], + ) -> None: + + super().__init__() + + self.latent_dim = latent_dim + + model_dict = OrderedDict() + state_dict = torch.load(modelpath)["state_dict"] + + self.text_model = DistilbertActorAgnosticEncoder('distilbert-base-uncased', num_layers=4) + + for k, v in state_dict.items(): + # print(k) + if k.split(".")[0] == "textencoder": + name = k.replace("textencoder.", "") + model_dict[name] = v + + self.text_model.load_state_dict(model_dict, strict=True) + + if not finetune: + self.text_model.training = False + for p in self.text_model.parameters(): + p.requires_grad = False + + + + def forward(self, texts: List[str]): + feat_clip_text = self.text_model(texts).loc.to(self.text_model.device) + feat_clip_text = torch.cat((feat_clip_text, feat_clip_text), dim=1) + + return feat_clip_text diff --git a/Evaluator_272/mld/models/architectures/mld_bert.py b/Evaluator_272/mld/models/architectures/mld_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..3508dddef57cf9a5af6c4abc232d631d86b959b6 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/mld_bert.py @@ -0,0 +1,164 @@ +import torch +import os + +from typing import List, Union +from torch import nn, Tensor +from torch.distributions.distribution import Distribution + +from mld.models.operator import PositionalEncoding +from mld.utils.temos_utils import lengths_to_mask + + +class MLDTextEncoder(nn.Module): + def __init__(self, + cfg, + modelpath: str, + finetune: bool = False, + vae: bool = True, + latent_dim: int = 256, + ff_size: int = 1024, + num_layers: int = 6, + num_heads: int = 4, + dropout: float = 0.1, + activation: str = "gelu", + **kwargs) -> None: + + super().__init__() + + from transformers import AutoTokenizer, AutoModel + from transformers import logging + + logging.set_verbosity_error() + # Tokenizer + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + self.tokenizer = AutoTokenizer.from_pretrained(modelpath) + + # Text model + self.text_model = AutoModel.from_pretrained(modelpath) + # Don't train the model + if not finetune: + self.text_model.training = False + for p in self.text_model.parameters(): + p.requires_grad = False + + # Then configure the model + self.text_encoded_dim = self.text_model.config.hidden_size + self.text_encoded_dim = latent_dim # enable projection + # self.save_hyperparameters(logger=False) + + encoded_dim = self.text_model.config.hidden_size + + # Projection of the text-outputs into the latent space + self.projection = nn.Sequential(nn.ReLU(), + nn.Linear(encoded_dim, latent_dim)) + + # TransformerVAE adapted from ACTOR + # Action agnostic: only one set of params + + vae = False + if vae: + self.mu_token = nn.Parameter(torch.randn(latent_dim)) + self.logvar_token = nn.Parameter(torch.randn(latent_dim)) + else: + self.global_text_token = nn.Parameter(torch.randn(latent_dim)) + + self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout) + seq_trans_encoder_layer = nn.TransformerEncoderLayer( + d_model=latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation) + self.seqTransEncoder = nn.TransformerEncoder(seq_trans_encoder_layer, + num_layers=num_layers) + + + if self.is_action_branch: + action_trans_encoder_layer = nn.TransformerEncoderLayer( + d_model=latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation) + self.actionTransEncoder = nn.TransformerEncoder( + action_trans_encoder_layer, num_layers=num_layers) + self.mean_token = nn.Parameter(torch.randn(latent_dim)) + self.std_token = nn.Parameter(torch.randn(latent_dim)) + + def global_branch(self, x, mask): + bs = x.shape[0] + + # Switch sequence and batch_size because the input of + # Pytorch Transformer is [Sequence, Batch size, ...] + x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim] + + + global_tokens = torch.tile(self.global_text_token, + (bs, )).reshape(bs, -1) + + if self.is_cross_token: + mean_tokens = torch.tile(self.mean_token, (bs, )).reshape(bs, -1) + std_tokens = torch.tile(self.std_token, (bs, )).reshape(bs, -1) + # adding the embedding token for all sequences + xseq = torch.cat( + (mean_tokens[None], std_tokens[None], global_tokens[None], x), + 0) + + # create a bigger mask, to allow attend to emb + token_mask = torch.ones((bs, 3), dtype=bool, device=x.device) + aug_mask = torch.cat((token_mask, mask), 1) + else: + # adding the embedding token for all sequences + xseq = torch.cat((global_tokens[None], x), 0) + + # create a bigger mask, to allow attend to global + token_mask = torch.ones((bs, 1), dtype=bool, device=x.device) + aug_mask = torch.cat((token_mask, mask), 1) + + # add positional encoding + xseq = self.sequence_pos_encoding(xseq) + # content encode + text_tokens = self.seqTransEncoder(xseq, + src_key_padding_mask=~aug_mask) + return text_tokens + + def action_branch(self, x, mask): + bs = x.shape[0] + mean_tokens = torch.tile(self.mean_token, (bs, )).reshape(bs, -1) + std_tokens = torch.tile(self.std_token, (bs, )).reshape(bs, -1) + + # adding the embedding token for all sequences + actionSeq = torch.cat((mean_tokens[None], std_tokens[None], x), 0) + + # create a bigger mask, to allow attend to emb + token_mask = torch.ones((bs, 2), dtype=bool, device=x.device) + aug_mask = torch.cat((token_mask, mask), 1) + + # Pass through the transformer decoder + # with the latent vector for memory + # add positional encoding + actionSeq = self.sequence_pos_encoding(actionSeq) + action_tokens = self.actionTransEncoder(actionSeq, + src_key_padding_mask=~aug_mask) + return action_tokens[0:2] + + def forward(self, texts: List[str]): + text_encoded, mask = self.get_last_hidden_state(texts, + return_mask=True) + text_emb = self.projection(text_encoded) + + return text_emb + + def get_last_hidden_state(self, + texts: List[str], + return_mask: bool = False + ): #-> Union[Tensor, tuple[Tensor, Tensor]]: + encoded_inputs = self.tokenizer(texts, + return_tensors="pt", + padding=True) + output = self.text_model(**encoded_inputs.to(self.text_model.device)) + if not return_mask: + return output.last_hidden_state + return output.last_hidden_state, encoded_inputs.attention_mask.to( + dtype=bool) diff --git a/Evaluator_272/mld/models/architectures/mld_clip.py b/Evaluator_272/mld/models/architectures/mld_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..cf0ddf803717b8bfcaee134e6a57b28b4707653d --- /dev/null +++ b/Evaluator_272/mld/models/architectures/mld_clip.py @@ -0,0 +1,90 @@ +import os +from typing import List, Union + +import torch +from torch import Tensor, nn +from torch.distributions.distribution import Distribution +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer + +from mld.models.operator import PositionalEncoding +from mld.utils.temos_utils import lengths_to_mask + + +class MldTextEncoder(nn.Module): + + def __init__( + self, + modelpath: str, + finetune: bool = False, + last_hidden_state: bool = False, + latent_dim: list = [1, 256], + ) -> None: + + super().__init__() + + self.latent_dim = latent_dim + + self.tokenizer = AutoTokenizer.from_pretrained(modelpath) + self.text_model = AutoModel.from_pretrained(modelpath) + + # Don't train the model + if not finetune: + self.text_model.training = False + for p in self.text_model.parameters(): + p.requires_grad = False + + # Then configure the model + self.max_length = self.tokenizer.model_max_length + if "clip" in modelpath: + self.text_encoded_dim = self.text_model.config.text_config.hidden_size + if last_hidden_state: + self.name = "clip_hidden" + else: + self.name = "clip" + elif "bert" in modelpath: + self.name = "bert" + self.text_encoded_dim = self.text_model.config.hidden_size + else: + raise ValueError(f"Model {modelpath} not supported") + + def forward(self, texts: List[str]): + # get prompt text embeddings + if self.name in ["clip", "clip_hidden"]: + text_inputs = self.tokenizer( + texts, + padding="max_length", + truncation=True, + max_length=self.max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + # split into max length Clip can handle + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + text_input_ids = text_input_ids[:, :self.tokenizer. + model_max_length] + elif self.name == "bert": + text_inputs = self.tokenizer(texts, + return_tensors="pt", + padding=True) + + # use pooled ouuput if latent dim is two-dimensional + # pooled = 0 if self.latent_dim[0] == 1 else 1 # (bs, seq_len, text_encoded_dim) -> (bs, text_encoded_dim) + # text encoder forward, clip must use get_text_features + if self.name == "clip": + # (batch_Size, text_encoded_dim) + text_embeddings = self.text_model.get_text_features( + text_input_ids.to(self.text_model.device)) + # (batch_Size, 1, text_encoded_dim) + text_embeddings = text_embeddings.unsqueeze(1) + elif self.name == "clip_hidden": + # (batch_Size, seq_length , text_encoded_dim) + text_embeddings = self.text_model.text_model( + text_input_ids.to(self.text_model.device)).last_hidden_state + elif self.name == "bert": + # (batch_Size, seq_length , text_encoded_dim) + text_embeddings = self.text_model( + **text_inputs.to(self.text_model.device)).last_hidden_state + else: + raise NotImplementedError(f"Model {self.name} not implemented") + + return text_embeddings diff --git a/Evaluator_272/mld/models/architectures/mld_denoiser.py b/Evaluator_272/mld/models/architectures/mld_denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bd47f56fc0a5fbbebbe1dbe81afc0f3e1300ff --- /dev/null +++ b/Evaluator_272/mld/models/architectures/mld_denoiser.py @@ -0,0 +1,279 @@ +import torch +import torch.nn as nn +from torch import nn +from mld.models.architectures.tools.embeddings import (TimestepEmbedding, + Timesteps) +from mld.models.operator import PositionalEncoding +from mld.models.operator.cross_attention import (SkipTransformerEncoder, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer) +from mld.models.operator.position_encoding import build_position_encoding +from mld.utils.temos_utils import lengths_to_mask + + +class MldDenoiser(nn.Module): + + def __init__(self, + ablation, + nfeats: int = 263, + condition: str = "text", + latent_dim: list = [1, 256], + ff_size: int = 1024, + num_layers: int = 6, + num_heads: int = 4, + dropout: float = 0.1, + normalize_before: bool = False, + activation: str = "gelu", + flip_sin_to_cos: bool = True, + return_intermediate_dec: bool = False, + position_embedding: str = "learned", + arch: str = "trans_enc", + freq_shift: int = 0, + guidance_scale: float = 7.5, + guidance_uncondp: float = 0.1, + text_encoded_dim: int = 768, + nclasses: int = 10, + **kwargs) -> None: + + super().__init__() + + self.latent_dim = latent_dim[-1] + self.text_encoded_dim = text_encoded_dim + self.condition = condition + self.abl_plus = False + self.ablation_skip_connection = ablation.SKIP_CONNECT + self.diffusion_only = ablation.VAE_TYPE == "no" + self.arch = arch + self.pe_type = ablation.DIFF_PE_TYPE + + if self.diffusion_only: + # assert self.arch == "trans_enc", "only implement encoder for diffusion-only" + self.pose_embd = nn.Linear(nfeats, self.latent_dim) + self.pose_proj = nn.Linear(self.latent_dim, nfeats) + + # emb proj + if self.condition in ["text", "text_uncond", "text_all", 'text_face', 'text_body', 'text_hand', 'text_face_body', "text_seperate", "only_pose_concat", "only_pose_fusion"]: + # text condition + # project time from text_encoded_dim to latent_dim + self.time_proj = Timesteps(text_encoded_dim, flip_sin_to_cos, + freq_shift) + self.time_embedding = TimestepEmbedding(text_encoded_dim, + self.latent_dim) + # project time+text to latent_dim + if text_encoded_dim != self.latent_dim: + # todo 10.24 debug why relu + self.emb_proj = nn.Sequential( + nn.ReLU(), nn.Linear(text_encoded_dim, self.latent_dim)) + elif self.condition in ['action']: + self.time_proj = Timesteps(self.latent_dim, flip_sin_to_cos, + freq_shift) + self.time_embedding = TimestepEmbedding(self.latent_dim, + self.latent_dim) + self.emb_proj = EmbedAction(nclasses, + self.latent_dim, + guidance_scale=guidance_scale, + guidance_uncodp=guidance_uncondp) + else: + raise TypeError(f"condition type {self.condition} not supported") + + if self.pe_type == "actor": + self.query_pos = PositionalEncoding(self.latent_dim, dropout) + self.mem_pos = PositionalEncoding(self.latent_dim, dropout) + elif self.pe_type == "mld": + self.query_pos = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + self.mem_pos = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + else: + raise ValueError("Not Support PE type") + + if self.arch == "trans_enc": + if self.ablation_skip_connection: + # use DETR transformer + encoder_layer = TransformerEncoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + encoder_norm = nn.LayerNorm(self.latent_dim) + self.encoder = SkipTransformerEncoder(encoder_layer, + num_layers, encoder_norm) + else: + # use torch transformer + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation) + self.encoder = nn.TransformerEncoder(encoder_layer, + num_layers=num_layers) + elif self.arch == "trans_dec": + decoder_layer = TransformerDecoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + decoder_norm = nn.LayerNorm(self.latent_dim) + self.decoder = TransformerDecoder( + decoder_layer, + num_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + else: + raise ValueError(f"Not supported architechure{self.arch}!") + + def forward(self, + sample, + timestep, + encoder_hidden_states, + lengths=None, + **kwargs): + # 0. dimension matching + # sample [latent_dim[0], batch_size, latent_dim] <= [batch_size, latent_dim[0], latent_dim[1]] + sample = sample.permute(1, 0, 2) + + # 0. check lengths for no vae (diffusion only) + if lengths not in [None, []]: + mask = lengths_to_mask(lengths, sample.device) + + # 1. time_embedding + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timestep.expand(sample.shape[1]).clone() + time_emb = self.time_proj(timesteps) + time_emb = time_emb.to(dtype=sample.dtype) + # [1, bs, latent_dim] <= [bs, latent_dim] + time_emb = self.time_embedding(time_emb).unsqueeze(0) + + # 2. condition + time embedding + if self.condition in ["text", "text_uncond", "text_all", 'text_face', 'text_body', 'text_hand', 'text_face_body', "text_seperate", "only_pose_concat", "only_pose_fusion"]: + # text_emb [seq_len, batch_size, text_encoded_dim] <= [batch_size, seq_len, text_encoded_dim] + encoder_hidden_states = encoder_hidden_states.permute(1, 0, 2) + text_emb = encoder_hidden_states # [num_words, bs, latent_dim] + # textembedding projection + if self.text_encoded_dim != self.latent_dim: + # [1 or 2, bs, latent_dim] <= [1 or 2, bs, text_encoded_dim] + text_emb_latent = self.emb_proj(text_emb) + else: + text_emb_latent = text_emb + if self.abl_plus: + emb_latent = time_emb + text_emb_latent + else: + emb_latent = torch.cat((time_emb, text_emb_latent), 0) + elif self.condition in ['action']: + action_emb = self.emb_proj(encoder_hidden_states) + if self.abl_plus: + emb_latent = action_emb + time_emb + else: + emb_latent = torch.cat((time_emb, action_emb), 0) + else: + raise TypeError(f"condition type {self.condition} not supported") + + # 4. transformer + if self.arch == "trans_enc": + if self.diffusion_only: + sample = self.pose_embd(sample) + xseq = torch.cat((emb_latent, sample), axis=0) + else: + xseq = torch.cat((sample, emb_latent), axis=0) + + # if self.ablation_skip_connection: + # xseq = self.query_pos(xseq) + # tokens = self.encoder(xseq) + # else: + # # adding the timestep embed + # # [seqlen+1, bs, d] + # # todo change to query_pos_decoder + xseq = self.query_pos(xseq) + tokens = self.encoder(xseq) + + if self.diffusion_only: + sample = tokens[emb_latent.shape[0]:] + sample = self.pose_proj(sample) + + # zero for padded area + sample[~mask.T] = 0 + else: + sample = tokens[:sample.shape[0]] + + elif self.arch == "trans_dec": + if self.diffusion_only: + sample = self.pose_embd(sample) + + # tgt - [1 or 5 or 10, bs, latent_dim] + # memory - [token_num, bs, latent_dim] + sample = self.query_pos(sample) + emb_latent = self.mem_pos(emb_latent) + sample = self.decoder(tgt=sample, memory=emb_latent).squeeze(0) + + if self.diffusion_only: + sample = self.pose_proj(sample) + # zero for padded area + sample[~mask.T] = 0 + else: + raise TypeError("{self.arch} is not supoorted") + + # 5. [batch_size, latent_dim[0], latent_dim[1]] <= [latent_dim[0], batch_size, latent_dim[1]] + sample = sample.permute(1, 0, 2) + + return (sample, ) + + +class EmbedAction(nn.Module): + + def __init__(self, + num_actions, + latent_dim, + guidance_scale=7.5, + guidance_uncodp=0.1, + force_mask=False): + super().__init__() + self.nclasses = num_actions + self.guidance_scale = guidance_scale + self.action_embedding = nn.Parameter( + torch.randn(num_actions, latent_dim)) + + self.guidance_uncodp = guidance_uncodp + self.force_mask = force_mask + self._reset_parameters() + + def forward(self, input): + idx = input[:, 0].to(torch.long) # an index array must be long + output = self.action_embedding[idx] + if not self.training and self.guidance_scale > 1.0: + uncond, output = output.chunk(2) + uncond_out = self.mask_cond(uncond, force=True) + out = self.mask_cond(output) + output = torch.cat((uncond_out, out)) + + output = self.mask_cond(output) + + return output.unsqueeze(0) + + def mask_cond(self, output, force=False): + bs, d = output.shape + # classifer guidence + if self.force_mask or force: + return torch.zeros_like(output) + elif self.training and self.guidance_uncodp > 0.: + mask = torch.bernoulli( + torch.ones(bs, device=output.device) * + self.guidance_uncodp).view( + bs, 1) # 1-> use null_cond, 0-> use real cond + return output * (1. - mask) + else: + return output + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) diff --git a/Evaluator_272/mld/models/architectures/mld_dual_vae.py b/Evaluator_272/mld/models/architectures/mld_dual_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..b63844a3fb9b6fdbb9019a7b139127b912691221 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/mld_dual_vae.py @@ -0,0 +1,346 @@ +from functools import reduce +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor, nn +from torch.distributions.distribution import Distribution + +from mld.models.architectures.tools.embeddings import TimestepEmbedding, Timesteps +from mld.models.operator import PositionalEncoding +from mld.models.operator.cross_attention import ( + SkipTransformerEncoder, + SkipTransformerDecoder, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, +) +from mld.models.operator.position_encoding import build_position_encoding +from mld.utils.temos_utils import lengths_to_mask +""" +vae + +skip connection encoder +skip connection decoder + +mem for each decoder layer +""" + + +class MldDualVae(nn.Module): + + def __init__(self, + ablation, + nfeats: int, + latent_dim: list = [1, 256], + ff_size: int = 1024, + num_layers: int = 9, + num_heads: int = 4, + dropout: float = 0.1, + arch: str = "all_encoder", + normalize_before: bool = False, + activation: str = "gelu", + position_embedding: str = "learned", + **kwargs) -> None: + + super().__init__() + + assert nfeats == 313 + + + self.latent_size = latent_dim[0] + self.latent_dim = latent_dim[-1] + input_feats = nfeats + + body_input_feats = 4 + 21 * 3 + 22 * 3 + hand_input_feats = 30 * 3 + 30 * 3 + + output_feats = nfeats + + body_output_feats = 4 + 21 * 3 + 22 * 3 + hand_output_feats = 30 * 3 + 30 * 3 + + self.arch = arch + self.mlp_dist = ablation.MLP_DIST + self.pe_type = ablation.PE_TYPE + + if self.pe_type == "actor": + self.query_pos_encoder = PositionalEncoding( + self.latent_dim, dropout) + self.query_pos_decoder = PositionalEncoding( + self.latent_dim, dropout) + elif self.pe_type == "mld": + # self.query_pos_encoder = build_position_encoding( + # self.latent_dim, position_embedding=position_embedding) + self.body_query_pos_encoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + self.hand_query_pos_encoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + + # self.query_pos_decoder = build_position_encoding( + # self.latent_dim, position_embedding=position_embedding) + self.body_query_pos_decoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + self.hand_query_pos_decoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + + else: + raise ValueError("Not Support PE type") + + # encoder_layer = TransformerEncoderLayer( + # self.latent_dim, + # num_heads, + # ff_size, + # dropout, + # activation, + # normalize_before, + # ) + + body_encoder_layer = TransformerEncoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + + hand_encoder_layer = TransformerEncoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + + body_encoder_norm = nn.LayerNorm(self.latent_dim) + hand_encoder_norm = nn.LayerNorm(self.latent_dim) + + # self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, + # encoder_norm) + + self.body_encoder = SkipTransformerEncoder(body_encoder_layer, num_layers, + body_encoder_norm) + self.hand_encoder = SkipTransformerEncoder(hand_encoder_layer, num_layers, + hand_encoder_norm) + + + if self.arch == "all_encoder": + decoder_norm = nn.LayerNorm(self.latent_dim) + self.decoder = SkipTransformerEncoder(encoder_layer, num_layers, + decoder_norm) + elif self.arch == "encoder_decoder": + + body_decoder_layer = TransformerDecoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + hand_decoder_layer = TransformerDecoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + body_decoder_norm = nn.LayerNorm(self.latent_dim) + hand_decoder_norm = nn.LayerNorm(self.latent_dim) + + self.body_decoder = SkipTransformerDecoder(body_decoder_layer, num_layers, + body_decoder_norm) + + self.hand_decoder = SkipTransformerDecoder(hand_decoder_layer, num_layers, + hand_decoder_norm) + + + else: + raise ValueError("Not support architecture!") + + if self.mlp_dist: + self.global_motion_token = nn.Parameter( + torch.randn(self.latent_size, self.latent_dim)) + self.dist_layer = nn.Linear(self.latent_dim, 2 * self.latent_dim) + else: + + + self.body_global_motion_token = nn.Parameter( + torch.randn(self.latent_size * 2, self.latent_dim)) + + self.hand_global_motion_token = nn.Parameter( + torch.randn(self.latent_size * 2, self.latent_dim)) + + # self.skel_embedding = nn.Linear(input_feats, self.latent_dim) + self.body_skel_embedding = nn.Linear(body_output_feats, self.latent_dim) + self.hand_skel_embedding = nn.Linear(hand_output_feats, self.latent_dim) + + # self.final_layer = nn.Linear(self.latent_dim, output_feats) + self.body_final_layer = nn.Linear(self.latent_dim, body_output_feats) + self.hand_final_layer = nn.Linear(self.latent_dim, hand_output_feats) + + + def forward(self, features: Tensor, lengths: Optional[List[int]] = None): + + print("Should Not enter here") + z, dist = self.encode(features, lengths) + feats_rst = self.decode(z, lengths) + return feats_rst, z, dist + + def encode( + self, + features: Tensor, + lengths: Optional[List[int]] = None + ) -> Union[Tensor, Distribution]: + if lengths is None: + lengths = [len(feature) for feature in features] + + device = features.device + + body_features = torch.cat((features[..., :4+21*3], features[..., 4+51*3:4+51*3+22*3]), dim=-1) # (32, 196, 133) + hand_features = torch.cat((features[..., 4+21*3:4+51*3], features[..., 4+51*3+22*3:]), dim=-1) # (132, 196, 180) + bs, nframes, _ = features.shape # (32, 196, 313) + mask = lengths_to_mask(lengths, device) # (32, 196) + + body_x = body_features + hand_x = hand_features + # Embed each human poses into latent vectors + # x = self.skel_embedding(x) + body_x = self.body_skel_embedding(body_x) + hand_x = self.hand_skel_embedding(hand_x) + + # Switch sequence and batch_size because the input of + # Pytorch Transformer is [Sequence, Batch size, ...] + # x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim] (196, 32, 256) + body_x = body_x.permute(1,0,2) + hand_x = hand_x.permute(1,0,2) + + # Each batch has its own set of tokens + # dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1)) # (2, 32, 256) + body_dist = torch.tile(self.body_global_motion_token[:, None, :], (1, bs, 1)) # (2, 32, 256) + hand_dist = torch.tile(self.hand_global_motion_token[:, None, :], (1, bs, 1)) # (2, 32, 256) + + # create a bigger mask, to allow attend to emb + dist_masks = torch.ones((bs, body_dist.shape[0]), + dtype=bool, + device=body_x.device) # (32, 2) all one + + aug_mask = torch.cat((dist_masks, mask), 1) + + # adding the embedding token for all sequences + # xseq = torch.cat((dist, x), 0) + xseq_body = torch.cat((body_dist, body_x), 0) + xseq_hand = torch.cat((hand_dist, hand_x), 0) + + if self.pe_type == "actor": + xseq = self.query_pos_encoder(xseq) + dist = self.encoder(xseq, + src_key_padding_mask=~aug_mask)[:dist.shape[0]] + elif self.pe_type == "mld": + # xseq = self.query_pos_encoder(xseq) + # dist = self.encoder(xseq, + # src_key_padding_mask=~aug_mask)[:dist.shape[0]] + + xseq_body = self.body_query_pos_encoder(xseq_body) + body_dist = self.body_encoder(xseq_body, + src_key_padding_mask=~aug_mask)[:body_dist.shape[0]] + + xseq_hand = self.hand_query_pos_encoder(xseq_hand) + hand_dist = self.hand_encoder(xseq_hand, + src_key_padding_mask=~aug_mask)[:hand_dist.shape[0]] + + # content distribution + # self.latent_dim => 2*self.latent_dim + if self.mlp_dist: + tokens_dist = self.dist_layer(dist) + mu = tokens_dist[:, :, :self.latent_dim] + logvar = tokens_dist[:, :, self.latent_dim:] + else: + + body_mu = body_dist[0:self.latent_size, ...] + body_logvar = body_dist[self.latent_size:, ...] + hand_mu = hand_dist[0:self.latent_size, ...] + hand_logvar = hand_dist[self.latent_size:, ...] + + + body_std = body_logvar.exp().pow(0.5) + body_dist = torch.distributions.Normal(body_mu, body_std) + body_latent = body_dist.rsample() + + hand_std = hand_logvar.exp().pow(0.5) + hand_dist = torch.distributions.Normal(hand_mu, hand_std) + hand_latent = hand_dist.rsample() + + # return latent, dist + return body_latent, hand_latent, body_dist, hand_dist + + def decode(self, body_z: Tensor, hand_z: Tensor, lengths: List[int]): + mask = lengths_to_mask(lengths, body_z.device) + bs, nframes = mask.shape + + # queries = torch.zeros(nframes, bs, self.latent_dim, device=z.device) + body_queries = torch.zeros(nframes, bs, self.latent_dim, device=body_z.device) + hand_queries = torch.zeros(nframes, bs, self.latent_dim, device=hand_z.device) + + + # Pass through the transformer decoder + # with the latent vector for memory + if self.arch == "all_encoder": + xseq = torch.cat((z, queries), axis=0) + z_mask = torch.ones((bs, self.latent_size), + dtype=bool, + device=z.device) + augmask = torch.cat((z_mask, mask), axis=1) + + if self.pe_type == "actor": + xseq = self.query_pos_decoder(xseq) + output = self.decoder( + xseq, src_key_padding_mask=~augmask)[z.shape[0]:] + elif self.pe_type == "mld": + xseq = self.query_pos_decoder(xseq) + output = self.decoder( + xseq, src_key_padding_mask=~augmask)[z.shape[0]:] + + + elif self.arch == "encoder_decoder": + if self.pe_type == "actor": + queries = self.query_pos_decoder(queries) + output = self.decoder(tgt=queries, + memory=z, + tgt_key_padding_mask=~mask).squeeze(0) + elif self.pe_type == "mld": + # queries = self.query_pos_decoder(queries) + body_queries = self.body_query_pos_decoder(body_queries) + hand_queries = self.hand_query_pos_decoder(hand_queries) + + + body_output = self.body_decoder( + tgt=body_queries, + memory=body_z, + tgt_key_padding_mask=~mask, + + ).squeeze(0) + + hand_output = self.hand_decoder( + tgt=hand_queries, + memory=hand_z, + tgt_key_padding_mask=~mask, + + ).squeeze(0) + + + body_output = self.body_final_layer(body_output) + hand_output = self.hand_final_layer(hand_output) + # zero for padded area + # output[~mask.T] = 0 + body_output[~mask.T] = 0 + hand_output[~mask.T] = 0 + # Pytorch Transformer: [Sequence, Batch size, ...] + feats = torch.cat((body_output.permute(1, 0, 2), hand_output.permute(1, 0, 2)), dim=-1) + return feats diff --git a/Evaluator_272/mld/models/architectures/mld_vae.py b/Evaluator_272/mld/models/architectures/mld_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..91e0496b557e50832a665c1bb2bdb22008221a21 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/mld_vae.py @@ -0,0 +1,226 @@ +from functools import reduce +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor, nn +from torch.distributions.distribution import Distribution + +from mld.models.architectures.tools.embeddings import TimestepEmbedding, Timesteps +from mld.models.operator import PositionalEncoding +from mld.models.operator.cross_attention import ( + SkipTransformerEncoder, + SkipTransformerDecoder, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, +) +from mld.models.operator.position_encoding import build_position_encoding +from mld.utils.temos_utils import lengths_to_mask +""" +vae + +skip connection encoder +skip connection decoder + +mem for each decoder layer +""" + + +class MldVae(nn.Module): + + def __init__(self, + ablation, + nfeats: int, + latent_dim: list = [1, 256], + ff_size: int = 1024, + num_layers: int = 9, + num_heads: int = 4, + dropout: float = 0.1, + arch: str = "all_encoder", + normalize_before: bool = False, + activation: str = "gelu", + position_embedding: str = "learned", + **kwargs) -> None: + + super().__init__() + + self.latent_size = latent_dim[0] + self.latent_dim = latent_dim[-1] + input_feats = nfeats + output_feats = nfeats + self.arch = arch + self.mlp_dist = ablation.MLP_DIST + self.pe_type = ablation.PE_TYPE + + if self.pe_type == "actor": + self.query_pos_encoder = PositionalEncoding( + self.latent_dim, dropout) + self.query_pos_decoder = PositionalEncoding( + self.latent_dim, dropout) + elif self.pe_type == "mld": + self.query_pos_encoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + self.query_pos_decoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + else: + raise ValueError("Not Support PE type") + + encoder_layer = TransformerEncoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + encoder_norm = nn.LayerNorm(self.latent_dim) + self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, + encoder_norm) + + if self.arch == "all_encoder": + decoder_norm = nn.LayerNorm(self.latent_dim) + self.decoder = SkipTransformerEncoder(encoder_layer, num_layers, + decoder_norm) + elif self.arch == "encoder_decoder": + decoder_layer = TransformerDecoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + decoder_norm = nn.LayerNorm(self.latent_dim) + self.decoder = SkipTransformerDecoder(decoder_layer, num_layers, + decoder_norm) + else: + raise ValueError("Not support architecture!") + + if self.mlp_dist: + self.global_motion_token = nn.Parameter( + torch.randn(self.latent_size, self.latent_dim)) + self.dist_layer = nn.Linear(self.latent_dim, 2 * self.latent_dim) + else: + self.global_motion_token = nn.Parameter( + torch.randn(self.latent_size * 2, self.latent_dim)) + + self.skel_embedding = nn.Linear(input_feats, self.latent_dim) + self.final_layer = nn.Linear(self.latent_dim, output_feats) + + def forward(self, features: Tensor, lengths: Optional[List[int]] = None): + + print("Should Not enter here") + + z, dist = self.encode(features, lengths) + feats_rst = self.decode(z, lengths) + return feats_rst, z, dist + + def encode( + self, + features: Tensor, + lengths: Optional[List[int]] = None + ) -> Union[Tensor, Distribution]: + if lengths is None: + lengths = [len(feature) for feature in features] + + device = features.device + + bs, nframes, nfeats = features.shape + mask = lengths_to_mask(lengths, device) + + x = features + # Embed each human poses into latent vectors + x = self.skel_embedding(x) + + # Switch sequence and batch_size because the input of + # Pytorch Transformer is [Sequence, Batch size, ...] + x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim] + + # Each batch has its own set of tokens + dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1)) + + # create a bigger mask, to allow attend to emb + dist_masks = torch.ones((bs, dist.shape[0]), + dtype=bool, + device=x.device) + aug_mask = torch.cat((dist_masks, mask), 1) + + # adding the embedding token for all sequences + xseq = torch.cat((dist, x), 0) + if self.pe_type == "actor": + xseq = self.query_pos_encoder(xseq) + dist = self.encoder(xseq, + src_key_padding_mask=~aug_mask)[:dist.shape[0]] + elif self.pe_type == "mld": + xseq = self.query_pos_encoder(xseq) + dist = self.encoder(xseq, + src_key_padding_mask=~aug_mask)[:dist.shape[0]] + + + # content distribution + # self.latent_dim => 2*self.latent_dim + if self.mlp_dist: + tokens_dist = self.dist_layer(dist) + mu = tokens_dist[:, :, :self.latent_dim] + logvar = tokens_dist[:, :, self.latent_dim:] + else: + mu = dist[0:self.latent_size, ...] + logvar = dist[self.latent_size:, ...] + + # resampling + std = logvar.exp().pow(0.5) + dist = torch.distributions.Normal(mu, std) + latent = dist.rsample() + return latent, dist + + def decode(self, z: Tensor, lengths: List[int]): + mask = lengths_to_mask(lengths, z.device) + bs, nframes = mask.shape + + queries = torch.zeros(nframes, bs, self.latent_dim, device=z.device) + + + if self.arch == "all_encoder": + xseq = torch.cat((z, queries), axis=0) + z_mask = torch.ones((bs, self.latent_size), + dtype=bool, + device=z.device) + augmask = torch.cat((z_mask, mask), axis=1) + + if self.pe_type == "actor": + xseq = self.query_pos_decoder(xseq) + output = self.decoder( + xseq, src_key_padding_mask=~augmask)[z.shape[0]:] + elif self.pe_type == "mld": + xseq = self.query_pos_decoder(xseq) + output = self.decoder( + xseq, src_key_padding_mask=~augmask)[z.shape[0]:] + + elif self.arch == "encoder_decoder": + if self.pe_type == "actor": + queries = self.query_pos_decoder(queries) + output = self.decoder(tgt=queries, + memory=z, + tgt_key_padding_mask=~mask).squeeze(0) + elif self.pe_type == "mld": + queries = self.query_pos_decoder(queries) + # mem_pos = self.mem_pos_decoder(z) + output = self.decoder( + tgt=queries, + memory=z, + tgt_key_padding_mask=~mask, + # query_pos=query_pos, + # pos=mem_pos, + ).squeeze(0) + + + output = self.final_layer(output) + # zero for padded area + output[~mask.T] = 0 + # Pytorch Transformer: [Sequence, Batch size, ...] + feats = output.permute(1, 0, 2) + return feats diff --git a/Evaluator_272/mld/models/architectures/t2m_motionenc.py b/Evaluator_272/mld/models/architectures/t2m_motionenc.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3c3a304e0f0c457b25bba752de0eccd8798e2c --- /dev/null +++ b/Evaluator_272/mld/models/architectures/t2m_motionenc.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence + + +class MovementConvEncoder(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MovementConvEncoder, self).__init__() + self.main = nn.Sequential( + nn.Conv1d(input_size, hidden_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(hidden_size, output_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ) + self.out_net = nn.Linear(output_size, output_size) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return self.out_net(outputs) + + +class MotionEncoderBiGRUCo(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MotionEncoderBiGRUCo, self).__init__() + + self.input_emb = nn.Linear(input_size, hidden_size) + self.gru = nn.GRU( + hidden_size, hidden_size, batch_first=True, bidirectional=True + ) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size), + ) + + self.hidden_size = hidden_size + self.hidden = nn.Parameter( + torch.randn((2, 1, self.hidden_size), requires_grad=True) + ) + + def forward(self, inputs, m_lens): + num_samples = inputs.shape[0] + + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = m_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) diff --git a/Evaluator_272/mld/models/architectures/t2m_textenc.py b/Evaluator_272/mld/models/architectures/t2m_textenc.py new file mode 100644 index 0000000000000000000000000000000000000000..afcc54c898b24fd1fe641ac47ace3197fe4b0167 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/t2m_textenc.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence + + +class TextEncoderBiGRUCo(nn.Module): + def __init__(self, word_size, pos_size, hidden_size, output_size): + super(TextEncoderBiGRUCo, self).__init__() + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU( + hidden_size, hidden_size, batch_first=True, bidirectional=True + ) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size), + ) + + self.hidden_size = hidden_size + self.hidden = nn.Parameter( + torch.randn((2, 1, self.hidden_size), requires_grad=True) + ) + + def forward(self, word_embs, pos_onehot, cap_lens): + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) + + + +class TextEncoderBiGRUCoV2(nn.Module): + def __init__(self, word_size, hidden_size, output_size): + super(TextEncoderBiGRUCoV2, self).__init__() + + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size) + ) + + + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + + def forward(self, word_embs, cap_lens): + num_samples = word_embs.shape[0] + + inputs = word_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) diff --git a/Evaluator_272/mld/models/architectures/temos/__init__.py b/Evaluator_272/mld/models/architectures/temos/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/models/architectures/temos/motiondecoder/__init__.py b/Evaluator_272/mld/models/architectures/temos/motiondecoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/models/architectures/temos/motiondecoder/actor.py b/Evaluator_272/mld/models/architectures/temos/motiondecoder/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..66e98e542660488bbb96d19d84b9920c356a8433 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/temos/motiondecoder/actor.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl + +from typing import List, Optional +from torch import nn, Tensor + +from mld.models.operator import PositionalEncoding +from mld.utils.temos_utils import lengths_to_mask + + +class ActorAgnosticDecoder(pl.LightningModule): + def __init__(self, nfeats: int, + latent_dim: int = 256, ff_size: int = 1024, + num_layers: int = 4, num_heads: int = 4, + dropout: float = 0.1, + activation: str = "gelu", **kwargs) -> None: + + super().__init__() + self.save_hyperparameters(logger=False) + + output_feats = nfeats + + self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout) + + seq_trans_decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation) + + self.seqTransDecoder = nn.TransformerDecoder(seq_trans_decoder_layer, + num_layers=num_layers) + + self.final_layer = nn.Linear(latent_dim, output_feats) + + def forward(self, z: Tensor, lengths: List[int]): + mask = lengths_to_mask(lengths, z.device) + latent_dim = z.shape[1] + bs, nframes = mask.shape + nfeats = self.hparams.nfeats + + z = z[None] # sequence of 1 element for the memory + + # Construct time queries + time_queries = torch.zeros(nframes, bs, latent_dim, device=z.device) + time_queries = self.sequence_pos_encoding(time_queries) + + # Pass through the transformer decoder + # with the latent vector for memory + output = self.seqTransDecoder(tgt=time_queries, memory=z, + tgt_key_padding_mask=~mask) + + output = self.final_layer(output) + # zero for padded area + output[~mask.T] = 0 + # Pytorch Transformer: [Sequence, Batch size, ...] + feats = output.permute(1, 0, 2) + return feats diff --git a/Evaluator_272/mld/models/architectures/temos/motionencoder/__init__.py b/Evaluator_272/mld/models/architectures/temos/motionencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/models/architectures/temos/motionencoder/actor.py b/Evaluator_272/mld/models/architectures/temos/motionencoder/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c516460571cf569419cb76bcf24de800329667 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/temos/motionencoder/actor.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl + +from typing import List, Optional, Union +from torch import nn, Tensor +from torch.distributions.distribution import Distribution + +from mld.models.operator import PositionalEncoding + +class ActorAgnosticEncoder(pl.LightningModule): + def __init__(self, nfeats: int, vae: bool, + latent_dim: int = 256, ff_size: int = 1024, + num_layers: int = 4, num_heads: int = 4, + dropout: float = 0.1, + activation: str = "gelu", max_len: int = -1, **kwargs) -> None: + super().__init__() + self.save_hyperparameters(logger=False) + input_feats = nfeats + self.skel_embedding = nn.Linear(input_feats, latent_dim) + self.max_len = max_len + + # Action agnostic: only one set of params + if vae: + self.mu_token = nn.Parameter(torch.randn(latent_dim)) + self.logvar_token = nn.Parameter(torch.randn(latent_dim)) + else: + self.emb_token = nn.Parameter(torch.randn(latent_dim)) + + self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout) + + seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation) + + self.seqTransEncoder = nn.TransformerEncoder(seq_trans_encoder_layer, + num_layers=num_layers) + + def lengths_to_mask(self, lengths, device): + if self.max_len == -1: + max_len = max(lengths) + mask = torch.arange(max_len, device=device).expand(len(lengths), max_len) < lengths.unsqueeze(1) + else: + mask = torch.arange(self.max_len, device=lengths.device).expand(len(lengths), self.max_len) < lengths.unsqueeze(1) + return mask + + def forward(self, features: Tensor, lengths: Optional[List[int]] = None) -> Union[Tensor, Distribution]: + if lengths is None: + lengths = [len(feature) for feature in features] + + device = features.device + + bs, nframes, nfeats = features.shape + + if not isinstance(lengths, torch.Tensor): + lengths = torch.tensor(lengths, device=device) + mask = self.lengths_to_mask(lengths, device).to(device) + + x = features + # Embed each human poses into latent vectors + x = self.skel_embedding(x) + + # Switch sequence and batch_size because the input of + # Pytorch Transformer is [Sequence, Batch size, ...] + x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim] + + # Each batch has its own set of tokens + if self.hparams.vae: + mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1) + logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1) + + # adding the distribution tokens for all sequences + xseq = torch.cat((mu_token[None], logvar_token[None], x), 0) + + # create a bigger mask, to allow attend to mu and logvar + token_mask = torch.ones((bs, 2), dtype=bool, device=x.device) + aug_mask = torch.cat((token_mask, mask), 1) + else: + emb_token = torch.tile(self.emb_token, (bs,)).reshape(bs, -1) + + # adding the embedding token for all sequences + xseq = torch.cat((emb_token[None], x), 0) + + # create a bigger mask, to allow attend to emb + token_mask = torch.ones((bs, 1), dtype=bool, device=x.device) + aug_mask = torch.cat((token_mask, mask), 1) + + # add positional encoding + xseq = self.sequence_pos_encoding(xseq) + final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) + + if self.hparams.vae: + mu, logvar = final[0], final[1] + std = logvar.exp().pow(0.5) + dist = torch.distributions.Normal(mu, std) + return dist + else: + return final[0] diff --git a/Evaluator_272/mld/models/architectures/temos/textencoder/__init__.py b/Evaluator_272/mld/models/architectures/temos/textencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/models/architectures/temos/textencoder/distillbert.py b/Evaluator_272/mld/models/architectures/temos/textencoder/distillbert.py new file mode 100644 index 0000000000000000000000000000000000000000..3c9a2053954fc838013e56211bdeb44a851f063e --- /dev/null +++ b/Evaluator_272/mld/models/architectures/temos/textencoder/distillbert.py @@ -0,0 +1,51 @@ +from typing import List, Union, Tuple +import pytorch_lightning as pl + +import torch.nn as nn +import os + +import torch +from torch import Tensor +from torch.distributions.distribution import Distribution + + +class DistilbertEncoderBase(pl.LightningModule): + def __init__(self, modelpath: str, + finetune: bool = False) -> None: + super().__init__() + + from transformers import AutoTokenizer, AutoModel + from transformers import logging + logging.set_verbosity_error() + # Tokenizer + os.environ["TOKENIZERS_PARALLELISM"] = "false" + self.tokenizer = AutoTokenizer.from_pretrained(modelpath) + + # Text model + self.text_model = AutoModel.from_pretrained(modelpath) + # Don't train the model + if not finetune: + self.text_model.training = False + for p in self.text_model.parameters(): + p.requires_grad = False + + # Then configure the model + self.text_encoded_dim = self.text_model.config.hidden_size + + def train(self, mode: bool = True): + self.training = mode + for module in self.children(): + # Don't put the model in + if module == self.text_model and not self.hparams.finetune: + continue + module.train(mode) + return self + + def get_last_hidden_state(self, texts: List[str], + return_mask: bool = False + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + output = self.text_model(**encoded_inputs.to(self.text_model.device)) + if not return_mask: + return output.last_hidden_state + return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool) diff --git a/Evaluator_272/mld/models/architectures/temos/textencoder/distillbert_actor.py b/Evaluator_272/mld/models/architectures/temos/textencoder/distillbert_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6d4c3ba4d60c08b06b889ff965575520441ef5 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/temos/textencoder/distillbert_actor.py @@ -0,0 +1,91 @@ +from .distillbert import DistilbertEncoderBase +import torch + +from typing import List, Union +from torch import nn, Tensor +from torch.distributions.distribution import Distribution + +from mld.models.operator import PositionalEncoding +from mld.utils.temos_utils import lengths_to_mask + + +class DistilbertActorAgnosticEncoder(DistilbertEncoderBase): + def __init__(self, modelpath: str, + finetune: bool = False, + vae: bool = True, + latent_dim: int = 256, + ff_size: int = 1024, + num_layers: int = 4, num_heads: int = 4, + dropout: float = 0.1, + activation: str = "gelu", **kwargs) -> None: + super().__init__(modelpath=modelpath, finetune=finetune) + self.save_hyperparameters(logger=False) + + encoded_dim = self.text_encoded_dim + # Projection of the text-outputs into the latent space + self.projection = nn.Sequential(nn.ReLU(), + nn.Linear(encoded_dim, latent_dim)) + + # TransformerVAE adapted from ACTOR + # Action agnostic: only one set of params + if vae: + self.mu_token = nn.Parameter(torch.randn(latent_dim)) + self.logvar_token = nn.Parameter(torch.randn(latent_dim)) + else: + self.emb_token = nn.Parameter(torch.randn(latent_dim)) + + self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout) + + seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation) + + self.seqTransEncoder = nn.TransformerEncoder(seq_trans_encoder_layer, + num_layers=num_layers) + + def forward(self, texts: List[str]) -> Union[Tensor, Distribution]: + text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True) + + x = self.projection(text_encoded) + bs, nframes, _ = x.shape + # bs, nframes, totjoints, nfeats = x.shape + # Switch sequence and batch_size because the input of + # Pytorch Transformer is [Sequence, Batch size, ...] + x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim] + + if self.hparams.vae: + mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1) + logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1) + + # adding the distribution tokens for all sequences + xseq = torch.cat((mu_token[None], logvar_token[None], x), 0) + + # create a bigger mask, to allow attend to mu and logvar + token_mask = torch.ones((bs, 2), dtype=bool, device=x.device) + aug_mask = torch.cat((token_mask, mask), 1) + else: + emb_token = torch.tile(self.emb_token, (bs,)).reshape(bs, -1) + + # adding the embedding token for all sequences + xseq = torch.cat((emb_token[None], x), 0) + + # create a bigger mask, to allow attend to emb + token_mask = torch.ones((bs, 1), dtype=bool, device=x.device) + aug_mask = torch.cat((token_mask, mask), 1) + + # add positional encoding + xseq = self.sequence_pos_encoding(xseq) + final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) + + if self.hparams.vae: + mu, logvar = final[0], final[1] + std = logvar.exp().pow(0.5) + try: + dist = torch.distributions.Normal(mu, std) + except ValueError: + pass + return dist + else: + return final[0] diff --git a/Evaluator_272/mld/models/architectures/tools/embeddings.py b/Evaluator_272/mld/models/architectures/tools/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..632f8c7996314f2a558cbb8e60f687b6f0771fd0 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/tools/embeddings.py @@ -0,0 +1,320 @@ +# This file is taken from signjoey repository +import math + +import torch +from torch import Tensor, nn + + +def get_activation(activation_type): + if activation_type == "relu": + return nn.ReLU() + elif activation_type == "relu6": + return nn.ReLU6() + elif activation_type == "prelu": + return nn.PReLU() + elif activation_type == "selu": + return nn.SELU() + elif activation_type == "celu": + return nn.CELU() + elif activation_type == "gelu": + return nn.GELU() + elif activation_type == "sigmoid": + return nn.Sigmoid() + elif activation_type == "softplus": + return nn.Softplus() + elif activation_type == "softshrink": + return nn.Softshrink() + elif activation_type == "softsign": + return nn.Softsign() + elif activation_type == "tanh": + return nn.Tanh() + elif activation_type == "tanhshrink": + return nn.Tanhshrink() + else: + raise ValueError("Unknown activation type {}".format(activation_type)) + + +class MaskedNorm(nn.Module): + """ + Original Code from: + https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8 + """ + + def __init__(self, norm_type, num_groups, num_features): + super().__init__() + self.norm_type = norm_type + if self.norm_type == "batch": + self.norm = nn.BatchNorm1d(num_features=num_features) + elif self.norm_type == "group": + self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features) + elif self.norm_type == "layer": + self.norm = nn.LayerNorm(normalized_shape=num_features) + else: + raise ValueError("Unsupported Normalization Layer") + + self.num_features = num_features + + def forward(self, x: Tensor, mask: Tensor): + if self.training: + reshaped = x.reshape([-1, self.num_features]) + reshaped_mask = mask.reshape([-1, 1]) > 0 + selected = torch.masked_select(reshaped, reshaped_mask).reshape( + [-1, self.num_features] + ) + batch_normed = self.norm(selected) + scattered = reshaped.masked_scatter(reshaped_mask, batch_normed) + return scattered.reshape([x.shape[0], -1, self.num_features]) + else: + reshaped = x.reshape([-1, self.num_features]) + batched_normed = self.norm(reshaped) + return batched_normed.reshape([x.shape[0], -1, self.num_features]) + + + +class Embeddings(nn.Module): + + """ + Simple embeddings class + """ + + # pylint: disable=unused-argument + def __init__( + self, + embedding_dim: int = 64, + num_heads: int = 8, + scale: bool = False, + scale_factor: float = None, + norm_type: str = None, + activation_type: str = None, + vocab_size: int = 0, + padding_idx: int = 1, + freeze: bool = False, + **kwargs + ): + """ + Create new embeddings for the vocabulary. + Use scaling for the Transformer. + + :param embedding_dim: + :param scale: + :param vocab_size: + :param padding_idx: + :param freeze: freeze the embeddings during training + """ + super().__init__() + + self.embedding_dim = embedding_dim + self.vocab_size = vocab_size + self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx) + + self.norm_type = norm_type + if self.norm_type: + self.norm = MaskedNorm( + norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim + ) + + self.activation_type = activation_type + if self.activation_type: + self.activation = get_activation(activation_type) + + self.scale = scale + if self.scale: + if scale_factor: + self.scale_factor = scale_factor + else: + self.scale_factor = math.sqrt(self.embedding_dim) + + if freeze: + freeze_params(self) + + # pylint: disable=arguments-differ + def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: + """ + Perform lookup for input `x` in the embedding table. + + :param mask: token masks + :param x: index in the vocabulary + :return: embedded representation for `x` + """ + + x = self.lut(x) + + if self.norm_type: + x = self.norm(x, mask) + + if self.activation_type: + x = self.activation(x) + + if self.scale: + return x * self.scale_factor + else: + return x + + def __repr__(self): + return "%s(embedding_dim=%d, vocab_size=%d)" % ( + self.__class__.__name__, + self.embedding_dim, + self.vocab_size, + ) + + +class SpatialEmbeddings(nn.Module): + + """ + Simple Linear Projection Layer + (For encoder outputs to predict glosses) + """ + + # pylint: disable=unused-argument + def __init__( + self, + embedding_dim: int, + input_size: int, + num_heads: int, + freeze: bool = False, + norm_type: str = "batch", + activation_type: str = "softsign", + scale: bool = False, + scale_factor: float = None, + **kwargs + ): + """ + Create new embeddings for the vocabulary. + Use scaling for the Transformer. + + :param embedding_dim: + :param input_size: + :param freeze: freeze the embeddings during training + """ + super().__init__() + + self.embedding_dim = embedding_dim + self.input_size = input_size + self.ln = nn.Linear(self.input_size, self.embedding_dim) + + self.norm_type = norm_type + if self.norm_type: + self.norm = MaskedNorm( + norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim + ) + + self.activation_type = activation_type + if self.activation_type: + self.activation = get_activation(activation_type) + + self.scale = scale + if self.scale: + if scale_factor: + self.scale_factor = scale_factor + else: + self.scale_factor = math.sqrt(self.embedding_dim) + + if freeze: + freeze_params(self) + + # pylint: disable=arguments-differ + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + """ + :param mask: frame masks + :param x: input frame features + :return: embedded representation for `x` + """ + + x = self.ln(x) + + if self.norm_type: + x = self.norm(x, mask) + + if self.activation_type: + x = self.activation(x) + + if self.scale: + return x * self.scale_factor + else: + return x + + def __repr__(self): + return "%s(embedding_dim=%d, input_size=%d)" % ( + self.__class__.__name__, + self.embedding_dim, + self.input_size, + ) + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): + super().__init__() + + self.linear_1 = nn.Linear(channel, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample): + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb diff --git a/Evaluator_272/mld/models/architectures/tools/transformer_layers.py b/Evaluator_272/mld/models/architectures/tools/transformer_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..741e9f28ee69037fe4d210789ed100c1803e4107 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/tools/transformer_layers.py @@ -0,0 +1,281 @@ +# -*- coding: utf-8 -*- +import math +import torch +import torch.nn as nn +from torch import Tensor + +# Took from https://github.com/joeynmt/joeynmt/blob/fb66afcbe1beef9acd59283bcc084c4d4c1e6343/joeynmt/transformer_layers.py + + +# pylint: disable=arguments-differ +class MultiHeadedAttention(nn.Module): + """ + Multi-Head Attention module from "Attention is All You Need" + + Implementation modified from OpenNMT-py. + https://github.com/OpenNMT/OpenNMT-py + """ + + def __init__(self, num_heads: int, size: int, dropout: float = 0.1): + """ + Create a multi-headed attention layer. + :param num_heads: the number of heads + :param size: model size (must be divisible by num_heads) + :param dropout: probability of dropping a unit + """ + super().__init__() + + assert size % num_heads == 0 + + self.head_size = head_size = size // num_heads + self.model_size = size + self.num_heads = num_heads + + self.k_layer = nn.Linear(size, num_heads * head_size) + self.v_layer = nn.Linear(size, num_heads * head_size) + self.q_layer = nn.Linear(size, num_heads * head_size) + + self.output_layer = nn.Linear(size, size) + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None): + """ + Computes multi-headed attention. + + :param k: keys [B, M, D] with M being the sentence length. + :param v: values [B, M, D] + :param q: query [B, M, D] + :param mask: optional mask [B, 1, M] or [B, M, M] + :return: + """ + batch_size = k.size(0) + num_heads = self.num_heads + + # project the queries (q), keys (k), and values (v) + k = self.k_layer(k) + v = self.v_layer(v) + q = self.q_layer(q) + + # reshape q, k, v for our computation to [batch_size, num_heads, ..] + k = k.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) + v = v.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) + q = q.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) + + # compute scores + q = q / math.sqrt(self.head_size) + + # batch x num_heads x query_len x key_len + scores = torch.matmul(q, k.transpose(2, 3)) + # torch.Size([48, 8, 183, 183]) + + # apply the mask (if we have one) + # we add a dimension for the heads to it below: [B, 1, 1, M] + if mask is not None: + scores = scores.masked_fill(~mask.unsqueeze(1), float('-inf')) + + # apply attention dropout and compute context vectors. + attention = self.softmax(scores) + attention = self.dropout(attention) + # torch.Size([48, 8, 183, 183]) [bs, nheads, time, time] (for decoding) + + # v: torch.Size([48, 8, 183, 32]) (32 is 256/8) + # get context vector (select values with attention) and reshape + # back to [B, M, D] + context = torch.matmul(attention, v) # torch.Size([48, 8, 183, 32]) + context = context.transpose(1, 2).contiguous().view( + batch_size, -1, num_heads * self.head_size) + # torch.Size([48, 183, 256]) put back to 256 (combine the heads) + + output = self.output_layer(context) + # torch.Size([48, 183, 256]): 1 output per time step + + return output + + +# pylint: disable=arguments-differ +class PositionwiseFeedForward(nn.Module): + """ + Position-wise Feed-forward layer + Projects to ff_size and then back down to input_size. + """ + + def __init__(self, input_size, ff_size, dropout=0.1): + """ + Initializes position-wise feed-forward layer. + :param input_size: dimensionality of the input. + :param ff_size: dimensionality of intermediate representation + :param dropout: + """ + super().__init__() + self.layer_norm = nn.LayerNorm(input_size, eps=1e-6) + self.pwff_layer = nn.Sequential( + nn.Linear(input_size, ff_size), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(ff_size, input_size), + nn.Dropout(dropout), + ) + + def forward(self, x): + x_norm = self.layer_norm(x) + return self.pwff_layer(x_norm) + x + + +# pylint: disable=arguments-differ +class PositionalEncoding(nn.Module): + """ + Pre-compute position encodings (PE). + In forward pass, this adds the position-encodings to the + input for as many time steps as necessary. + + Implementation based on OpenNMT-py. + https://github.com/OpenNMT/OpenNMT-py + """ + + def __init__(self, + size: int = 0, + max_len: int = 5000): + """ + Positional Encoding with maximum length max_len + :param size: + :param max_len: + :param dropout: + """ + if size % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(size)) + pe = torch.zeros(max_len, size) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp((torch.arange(0, size, 2, dtype=torch.float) * + -(math.log(10000.0) / size))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0) # shape: [1, size, max_len] + super().__init__() + self.register_buffer('pe', pe) + self.dim = size + + def forward(self, emb): + """Embed inputs. + Args: + emb (FloatTensor): Sequence of word vectors + ``(seq_len, batch_size, self.dim)`` + """ + # Add position encodings + return emb + self.pe[:, :emb.size(1)] + + +class TransformerEncoderLayer(nn.Module): + """ + One Transformer encoder layer has a Multi-head attention layer plus + a position-wise feed-forward layer. + """ + + def __init__(self, + size: int = 0, + ff_size: int = 0, + num_heads: int = 0, + dropout: float = 0.1): + """ + A single Transformer layer. + :param size: + :param ff_size: + :param num_heads: + :param dropout: + """ + super().__init__() + + self.layer_norm = nn.LayerNorm(size, eps=1e-6) + self.src_src_att = MultiHeadedAttention(num_heads, size, + dropout=dropout) + self.feed_forward = PositionwiseFeedForward(size, ff_size=ff_size, + dropout=dropout) + self.dropout = nn.Dropout(dropout) + self.size = size + + # pylint: disable=arguments-differ + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + """ + Forward pass for a single transformer encoder layer. + First applies layer norm, then self attention, + then dropout with residual connection (adding the input to the result), + and then a position-wise feed-forward layer. + + :param x: layer input + :param mask: input mask + :return: output tensor + """ + x_norm = self.layer_norm(x) + h = self.src_src_att(x_norm, x_norm, x_norm, mask) + h = self.dropout(h) + x + o = self.feed_forward(h) + return o + + +class TransformerDecoderLayer(nn.Module): + """ + Transformer decoder layer. + + Consists of self-attention, source-attention, and feed-forward. + """ + + def __init__(self, + size: int = 0, + ff_size: int = 0, + num_heads: int = 0, + dropout: float = 0.1): + """ + Represents a single Transformer decoder layer. + + It attends to the source representation and the previous decoder states. + + :param size: model dimensionality + :param ff_size: size of the feed-forward intermediate layer + :param num_heads: number of heads + :param dropout: dropout to apply to input + """ + super().__init__() + self.size = size + + self.trg_trg_att = MultiHeadedAttention(num_heads, size, + dropout=dropout) + self.src_trg_att = MultiHeadedAttention(num_heads, size, + dropout=dropout) + + self.feed_forward = PositionwiseFeedForward(size, ff_size=ff_size, + dropout=dropout) + + self.x_layer_norm = nn.LayerNorm(size, eps=1e-6) + self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6) + + self.dropout = nn.Dropout(dropout) + + # pylint: disable=arguments-differ + def forward(self, + x: Tensor = None, + memory: Tensor = None, + src_mask: Tensor = None, + trg_mask: Tensor = None) -> Tensor: + """ + Forward pass of a single Transformer decoder layer. + + :param x: inputs + :param memory: source representations + :param src_mask: source mask + :param trg_mask: target mask (so as to not condition on future steps) + :return: output tensor + """ + # decoder/target self-attention + x_norm = self.x_layer_norm(x) # torch.Size([48, 183, 256]) + h1 = self.trg_trg_att(x_norm, x_norm, x_norm, mask=trg_mask) + h1 = self.dropout(h1) + x + + # source-target attention + h1_norm = self.dec_layer_norm(h1) # torch.Size([48, 183, 256]) (same for memory) + h2 = self.src_trg_att(memory, memory, h1_norm, mask=src_mask) + + # final position-wise feed-forward layer + o = self.feed_forward(self.dropout(h2) + h1) + + return o diff --git a/Evaluator_272/mld/models/architectures/vision_transformer.py b/Evaluator_272/mld/models/architectures/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d47cf9a4bb7f0e4d6092b55a798f5baec4b7228c --- /dev/null +++ b/Evaluator_272/mld/models/architectures/vision_transformer.py @@ -0,0 +1,954 @@ +""" +This script is borrowed from https://github.com/rwightman/pytorch-image-models. +Adhere to their licence to use this script + +We hacked it a little bit to make it happy in our framework. +""" +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import math +import warnings +import random +import numpy as np +import joblib + +from collections import OrderedDict +from functools import partial +from itertools import repeat +# from torch._six import container_abcs + +from mld.utils.maed_utils import DropPath, determine_output_feature_dim, load_state_dict +from mld.models.architectures.hrnet import get_hrnet +from mld.models.architectures.resnetv2 import ResNetV2 +from .ghost_nas_network import get_ghostnas +from .ghost_nas_network_tiny import get_ghostnas as get_ghostnas_tiny +# from torchvision.models.utils import load_state_dict_from_url + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import math +import warnings +import random +import numpy as np +import joblib + +from collections import OrderedDict +from functools import partial +from itertools import repeat +# from torch._six import container_abcs + +model_urls = { + 'vit_tiny_patch16_224': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + 'vit_small_patch16_224': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', + 'vit_base_patch16_224': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + 'vit_base_patch16_384': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', + 'vit_base_patch32_384': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', + 'vit_large_patch16_224': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', + 'vit_large_patch16_384': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', + 'vit_large_patch32_384': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + 'vit_base_resnet50_224_in21k': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', +} + +model_urls = { + 'vit_tiny_patch16_224': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + 'vit_small_patch16_224': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', + 'vit_base_patch16_224': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + 'vit_base_patch16_384': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', + 'vit_base_patch32_384': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', + 'vit_large_patch16_224': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', + 'vit_large_patch16_384': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', + 'vit_large_patch32_384': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + 'vit_base_resnet50_224_in21k': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', +} + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + st_mode='vanilla'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.mode = st_mode + if self.mode == 'parallel': + self.ts_attn = nn.Linear(dim * 2, dim * 2) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + else: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.attn_count_s = None + self.attn_count_t = None + + def forward(self, x, seqlen=1): + B, N, C = x.shape + + if self.mode == 'series': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute( + 2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_spatial(q, k, v) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute( + 2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_temporal(q, k, v, seqlen=seqlen) + elif self.mode == 'parallel': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute( + 2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + x_t = self.forward_temporal(q, k, v, seqlen=seqlen) + x_s = self.forward_spatial(q, k, v) + + alpha = torch.cat([x_s, x_t], dim=-1) + alpha = alpha.mean(dim=1, keepdim=True) + alpha = self.ts_attn(alpha).reshape(B, 1, C, 2) + alpha = alpha.softmax(dim=-1) + #self.count_attn(alpha) + + x = x_t * alpha[:, :, :, 1] + x_s * alpha[:, :, :, 0] + #x = (x_t + x_s) / 2 + elif self.mode == 'coupling': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute( + 2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_coupling(q, k, v, seqlen=seqlen) + elif self.mode == 'vanilla': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute( + 2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_spatial(q, k, v) + elif self.mode == 'temporal': + x = x.mean(dim=1, keepdim=True) + N = 1 + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute( + 2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_temporal(q, k, v, seqlen=seqlen) + else: + raise NotImplementedError(self.mode) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def reshape_T(self, x, seqlen=1, inverse=False): + if not inverse: + N, C = x.shape[-2:] + x = x.reshape(-1, seqlen, self.num_heads, N, C).transpose(1, 2) + x = x.reshape(-1, self.num_heads, seqlen * N, C) #(B, H, TN, c) + else: + TN, C = x.shape[-2:] + x = x.reshape(-1, self.num_heads, seqlen, TN // seqlen, + C).transpose(1, 2) + x = x.reshape(-1, self.num_heads, TN // seqlen, C) #(BT, H, N, C) + return x + + def forward_coupling(self, q, k, v, seqlen=8): + BT, _, N, C = q.shape + q = self.reshape_T(q, seqlen) + k = self.reshape_T(k, seqlen) + v = self.reshape_T(v, seqlen) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v + x = self.reshape_T(x, seqlen, inverse=True) + x = x.transpose(1, 2).reshape(BT, N, C * self.num_heads) + return x + + def forward_spatial(self, q, k, v): + B, _, N, C = q.shape + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v + x = x.transpose(1, 2).reshape(B, N, C * self.num_heads) + return x + + def forward_temporal(self, q, k, v, seqlen=8): + B, _, N, C = q.shape + qt = q.reshape(-1, seqlen, self.num_heads, N, + C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) + kt = k.reshape(-1, seqlen, self.num_heads, N, + C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) + vt = v.reshape(-1, seqlen, self.num_heads, N, + C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) + + attn = (qt @ kt.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = attn @ vt #(B, H, N, T, C) + x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C * self.num_heads) + return x + + def count_attn(self, attn): + attn = attn.detach().cpu().numpy() + attn = attn.mean(axis=1) + attn_t = attn[:, :, 1].mean(axis=1) + attn_s = attn[:, :, 0].mean(axis=1) + if self.attn_count_s is None: + self.attn_count_s = attn_s + self.attn_count_t = attn_t + else: + self.attn_count_s = np.concatenate([self.attn_count_s, attn_s], + axis=0) + self.attn_count_t = np.concatenate([self.attn_count_t, attn_t], + axis=0) + + +class Block(nn.Module): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + st_mode='vanilla'): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + st_mode=st_mode) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x, seqlen=1): + x = x + self.drop_path(self.attn(self.norm1(x), seqlen)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = tuple(repeat(img_size, 2)) + patch_size = tuple(repeat(patch_size, 2)) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // + patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, + backbone, + img_size=224, + feature_size=None, + in_chans=3, + embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = tuple(repeat(img_size, 2)) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + feature_size, feature_dim = determine_output_feature_dim( + inp_size=(1, in_chans, img_size[0], img_size[1]), + model=self.backbone) + else: + feature_size = tuple(repeat(feature_size, n)) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Conv2d(feature_dim, embed_dim, 1) + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[ + -1] # last feature if backbone outputs list/tuple of features + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + representation_size=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + hybrid_backbone=None, + norm_layer=nn.LayerNorm, + st_mode='vanilla'): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed(hybrid_backbone, + img_size=img_size, + in_chans=in_chans, + embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed(img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block(dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + st_mode=st_mode) for i in range(depth) + ]) + self.norm = norm_layer(embed_dim) + self.st_mode = st_mode + + # Representation layer + if representation_size: + self.num_features = representation_size + self.pre_logits = nn.Sequential( + OrderedDict([('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh())])) + else: + self.pre_logits = nn.Identity() + + # Classifier head + self.head = nn.Linear( + embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + + if st_mode in ['coupling', 'parallel', 'series']: + self.temp_embed = nn.Parameter(torch.zeros(1, 16, 1, embed_dim)) + trunc_normal_(self.temp_embed, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear( + self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, seqlen=1): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand( + B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + + if self.st_mode in ['coupling', 'parallel', 'series']: + _, N, C = x.shape + x = x.reshape(-1, seqlen, N, C) + self.temp_embed[:, :seqlen, :, :] + x = x.reshape(B, N, C) + + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x, seqlen) + + x = self.norm(x)[:, 0] + x = self.pre_logits(x) + return x + + def forward(self, x, seqlen=1): + x = self.forward_features(x, seqlen) + x = self.head(x) + return x + + +def _conv_filter(state_dict, patch_size=16): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict + + +def vit_small_patch16_224(pretrained=False, strict=True, **kwargs): + if pretrained: + # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model + kwargs.setdefault('qk_scale', 768**-0.5) + model = VisionTransformer(patch_size=16, + embed_dim=768, + depth=8, + num_heads=8, + mlp_ratio=3., + **kwargs) + if pretrained: + state_dict = model_zoo.load_url(model_urls['vit_small_patch16_224'], + progress=False, + map_location='cpu') + state_dict = _conv_filter(state_dict) + if kwargs['num_classes'] != 1000: + del state_dict['head.weight'] + del state_dict['head.bias'] + model.load_state_dict(state_dict, strict=strict) + return model + + +def vit_base_patch16_224(pretrained=False, strict=True, **kwargs): + model = VisionTransformer(patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + if pretrained: + state_dict = model_zoo.load_url(model_urls['vit_base_patch16_224'], + progress=False, + map_location='cpu') + state_dict = _conv_filter(state_dict) + if kwargs['num_classes'] != 1000: + del state_dict['head.weight'] + del state_dict['head.bias'] + model.load_state_dict(state_dict, strict=strict) + return model + + +def vit_base_patch16_384(pretrained=False, strict=True, **kwargs): + model = VisionTransformer(img_size=384, + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + if pretrained: + state_dict = model_zoo.load_url(model_urls['vit_base_patch16_384'], + progress=False, + map_location='cpu') + if kwargs['num_classes'] != 1000: + del state_dict['head.weight'] + del state_dict['head.bias'] + model.load_state_dict(state_dict, strict=strict) + return model + + +def vit_base_patch32_384(pretrained=False, strict=True, **kwargs): + model = VisionTransformer(img_size=384, + patch_size=32, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + if pretrained: + state_dict = model_zoo.load_url(model_urls['vit_base_patch32_384'], + progress=False, + map_location='cpu') + if kwargs['num_classes'] != 1000: + del state_dict['head.weight'] + del state_dict['head.bias'] + model.load_state_dict(state_dict, strict=strict) + return model + + +def vit_large_patch16_224(pretrained=False, strict=True, **kwargs): + model = VisionTransformer(patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + if pretrained: + state_dict = model_zoo.load_url(model_urls['vit_large_patch16_224'], + progress=False, + map_location='cpu') + if kwargs['num_classes'] != 1000: + del state_dict['head.weight'] + del state_dict['head.bias'] + model.load_state_dict(state_dict, strict=strict) + return model + + +def vit_large_patch16_384(pretrained=False, strict=True, **kwargs): + model = VisionTransformer(img_size=384, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + if pretrained: + state_dict = model_zoo.load_url(model_urls['vit_large_patch16_384'], + progress=False, + map_location='cpu') + if kwargs['num_classes'] != 1000: + del state_dict['head.weight'] + del state_dict['head.bias'] + model.load_state_dict(state_dict, strict=strict) + return model + + +def vit_large_patch32_384(pretrained=False, strict=True, **kwargs): + model = VisionTransformer(img_size=384, + patch_size=32, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + if pretrained: + state_dict = model_zoo.load_url(model_urls['vit_large_patch32_384'], + progress=False, + map_location='cpu') + if kwargs['num_classes'] != 1000: + del state_dict['head.weight'] + del state_dict['head.bias'] + model.load_state_dict(state_dict, strict=strict) + return model + + +def vit_huge_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer(patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + **kwargs) + model.default_cfg = default_cfgs['vit_huge_patch16_224'] + return model + + +def vit_huge_patch32_384(pretrained=False, **kwargs): + model = VisionTransformer(img_size=384, + patch_size=32, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + **kwargs) + model.default_cfg = default_cfgs['vit_huge_patch32_384'] + return model + + +def vit_base_resnet50_224_in21k(pretrained=False, strict=True, **kwargs): + """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head + backbone = ResNetV2(layers=(3, 4, 9), + num_classes=0, + global_pool='', + in_chans=kwargs.get('in_chans', 3), + preact=False, + stem_type='same') + model = VisionTransformer(patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + hybrid_backbone=backbone, + mlp_ratio=4, + qkv_bias=True, + representation_size=768, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + if pretrained: + state_dict = model_zoo.load_url( + model_urls['vit_base_resnet50_224_in21k'], + progress=False, + map_location='cpu') + state_dict = _conv_filter(state_dict) + if kwargs['num_classes'] != 1000: + del state_dict['head.weight'] + del state_dict['head.bias'] + model.load_state_dict(state_dict, strict=strict) + return model + + +def vit_custom_resnet50_224_in21k(num_blocks, + num_heads, + st_mode, + pretrained=True, + **kwargs): + """ Hybrid model with a R50 and a Vit of custom layers . + """ + # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head + backbone = ResNetV2(layers=(3, 4, 9), + num_classes=0, + global_pool='', + in_chans=kwargs.get('in_chans', 3), + preact=False, + stem_type='same') + model = VisionTransformer(patch_size=16, + embed_dim=768, + depth=num_blocks, + num_heads=num_heads, + hybrid_backbone=backbone, + mlp_ratio=4, + qkv_bias=True, + representation_size=768, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + st_mode=st_mode, + **kwargs) + if pretrained: + state_dict = model_zoo.load_url( + model_urls['vit_base_resnet50_224_in21k'], + progress=False, + map_location='cpu') + state_dict = _conv_filter(state_dict) + del state_dict['head.weight'] + del state_dict['head.bias'] + model.load_state_dict(state_dict, strict=False) + return model + + +def vit_custom_resnet50_320_in21k(image_size, + num_blocks, + num_heads, + st_mode, + pretrained=True, + **kwargs): + """ Hybrid model with a R50 and a Vit of custom layers . + """ + + # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head + backbone = ResNetV2(layers=(3, 4, 9), + num_classes=0, + global_pool='', + in_chans=kwargs.get('in_chans', 3), + preact=False, + stem_type='same') + model = VisionTransformer(img_size=image_size, + patch_size=16, + embed_dim=768, + depth=num_blocks, + num_heads=num_heads, + hybrid_backbone=backbone, + mlp_ratio=4, + qkv_bias=True, + representation_size=768, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + st_mode=st_mode, + **kwargs) + if pretrained: + state_dict = model_zoo.load_url( + model_urls['vit_base_resnet50_224_in21k'], + progress=False, + map_location='cpu') + state_dict = _conv_filter(state_dict) + del state_dict['head.weight'] + del state_dict['head.bias'] + del state_dict['pos_embed'] + model.load_state_dict(state_dict, strict=False) + return model + + +def vit_custom_ghostnet_224_in21k(num_blocks, + num_heads, + st_mode, + pretrained=True, + embed_dim=768, + tiny=False, + **kwargs): + """ Hybrid model with a R50 and a Vit of custom layers . + """ + # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head + if tiny: + backbone = get_ghostnas_tiny(flops=170) + else: + backbone = get_ghostnas(flops=170) + model = VisionTransformer(patch_size=16, + embed_dim=embed_dim, + depth=num_blocks, + num_heads=num_heads, + hybrid_backbone=backbone, + mlp_ratio=4, + qkv_bias=True, + representation_size=embed_dim, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + st_mode=st_mode, + **kwargs) + if pretrained: + # state_dict = model_zoo.load_url(model_urls['vit_base_resnet50_224_in21k'], progress=False, map_location='cpu') + PRETRAINED = "/apdcephfs/share_1227775/sylvainliu/data/smpldatas/ghostnas_170M_pretrain_1141226.pth" + state_dict = torch.load(PRETRAINED) + state_dict = _conv_filter(state_dict) + # del state_dict['head.weight'] + # del state_dict['head.bias'] + model.patch_embed.load_state_dict(state_dict, strict=False) + return model + + +def vit_custom_hrnet48_224_in21k(image_size, + num_blocks, + num_heads, + st_mode, + pretrained=True, + **kwargs): + """ Hybrid model with a R50 and a Vit of custom layers . + """ + + # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head + backbone = get_hrnet(model_type='hrnet48', input_size=224, pretrained=True) + + model = VisionTransformer(img_size=image_size, + patch_size=16, + embed_dim=768, + depth=num_blocks, + num_heads=num_heads, + hybrid_backbone=backbone, + mlp_ratio=4, + qkv_bias=True, + representation_size=768, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + st_mode=st_mode, + **kwargs) + if pretrained: + state_dict = model_zoo.load_url( + model_urls['vit_base_resnet50_224_in21k'], + progress=False, + map_location='cpu') + state_dict = _conv_filter(state_dict) + model = load_state_dict(model, state_dict) + return model + + +def vit_custom_hrnet48_320_in21k(image_size, + num_blocks, + num_heads, + st_mode, + pretrained=True, + **kwargs): + """ Hybrid model with a R50 and a Vit of custom layers . + """ + + # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head + backbone = get_hrnet(model_type='hrnet48', input_size=320, pretrained=True) + + model = VisionTransformer(img_size=image_size, + patch_size=16, + embed_dim=768, + depth=num_blocks, + num_heads=num_heads, + hybrid_backbone=backbone, + mlp_ratio=4, + qkv_bias=True, + representation_size=768, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + st_mode=st_mode, + **kwargs) + if pretrained: + state_dict = model_zoo.load_url( + model_urls['vit_base_resnet50_224_in21k'], + progress=False, + map_location='cpu') + state_dict = _conv_filter(state_dict) + model = load_state_dict(model, state_dict) + return model diff --git a/Evaluator_272/mld/models/architectures/vposert_vae.py b/Evaluator_272/mld/models/architectures/vposert_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..4941a6a08f3fd763b456b0e6c8b8931a28227086 --- /dev/null +++ b/Evaluator_272/mld/models/architectures/vposert_vae.py @@ -0,0 +1,113 @@ +from functools import reduce +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor, nn +from torch.distributions.distribution import Distribution + +from mld.models.architectures.tools.embeddings import (TimestepEmbedding, + Timesteps) +from mld.models.operator import PositionalEncoding +from mld.models.operator.cross_attention import ( + SkipTransformerEncoder, SkipTransformerDecoder, TransformerDecoder, + TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer) +from mld.models.operator.position_encoding import build_position_encoding +from mld.utils.temos_utils import lengths_to_mask +''' +vae +skip connection encoder +skip connection decoder +mem for each decoder layer +''' + + +class VPosert(nn.Module): + + def __init__(self, cfg, **kwargs) -> None: + + super(VPosert, self).__init__() + + num_neurons = 512 + self.latentD = 256 + + # self.num_joints = 21 + n_features = 196 * 263 + + self.encoder_net = nn.Sequential( + BatchFlatten(), nn.BatchNorm1d(n_features), + nn.Linear(n_features, num_neurons), nn.LeakyReLU(), + nn.BatchNorm1d(num_neurons), nn.Dropout(0.1), + nn.Linear(num_neurons, num_neurons), + nn.Linear(num_neurons, num_neurons), + NormalDistDecoder(num_neurons, self.latentD)) + + self.decoder_net = nn.Sequential( + nn.Linear(self.latentD, num_neurons), + nn.LeakyReLU(), + nn.Dropout(0.1), + nn.Linear(num_neurons, num_neurons), + nn.LeakyReLU(), + nn.Linear(num_neurons, n_features), + ContinousRotReprDecoder(), + ) + + def forward(self, features: Tensor, lengths: Optional[List[int]] = None): + q_z = self.encode(features) + feats_rst = self.decode(q_z) + return feats_rst, q_z + + def encode(self, pose_body, lengths: Optional[List[int]] = None): + ''' + :param Pin: Nx(numjoints*3) + :param rep_type: 'matrot'/'aa' for matrix rotations or axis-angle + :return: + ''' + q_z = self.encoder_net(pose_body) + q_z_sample = q_z.rsample() + return q_z_sample.unsqueeze(0), q_z + + def decode(self, Zin, lengths: Optional[List[int]] = None): + bs = Zin.shape[0] + Zin = Zin[0] + + prec = self.decoder_net(Zin) + + return prec + + + +class BatchFlatten(nn.Module): + + def __init__(self): + super(BatchFlatten, self).__init__() + self._name = 'batch_flatten' + + def forward(self, x): + return x.view(x.shape[0], -1) + + +class ContinousRotReprDecoder(nn.Module): + + def __init__(self): + super(ContinousRotReprDecoder, self).__init__() + + def forward(self, module_input): + reshaped_input = module_input.view(-1, 196, 263) + + return reshaped_input + + +class NormalDistDecoder(nn.Module): + + def __init__(self, num_feat_in, latentD): + super(NormalDistDecoder, self).__init__() + + self.mu = nn.Linear(num_feat_in, latentD) + self.logvar = nn.Linear(num_feat_in, latentD) + + def forward(self, Xout): + return torch.distributions.normal.Normal(self.mu(Xout), + F.softplus(self.logvar(Xout))) diff --git a/Evaluator_272/mld/models/body_skeleton/__init__.py b/Evaluator_272/mld/models/body_skeleton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/models/body_skeleton/paramUtil.py b/Evaluator_272/mld/models/body_skeleton/paramUtil.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ea8218a329a828c48a24c8b298ffae84e0e2fe --- /dev/null +++ b/Evaluator_272/mld/models/body_skeleton/paramUtil.py @@ -0,0 +1,98 @@ +import numpy as np + +# Define a kinematic tree for the skeletal struture +kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] + +kit_raw_offsets = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1] + ] +) + +t2m_raw_offsets = np.array([[0,0,0], + [1,0,0], + [-1,0,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,0,1], + [0,0,1], + [0,1,0], + [1,0,0], + [-1,0,0], + [0,0,1], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0]]) + +# 30 +t2m_hand_raw_offsets = np.array([[1,0,0], # left_index1 + [1,0,0], # left_index2 + [1,0,0], # left_index3 + [1,0,0], # left_middle1 + [1,0,0], # left_middle2 + [1,0,0], # left_middle3 + [1,0,0], # left_pinky1 + [1,0,0], # left_pinky2 + [1,0,0], # left_pinky3 + [1,0,0], # left_ring1 + [1,0,0], # left_ring2 + [1,0,0], # left_ring3 + [1,0,0], # left_thumb1 + [1,0,0], # left_thumb2 + [1,0,0], # left_thumb3 + [-1,0,0], # right_index1 + [-1,0,0], # right_index2 + [-1,0,0], # right_index3 + [-1,0,0], # right_middle1 + [-1,0,0], # right_middle2 + [-1,0,0], # right_middle3 + [-1,0,0], # right_pinky1 + [-1,0,0], # right_pinky2 + [-1,0,0], # right_pinky3 + [-1,0,0], # right_ring1 + [-1,0,0], # right_ring2 + [-1,0,0], # right_ring3 + [-1,0,0], # right_thumb1 + [-1,0,0], # right_thumb2 + [-1,0,0],]) # right_thumb3 + +t2m_raw_body_hand_offsets = np.concatenate((t2m_raw_offsets, t2m_hand_raw_offsets), axis=0) + +t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] +t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] +t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] + +t2m_body_hand_kinematic_chain = t2m_kinematic_chain + t2m_left_hand_chain + t2m_right_hand_chain + +kit_tgt_skel_id = '03950' + +t2m_tgt_skel_id = '000021' + diff --git a/Evaluator_272/mld/models/body_skeleton/quaternion.py b/Evaluator_272/mld/models/body_skeleton/quaternion.py new file mode 100644 index 0000000000000000000000000000000000000000..dca3d890080a4e91e3f275f442b0aed006562881 --- /dev/null +++ b/Evaluator_272/mld/models/body_skeleton/quaternion.py @@ -0,0 +1,423 @@ +# Copyright (c) 2018-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import numpy as np + +_EPS4 = np.finfo(float).eps * 4.0 + +_FLOAT_EPS = np.finfo(np.float64).eps + +# PyTorch-backed implementations +def qinv(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + + +def qinv_np(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return qinv(torch.from_numpy(q).float()).numpy() + + +def qnormalize(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return q / torch.norm(q, dim=-1, keepdim=True) + + +def qmul(q, r): + """ + Multiply quaternion(s) q with quaternion(s) r. + Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. + Returns q*r as a tensor of shape (*, 4). + """ + assert q.shape[-1] == 4 + assert r.shape[-1] == 4 + + original_shape = q.shape + + # Compute outer product + terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) + + w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] + x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] + y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] + z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] + return torch.stack((w, x, y, z), dim=1).view(original_shape) + + +def qrot(q, v): + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + # print(q.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) + + +def qeuler(q, order, epsilon=0, deg=True): + """ + Convert quaternion(s) q to Euler angles. + Expects a tensor of shape (*, 4), where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + + original_shape = list(q.shape) + original_shape[-1] = 3 + q = q.view(-1, 4) + + q0 = q[:, 0] + q1 = q[:, 1] + q2 = q[:, 2] + q3 = q[:, 3] + + if order == 'xyz': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + elif order == 'yzx': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) + elif order == 'zxy': + x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'xzy': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) + elif order == 'yxz': + x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'zyx': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + else: + raise + + if deg: + return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi + else: + return torch.stack((x, y, z), dim=1).view(original_shape) + + +# Numpy-backed implementations + +def qmul_np(q, r): + q = torch.from_numpy(q).contiguous().float() + r = torch.from_numpy(r).contiguous().float() + return qmul(q, r).numpy() + + +def qrot_np(q, v): + q = torch.from_numpy(q).contiguous().float() + v = torch.from_numpy(v).contiguous().float() + return qrot(q, v).numpy() + + +def qeuler_np(q, order, epsilon=0, use_gpu=False): + if use_gpu: + q = torch.from_numpy(q).cuda().float() + return qeuler(q, order, epsilon).cpu().numpy() + else: + q = torch.from_numpy(q).contiguous().float() + return qeuler(q, order, epsilon).numpy() + + +def qfix(q): + """ + Enforce quaternion continuity across the time dimension by selecting + the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) + between two consecutive frames. + + Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. + Returns a tensor of the same shape. + """ + assert len(q.shape) == 3 + assert q.shape[-1] == 4 + + result = q.copy() + dot_products = np.sum(q[1:] * q[:-1], axis=2) + mask = dot_products < 0 + mask = (np.cumsum(mask, axis=0) % 2).astype(bool) + result[1:][mask] *= -1 + return result + + +def euler2quat(e, order, deg=True): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.view(-1, 3) + + ## if euler angles in degrees + if deg: + e = e * np.pi / 180. + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) + ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) + rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.view(original_shape) + + +def expmap_to_quaternion(e): + """ + Convert axis-angle rotations (aka exponential maps) to quaternions. + Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". + Expects a tensor of shape (*, 3), where * denotes any number of dimensions. + Returns a tensor of shape (*, 4). + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + e = e.reshape(-1, 3) + + theta = np.linalg.norm(e, axis=1).reshape(-1, 1) + w = np.cos(0.5 * theta).reshape(-1, 1) + xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e + return np.concatenate((w, xyz), axis=1).reshape(original_shape) + + +def euler_to_quaternion(e, order): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.reshape(-1, 3) + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) + ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) + rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul_np(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.reshape(original_shape) + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def quaternion_to_matrix_np(quaternions): + q = torch.from_numpy(quaternions).contiguous().float() + return quaternion_to_matrix(q).numpy() + + +def quaternion_to_cont6d_np(quaternions): + rotation_mat = quaternion_to_matrix_np(quaternions) + cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) + return cont_6d + + +def quaternion_to_cont6d(quaternions): + rotation_mat = quaternion_to_matrix(quaternions) + cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) + return cont_6d + + +def cont6d_to_matrix(cont6d): + assert cont6d.shape[-1] == 6, "The last dimension must be 6" + x_raw = cont6d[..., 0:3] + y_raw = cont6d[..., 3:6] + + x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) + z = torch.cross(x, y_raw, dim=-1) + z = z / torch.norm(z, dim=-1, keepdim=True) + + y = torch.cross(z, x, dim=-1) + + x = x[..., None] + y = y[..., None] + z = z[..., None] + + mat = torch.cat([x, y, z], dim=-1) + return mat + + +def cont6d_to_matrix_np(cont6d): + q = torch.from_numpy(cont6d).contiguous().float() + return cont6d_to_matrix(q).numpy() + + +def qpow(q0, t, dtype=torch.float): + ''' q0 : tensor of quaternions + t: tensor of powers + ''' + q0 = qnormalize(q0) + theta0 = torch.acos(q0[..., 0]) + + ## if theta0 is close to zero, add epsilon to avoid NaNs + mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) + theta0 = (1 - mask) * theta0 + mask * 10e-10 + v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) + + if isinstance(t, torch.Tensor): + q = torch.zeros(t.shape + q0.shape) + theta = t.view(-1, 1) * theta0.view(1, -1) + else: ## if t is a number + q = torch.zeros(q0.shape) + theta = t * theta0 + + q[..., 0] = torch.cos(theta) + q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) + + return q.to(dtype) + + +def qslerp(q0, q1, t): + ''' + q0: starting quaternion + q1: ending quaternion + t: array of points along the way + + Returns: + Tensor of Slerps: t.shape + q0.shape + ''' + + q0 = qnormalize(q0) + q1 = qnormalize(q1) + q_ = qpow(qmul(q1, qinv(q0)), t) + + return qmul(q_, + q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous()) + + +def qbetween(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v = torch.cross(v0, v1) + w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, + keepdim=True) + return qnormalize(torch.cat([w, v], dim=-1)) + + +def qbetween_np(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v0 = torch.from_numpy(v0).float() + v1 = torch.from_numpy(v1).float() + return qbetween(v0, v1).numpy() + + +def lerp(p0, p1, t): + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]) + + new_shape = t.shape + p0.shape + new_view_t = t.shape + torch.Size([1] * len(p0.shape)) + new_view_p = torch.Size([1] * len(t.shape)) + p0.shape + p0 = p0.view(new_view_p).expand(new_shape) + p1 = p1.view(new_view_p).expand(new_shape) + t = t.view(new_view_t).expand(new_shape) + + return p0 + t * (p1 - p0) diff --git a/Evaluator_272/mld/models/body_skeleton/skeleton.py b/Evaluator_272/mld/models/body_skeleton/skeleton.py new file mode 100644 index 0000000000000000000000000000000000000000..ac25a5773a04202c2059086dbf2e0d634fdf0290 --- /dev/null +++ b/Evaluator_272/mld/models/body_skeleton/skeleton.py @@ -0,0 +1,271 @@ +from .quaternion import * +import scipy.ndimage.filters as filters + +class Skeleton(object): + def __init__(self, offset, kinematic_tree): + self._raw_offset_np = offset.numpy() + self._raw_offset = offset.clone().detach().float() + self._kinematic_tree = kinematic_tree + self._offset = None + self._parents = [0] * len(self._raw_offset) + self._parents[0] = -1 + for chain in self._kinematic_tree: + for j in range(1, len(chain)): + self._parents[chain[j]] = chain[j-1] + + def njoints(self): + return len(self._raw_offset) + + def offset(self): + return self._offset + + def set_offset(self, offsets): + self._offset = offsets.clone().detach().float() + + def kinematic_tree(self): + return self._kinematic_tree + + def parents(self): + return self._parents + + # joints (batch_size, joints_num, 3) + def get_offsets_joints_batch(self, joints): + assert len(joints.shape) == 3 + _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() + for i in range(1, self._raw_offset.shape[0]): + _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] + + self._offset = _offsets.detach() + return _offsets + + # joints (joints_num, 3) + def get_offsets_joints(self, joints): + assert len(joints.shape) == 2 + _offsets = self._raw_offset.clone() + for i in range(1, self._raw_offset.shape[0]): + # print(joints.shape) + _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] + + self._offset = _offsets.detach() + return _offsets + + # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder + # joints (batch_size, joints_num, 3) + def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): + assert len(face_joint_idx) == 4 + '''Get Forward Direction''' + l_hip, r_hip, sdr_r, sdr_l = face_joint_idx + across1 = joints[:, r_hip] - joints[:, l_hip] + across2 = joints[:, sdr_r] - joints[:, sdr_l] + across = across1 + across2 + across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] + # print(across1.shape, across2.shape) + + # forward (batch_size, 3) + forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + if smooth_forward: + forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') + # forward (batch_size, 3) + forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] + + '''Get Root Rotation''' + target = np.array([[0,0,1]]).repeat(len(forward), axis=0) + root_quat = qbetween_np(forward, target) + + '''Inverse Kinematics''' + # quat_params (batch_size, joints_num, 4) + # print(joints.shape[:-1]) + quat_params = np.zeros(joints.shape[:-1] + (4,)) + # print(quat_params.shape) + root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + quat_params[:, 0] = root_quat + # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + for chain in self._kinematic_tree: + R = root_quat + for j in range(len(chain) - 1): + # (batch, 3) + u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) + # print(u.shape) + # (batch, 3) + v = joints[:, chain[j+1]] - joints[:, chain[j]] + v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] + # print(u.shape, v.shape) + rot_u_v = qbetween_np(u, v) + + R_loc = qmul_np(qinv_np(R), rot_u_v) + + quat_params[:,chain[j + 1], :] = R_loc + R = qmul_np(R, R_loc) + + return quat_params + + # joints (batch_size, frames, joints_num, 3) + def inverse_kinematics(self, joints, face_joint_idx, smooth_forward=False): + + bs = joints.shape[0] + frame = joints.shape[1] + joint_num = joints.shape[2] + + + joints = joints.reshape(-1, joints.shape[-2], joints.shape[-1]) + assert len(face_joint_idx) == 4 + '''Get Forward Direction''' + l_hip, r_hip, sdr_r, sdr_l = face_joint_idx + across1 = joints[:,r_hip] - joints[:, l_hip] + across2 = joints[:,sdr_r] - joints[:, sdr_l] + across = across1 + across2 + + # across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] + # across = data / np.sqrt((data**2).sum(axis=-1))[:, np.newaxis] + across = across / torch.sqrt((across**2).sum(dim=-1)).unsqueeze(1) + + + # print(across1.shape, across2.shape) + + # forward (batch_size, 3) + # forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + forward = torch.cross(torch.tensor([[0, 1, 0]], dtype=torch.float32).to(joints.device), across, dim=-1) + + if smooth_forward: + forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') + # forward (batch_size, 3) + # forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] + # forward = torch.norm(forward, p=2, dim=-1).unsqueeze(-1) + forward = forward / torch.sqrt((forward**2).sum(dim=-1)).unsqueeze(-1) + + '''Get Root Rotation''' + # target = np.array([[0,0,1]]).repeat(len(forward), axis=0) + target = torch.tensor([[0, 0, 1]], dtype=torch.float32).expand(len(forward), -1).to(joints.device) + root_quat = qbetween(forward, target) + + '''Inverse Kinematics''' + # quat_params (batch_size, joints_num, 4) + # print(joints.shape[:-1]) + # quat_params = np.zeros(joints.shape[:-1] + (4,)) + quat_params = torch.zeros(joints.shape[:-1] + (4,)).to(joints.device) + # print(quat_params.shape) + root_quat[0] = torch.tensor([[1.0, 0.0, 0.0, 0.0]]) + quat_params[:, 0] = root_quat + # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + + for chain in self._kinematic_tree: + R = root_quat + for j in range(len(chain) - 1): + # (batch, 3) + u = torch.from_numpy(self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)).to(joints.device) + # print(u.shape) + # (batch, 3) + v = joints[:, chain[j+1]] - joints[:, chain[j]] + # v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] + v = v / torch.sqrt((v**2).sum(dim=-1)).unsqueeze(1) + # v = torch.norm(v, p =2, dim=-1).unsqueeze(1) + # print(u.shape, v.shape) + rot_u_v = qbetween(u.float(), v) + + R_loc = qmul(qinv(R), rot_u_v) + + quat_params[:,chain[j + 1], :] = R_loc + R = qmul(R, R_loc) + + quat_params = quat_params.reshape(bs, frame, joint_num,4) + + return quat_params + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + joints = torch.zeros(quat_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach() + for i in range(1, len(chain)): + R = qmul(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] + return joints + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(quat_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) + for i in range(1, len(chain)): + R = qmul_np(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] + return joints + + def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(cont6d_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix_np(cont6d_params[:, 0]) + else: + matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) + for i in range(1, len(chain)): + matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]][..., np.newaxis] + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + + if skel_joints is not None: + # skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) + joints[..., 0, :] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix(cont6d_params[:, 0]) + else: + matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach() + for i in range(1, len(chain)): + matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]].unsqueeze(-1).to(matR.device) + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + + + + diff --git a/Evaluator_272/mld/models/get_model.py b/Evaluator_272/mld/models/get_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ece21898323c317b77698ad1481bfcfb5aaa2640 --- /dev/null +++ b/Evaluator_272/mld/models/get_model.py @@ -0,0 +1,17 @@ +import importlib + + +def get_model(cfg, datamodule, phase="train"): + modeltype = cfg.model.model_type + if modeltype in ["mld", "temos", "gpt"]: + return get_module(cfg, datamodule) + else: + raise ValueError(f"Invalid model type {modeltype}.") + + +def get_module(cfg, datamodule): + modeltype = cfg.model.model_type + model_module = importlib.import_module( + f".modeltype.{cfg.model.model_type}", package="mld.models") + Model = model_module.__getattribute__(f"{modeltype.upper()}") + return Model(cfg=cfg, datamodule=datamodule) diff --git a/Evaluator_272/mld/models/losses/__init__.py b/Evaluator_272/mld/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea10916883c979d359245cd09fb688035aedc7ea --- /dev/null +++ b/Evaluator_272/mld/models/losses/__init__.py @@ -0,0 +1,2 @@ +from mld.models.losses.temos import TemosLosses +from mld.models.losses.tmost import TmostLosses diff --git a/Evaluator_272/mld/models/losses/actor.py b/Evaluator_272/mld/models/losses/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..49be7134091b236eb014ee6368371b565ecf613e --- /dev/null +++ b/Evaluator_272/mld/models/losses/actor.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +from torchmetrics import Metric + +class ACTORLosses(Metric): + """ + Loss + Modify loss + + """ + def __init__(self, vae, mode, cfg): + super().__init__(dist_sync_on_step=cfg.LOSS.DIST_SYNC_ON_STEP) + + # Save parameters + self.vae = vae + self.mode = mode + + losses = [] + losses.append("recons_feature") + losses.append("recons_verts") + losses.append("recons_joints") + losses.append("recons_limb") + + # latent loss + losses.append("latent_st2sm") + + # KL loss + losses.append("kl_motion") + losses.append("total") + + for loss in losses: + self.register_buffer(loss, torch.tensor(0.0)) + self.register_buffer("count", torch.tensor(0)) + self.losses = losses + + self._losses_func = {} + self._params = {} + for loss in losses: + if loss !='total': + if loss.split('_')[0] == 'kl': + self._losses_func[loss] = KLLoss() + self._params[loss] = cfg.LOSS.LAMBDA_KL + elif loss.split('_')[0] == 'recons': + self._losses_func[loss] = torch.nn.SmoothL1Loss(reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_REC + elif loss.split('_')[0] == 'cross': + self._losses_func[loss] = torch.nn.SmoothL1Loss(reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_CROSS + elif loss.split('_')[0] =='latent': + self._losses_func[loss] = torch.nn.SmoothL1Loss(reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_LATENT + elif loss.split('_')[0] =='cycle': + self._losses_func[loss] = torch.nn.SmoothL1Loss(reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_CYCLE + else: + ValueError("This loss is not recognized.") + + + def update(self, rs_set, dist_ref): + total: float = 0.0 + # Compute the losses + # loss1 - reconstruction loss + total += self._update_loss("recons_feature", rs_set['m_rst'], rs_set['m_ref']) + # total += self._update_loss("recons_verts", rs_set['verts_rs'], rs_set['verts_ref']) + # total += self._update_loss("recons_joints", rs_set['joints_rs'], rs_set['joints_ref']) + # total += self._update_loss("recons_limb", rs_set['rs_base'], rs_set['m1']) + + # loss - text motion latent loss + total += self._update_loss("kl_motion", rs_set['dist_m'], dist_ref) + + self.total += total.detach() + self.count += 1 + + return total + + def compute(self, split): + count = getattr(self, "count") + return {loss: getattr(self, loss)/count for loss in self.losses} + + def _update_loss(self, loss: str, outputs, inputs): + # Update the loss + val = self._losses_func[loss](outputs, inputs) + getattr(self, loss).__iadd__(val.detach()) + # Return a weighted sum + weighted_loss = self._params[loss] * val + return weighted_loss + + def loss2logname(self, loss: str, split: str): + if loss == "total": + log_name = f"{loss}/{split}" + else: + loss_type, name = loss.split("_") + log_name = f"{loss_type}/{name}/{split}" + return log_name + + +class KLLoss: + def __init__(self): + pass + + def __call__(self, q, p): + div = torch.distributions.kl_divergence(q, p) + return div.mean() + + def __repr__(self): + return "KLLoss()" + + +class KLLossMulti: + def __init__(self): + self.klloss = KLLoss() + + def __call__(self, qlist, plist): + return sum([self.klloss(q, p) + for q, p in zip(qlist, plist)]) + + def __repr__(self): + return "KLLossMulti()" diff --git a/Evaluator_272/mld/models/losses/gpt.py b/Evaluator_272/mld/models/losses/gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..50143562a7c9de4c34788b763d180eb77c3e7706 --- /dev/null +++ b/Evaluator_272/mld/models/losses/gpt.py @@ -0,0 +1,166 @@ +import numpy as np +import torch +import torch.nn as nn +from torchmetrics import Metric + +from mld.data.humanml.scripts.motion_process import (qrot, + recover_root_rot_pos) + +from .infonce import InfoNCE + + +class GPTLosses(Metric): + """ + MLD Loss + """ + + def __init__(self, cfg): + super().__init__(dist_sync_on_step=cfg.LOSS.DIST_SYNC_ON_STEP) + + # Save parameters + # self.vae = vae + # self.vae_type = cfg.model.vae_type + # self.mode = mode + self.cfg = cfg + # self.predict_epsilon = cfg.TRAIN.ABLATION.PREDICT_EPSILON + self.stage = cfg.TRAIN.STAGE + + assert self.stage in ["gpt"] + losses = [] + + + + losses.append("ce_motiontoken") + if self.cfg.TRAIN.use_tmr_supervision: + losses.append("contrastive_tmrsupervise") + self.infonce_temp = cfg.LOSS.INFONCE_TEMP + + # self.add_state("count", torch.tensor(0), dist_reduce_fx="mean") + + losses.append("total") + + for loss in losses: + self.add_state(loss, + default=torch.tensor(0.0), + dist_reduce_fx="sum") + # self.register_buffer(loss, torch.tensor(0.0)) + self.add_state("count", torch.tensor(0), dist_reduce_fx="sum") + + if self.stage in ['gpt']: + self.add_state("rightnum", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("count_all_token", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + + self.losses = losses + + + self._losses_func = {} + self._params = {} + for loss in losses: + if loss.split('_')[0] == 'inst': + self._losses_func[loss] = nn.MSELoss(reduction='mean') + self._params[loss] = 1 + elif loss.split('_')[0] == 'x': + self._losses_func[loss] = nn.MSELoss(reduction='mean') + self._params[loss] = 1 + elif loss.split('_')[0] == 'prior': + self._losses_func[loss] = nn.MSELoss(reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_PRIOR + elif loss.split('_')[0] == 'kl': + if cfg.LOSS.LAMBDA_KL != 0.0: + self._losses_func[loss] = KLLoss() + self._params[loss] = cfg.LOSS.LAMBDA_KL + elif loss.split('_')[0] == 'recons': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_REC + elif loss.split('_')[0] == 'gen': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_GEN + elif loss.split('_')[0] == 'latent': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_LATENT + elif loss.split('_')[0] == 'ce': + self._losses_func[loss] = torch.nn.CrossEntropyLoss( + reduction='mean') + self._params[loss] = 1 + elif loss.split('_')[0] == 'contrastive': + self._losses_func[loss] = InfoNCE(self.infonce_temp) + self._params[loss] = cfg.LOSS.LAMBDA_INFONCE + else: + ValueError("This loss is not recognized.") + + def update(self, rs_set): + total: float = 0.0 + + assert len(rs_set['m_rst']) == len(rs_set['m_ref']) + bs = len(rs_set['m_rst']) + + if self.stage in ['gpt']: + + if self.cfg.TRAIN.use_tmr_supervision: + total += self._update_loss("contrastive_tmrsupervise", (rs_set['supervise_motion_feat'], rs_set['supervise_text_feat']), rs_set['emb_dist']) + + for i in range(bs): + total += self._update_loss("ce_motiontoken", rs_set['m_rst'][i], rs_set['m_ref'][i]) / bs # rs_set['m_rst'][i] (16, 513) rs_set['m_ref'][i] (16) + probs = torch.softmax(rs_set['m_rst'][i], dim=-1) + _, cls_pred_index = torch.max(probs, dim=-1) # 16 + self.count_all_token += cls_pred_index.shape[0] + self.rightnum += (cls_pred_index.flatten(0) == rs_set['m_ref'][i].flatten(0)).sum().item() + + self.total += total.detach() + self.count += 1 + + return total + + def compute(self, split): + count = getattr(self, "count") + loss_dict = {loss: getattr(self, loss) / count for loss in self.losses} + loss_dict['ACC_token'] = self.rightnum / self.count_all_token + return loss_dict + + def _update_loss(self, loss: str, outputs, inputs): + # Update the loss + val = self._losses_func[loss](outputs, inputs) + getattr(self, loss).__iadd__(val.detach()) + # Return a weighted sum + weighted_loss = self._params[loss] * val + return weighted_loss + + def loss2logname(self, loss: str, split: str): + if loss == "total": + log_name = f"{loss}/{split}" + else: + loss_type, name = loss.split("_") + log_name = f"{loss_type}/{name}/{split}" + return log_name + + +class KLLoss: + + def __init__(self): + pass + + def __call__(self, q, p): + div = torch.distributions.kl_divergence(q, p) + return div.mean() + + def __repr__(self): + return "KLLoss()" + + +class KLLossMulti: + + def __init__(self): + self.klloss = KLLoss() + + def __call__(self, qlist, plist): + return sum([self.klloss(q, p) for q, p in zip(qlist, plist)]) + + def __repr__(self): + return "KLLossMulti()" diff --git a/Evaluator_272/mld/models/losses/infonce.py b/Evaluator_272/mld/models/losses/infonce.py new file mode 100644 index 0000000000000000000000000000000000000000..c0065c00c7b8fb4f57d48f8777e4094760725917 --- /dev/null +++ b/Evaluator_272/mld/models/losses/infonce.py @@ -0,0 +1,45 @@ +import torch +import torch.nn.functional as F +import numpy as np + +class InfoNCE: + def __init__(self, t): + # pass + self.t = t + + def __call__(self, f, dist): + ''' + f_motion: N x d + f_text: N x d + ''' + t = self.t + f_motion, f_text = f[0], f[1] + N, d = f_motion.shape[0], f_motion.shape[1] + + + Emb_motion = F.normalize(f_motion, dim=1) + Emb_text = F.normalize(f_text, dim=1) + + t = torch.tensor(t).to(f_motion.device) + logits = torch.mm(Emb_motion, Emb_text.T) + # logits = torch.mm(Emb_motion, Emb_text.T) / torch.exp(t) + if dist is not None: + text_logits = dist.detach() + mask = torch.where(torch.logical_and(text_logits > 0.85, text_logits < 1.0-1e-100), torch.tensor(float('-inf')).to(f_motion.device), torch.tensor(1.0e100).to(f_motion.device)) + mask.diagonal().fill_(float('inf')) + logits = torch.min(mask, logits) + # mask = torch.where((torch.logical_and(text_logits > 0.985, text_logits < 1.0-1e-100)), torch.tensor(float('-inf')).cuda(), torch.tensor(1.0e100).cuda()) + # logits = torch.min(mask, logits) + + N = f_motion.shape[0] + labels = torch.arange(N).to(f_motion.device) + + loss_m = F.cross_entropy(logits / t, labels) + loss_t = F.cross_entropy(logits.T / t, labels) + + loss = (loss_m + loss_t) / 2 + + return loss + + def __repr__(self): + return "InfoNCE()" \ No newline at end of file diff --git a/Evaluator_272/mld/models/losses/kl.py b/Evaluator_272/mld/models/losses/kl.py new file mode 100644 index 0000000000000000000000000000000000000000..6532c0e70e037b4e3bf4c57fb25f21a3bbe5c4a6 --- /dev/null +++ b/Evaluator_272/mld/models/losses/kl.py @@ -0,0 +1,23 @@ +import torch + +class KLLoss: + def __init__(self): + pass + + def __call__(self, q, p): + div = torch.distributions.kl_divergence(q, p) + return div.mean() + + def __repr__(self): + return "KLLoss()" + +class KLLossMulti: + def __init__(self): + self.klloss = KLLoss() + + def __call__(self, qlist, plist): + return sum([self.klloss(q, p) + for q, p in zip(qlist, plist)]) + + def __repr__(self): + return "KLLossMulti()" diff --git a/Evaluator_272/mld/models/losses/mld.py b/Evaluator_272/mld/models/losses/mld.py new file mode 100644 index 0000000000000000000000000000000000000000..1316ce1e9930d82be8c805797e803afd1e30c0ec --- /dev/null +++ b/Evaluator_272/mld/models/losses/mld.py @@ -0,0 +1,340 @@ +import numpy as np +import torch +import torch.nn as nn +from torchmetrics import Metric + +from mld.data.humanml.scripts.motion_process import (qrot, + recover_root_rot_pos) + + +class MLDLosses(Metric): + """ + MLD Loss + """ + + def __init__(self, vae, mode, cfg): + super().__init__(dist_sync_on_step=cfg.LOSS.DIST_SYNC_ON_STEP) + + # Save parameters + # self.vae = vae + self.vae_type = cfg.model.vae_type + self.mode = mode + self.cfg = cfg + self.predict_epsilon = cfg.TRAIN.ABLATION.PREDICT_EPSILON + self.stage = cfg.TRAIN.STAGE + + losses = [] + + # diffusion loss + if self.stage in ['diffusion', 'vae_diffusion']: + # instance noise loss + losses.append("inst_loss") + losses.append("x_loss") + if self.cfg.LOSS.LAMBDA_PRIOR != 0.0: + # prior noise loss + losses.append("prior_loss") + + if self.stage in ['vae', 'vae_diffusion']: + # reconstruction loss + losses.append("recons_feature") + losses.append("recons_verts") + losses.append("recons_joints") + losses.append("recons_limb") + + losses.append("gen_feature") + losses.append("gen_joints") + + # KL loss + if self.vae_type in ['mld_dual_vae']: + losses.append("kl_motionbody") + losses.append("kl_motionhand") + else: + losses.append("kl_motion") + + # vel Loss + if cfg.LOSS.Velocity_loss: + losses.append("recons_velocity") + + if self.stage not in ['vae', 'diffusion', 'vae_diffusion']: + raise ValueError(f"Stage {self.stage} not supported") + + losses.append("total") + + for loss in losses: + self.add_state(loss, + default=torch.tensor(0.0), + dist_reduce_fx="sum") + # self.register_buffer(loss, torch.tensor(0.0)) + self.add_state("count", torch.tensor(0), dist_reduce_fx="sum") + self.losses = losses + + self._losses_func = {} + self._params = {} + for loss in losses: + if loss.split('_')[0] == 'inst': + self._losses_func[loss] = nn.MSELoss(reduction='mean') + self._params[loss] = 1 + elif loss.split('_')[0] == 'x': + self._losses_func[loss] = nn.MSELoss(reduction='mean') + self._params[loss] = 1 + elif loss.split('_')[0] == 'prior': + self._losses_func[loss] = nn.MSELoss(reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_PRIOR + if loss.split('_')[0] == 'kl': + if cfg.LOSS.LAMBDA_KL != 0.0: + self._losses_func[loss] = KLLoss() + self._params[loss] = cfg.LOSS.LAMBDA_KL + elif loss.split('_')[0] == 'recons': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_REC + elif loss.split('_')[0] == 'gen': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_GEN + elif loss.split('_')[0] == 'latent': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_LATENT + + else: + ValueError("This loss is not recognized.") + if loss.split('_')[-1] == 'joints': + self._params[loss] = cfg.LOSS.LAMBDA_JOINT + + def update(self, rs_set): + total: float = 0.0 + # Compute the losses + # Compute instance loss + if self.stage in ["vae", "vae_diffusion"]: + total += self._update_loss("recons_feature", rs_set['m_rst'], + rs_set['m_ref']) + total += self._update_loss("recons_joints", rs_set['joints_rst'], + rs_set['joints_ref']) + if self.vae_type in ["mld_dual_vae"]: + total += self._update_loss("kl_motionbody", rs_set['body_dist_m'], rs_set['body_dist_ref']) + total += self._update_loss("kl_motionhand", rs_set['hand_dist_m'], rs_set['hand_dist_ref']) + else: + total += self._update_loss("kl_motion", rs_set['dist_m'], rs_set['dist_ref']) + + if self.cfg.LOSS.Velocity_loss: + total += self._update_loss("recons_velocity", rs_set['vel_rst'], rs_set['vel_ref']) + + if self.stage in ["diffusion", "vae_diffusion"]: + # predict noise + if self.predict_epsilon: + total += self._update_loss("inst_loss", rs_set['noise_pred'], + rs_set['noise']) + # predict x + else: + total += self._update_loss("x_loss", rs_set['pred'], + rs_set['latent']) + + if self.cfg.LOSS.LAMBDA_PRIOR != 0.0: + # loss - prior loss + total += self._update_loss("prior_loss", rs_set['noise_prior'], + rs_set['dist_m1']) + + if self.stage in ["vae_diffusion"]: + # loss + # noise+text_emb => diff_reverse => latent => decode => motion + total += self._update_loss("gen_feature", rs_set['gen_m_rst'], + rs_set['m_ref']) + total += self._update_loss("gen_joints", rs_set['gen_joints_rst'], + rs_set['joints_ref']) + + self.total += total.detach() + self.count += 1 + + return total + + def compute(self, split): + count = getattr(self, "count") + return {loss: getattr(self, loss) / count for loss in self.losses} + + def _update_loss(self, loss: str, outputs, inputs): + # Update the loss + val = self._losses_func[loss](outputs, inputs) + getattr(self, loss).__iadd__(val.detach()) + # Return a weighted sum + weighted_loss = self._params[loss] * val + return weighted_loss + + def loss2logname(self, loss: str, split: str): + if loss == "total": + log_name = f"{loss}/{split}" + else: + loss_type, name = loss.split("_") + log_name = f"{loss_type}/{name}/{split}" + return log_name + + + +class MLDLosses_no_joint(Metric): + """ + MLD Loss + """ + + def __init__(self, vae, mode, cfg): + super().__init__(dist_sync_on_step=cfg.LOSS.DIST_SYNC_ON_STEP) + + # Save parameters + # self.vae = vae + self.vae_type = cfg.TRAIN.ABLATION.VAE_TYPE + self.mode = mode + self.cfg = cfg + self.predict_epsilon = cfg.TRAIN.ABLATION.PREDICT_EPSILON + self.stage = cfg.TRAIN.STAGE + + losses = [] + + # diffusion loss + if self.stage in ['diffusion', 'vae_diffusion']: + # instance noise loss + losses.append("inst_loss") + losses.append("x_loss") + if self.cfg.LOSS.LAMBDA_PRIOR != 0.0: + # prior noise loss + losses.append("prior_loss") + + if self.stage in ['vae', 'vae_diffusion']: + # reconstruction loss + losses.append("recons_feature") + losses.append("recons_verts") + # losses.append("recons_joints") + losses.append("recons_limb") + + losses.append("gen_feature") + # losses.append("gen_joints") + + # KL loss + losses.append("kl_motion") + + if self.stage not in ['vae', 'diffusion', 'vae_diffusion']: + raise ValueError(f"Stage {self.stage} not supported") + + losses.append("total") + + for loss in losses: + self.add_state(loss, + default=torch.tensor(0.0), + dist_reduce_fx="sum") + # self.register_buffer(loss, torch.tensor(0.0)) + self.add_state("count", torch.tensor(0), dist_reduce_fx="sum") + self.losses = losses + + self._losses_func = {} + self._params = {} + for loss in losses: + if loss.split('_')[0] == 'inst': + self._losses_func[loss] = nn.MSELoss(reduction='mean') + self._params[loss] = 1 + elif loss.split('_')[0] == 'x': + self._losses_func[loss] = nn.MSELoss(reduction='mean') + self._params[loss] = 1 + elif loss.split('_')[0] == 'prior': + self._losses_func[loss] = nn.MSELoss(reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_PRIOR + if loss.split('_')[0] == 'kl': + if cfg.LOSS.LAMBDA_KL != 0.0: + self._losses_func[loss] = KLLoss() + self._params[loss] = cfg.LOSS.LAMBDA_KL + elif loss.split('_')[0] == 'recons': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_REC + elif loss.split('_')[0] == 'gen': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_GEN + elif loss.split('_')[0] == 'latent': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_LATENT + else: + ValueError("This loss is not recognized.") + if loss.split('_')[-1] == 'joints': + self._params[loss] = cfg.LOSS.LAMBDA_JOINT + + def update(self, rs_set): + total: float = 0.0 + # Compute the losses + # Compute instance loss + if self.stage in ["vae", "vae_diffusion"]: + total += self._update_loss("recons_feature", rs_set['m_rst'], + rs_set['m_ref']) + # total += self._update_loss("recons_joints", rs_set['joints_rst'], + # rs_set['joints_ref']) + total += self._update_loss("kl_motion", rs_set['dist_m'], rs_set['dist_ref']) + + if self.stage in ["diffusion", "vae_diffusion"]: + # predict noise + if self.predict_epsilon: + total += self._update_loss("inst_loss", rs_set['noise_pred'], + rs_set['noise']) + # predict x + else: + total += self._update_loss("x_loss", rs_set['pred'], + rs_set['latent']) + + if self.cfg.LOSS.LAMBDA_PRIOR != 0.0: + # loss - prior loss + total += self._update_loss("prior_loss", rs_set['noise_prior'], + rs_set['dist_m1']) + + if self.stage in ["vae_diffusion"]: + # loss + # noise+text_emb => diff_reverse => latent => decode => motion + total += self._update_loss("gen_feature", rs_set['gen_m_rst'], + rs_set['m_ref']) + # total += self._update_loss("gen_joints", rs_set['gen_joints_rst'], + # rs_set['joints_ref']) + + self.total += total.detach() + self.count += 1 + + return total + + def compute(self, split): + count = getattr(self, "count") + return {loss: getattr(self, loss) / count for loss in self.losses} + + def _update_loss(self, loss: str, outputs, inputs): + # Update the loss + val = self._losses_func[loss](outputs, inputs) + getattr(self, loss).__iadd__(val.detach()) + # Return a weighted sum + weighted_loss = self._params[loss] * val + return weighted_loss + + def loss2logname(self, loss: str, split: str): + if loss == "total": + log_name = f"{loss}/{split}" + else: + loss_type, name = loss.split("_") + log_name = f"{loss_type}/{name}/{split}" + return log_name + +class KLLoss: + + def __init__(self): + pass + + def __call__(self, q, p): + div = torch.distributions.kl_divergence(q, p) + return div.mean() + + def __repr__(self): + return "KLLoss()" + + +class KLLossMulti: + + def __init__(self): + self.klloss = KLLoss() + + def __call__(self, qlist, plist): + return sum([self.klloss(q, p) for q, p in zip(qlist, plist)]) + + def __repr__(self): + return "KLLossMulti()" diff --git a/Evaluator_272/mld/models/losses/temos.py b/Evaluator_272/mld/models/losses/temos.py new file mode 100644 index 0000000000000000000000000000000000000000..cd02527a5451943f7d837ff26dad0b7af8575ac5 --- /dev/null +++ b/Evaluator_272/mld/models/losses/temos.py @@ -0,0 +1,220 @@ +import torch +import torch.nn as nn +from torchmetrics import Metric +from .infonce import InfoNCE + + +class TemosLosses(Metric): + """ + Loss + Modify loss + refer to temos loss + add loss like deep-motion-editing + 'gen_loss_total': l_total, + 'gen_loss_adv': l_adv, + 'gen_loss_recon_all': l_rec, + 'gen_loss_recon_r': l_r_rec, + 'gen_loss_recon_s': l_s_rec, + 'gen_loss_feature_all': l_ft, + 'gen_loss_feature_r': l_ft_r, + 'gen_loss_feature_s': l_ft_s, + 'gen_loss_feature_t': l_ft_t, + 'gen_loss_quaternion': l_qt, + 'gen_loss_twist': l_tw, + 'gen_loss_triplet': l_triplet, + 'gen_loss_joint': l_joint, + + """ + + def __init__(self, vae, mode, cfg): + super().__init__(dist_sync_on_step=cfg.LOSS.DIST_SYNC_ON_STEP) + # Save parameters + self.vae = vae + self.mode = mode + + self.use_infonce = cfg.LOSS.USE_INFONCE + + if self.use_infonce: + self.infonce_temp = cfg.LOSS.INFONCE_TEMP + + loss_on_both = True + force_loss_on_jfeats = True + ablation_no_kl_combine = False + ablation_no_kl_gaussian = False + ablation_no_motionencoder = False + + infonce_use_latent = False + + self.loss_on_both = loss_on_both + self.ablation_no_kl_combine = ablation_no_kl_combine + self.ablation_no_kl_gaussian = ablation_no_kl_gaussian + self.ablation_no_motionencoder = ablation_no_motionencoder + + self.infonce_use_latent = infonce_use_latent + + losses = [] + if mode == "xyz" or force_loss_on_jfeats: + if not ablation_no_motionencoder: + losses.append("recons_jfeats2jfeats") + losses.append("recons_text2jfeats") + if mode == "smpl": + if not ablation_no_motionencoder: + losses.append("recons_rfeats2rfeats") + losses.append("recons_text2rfeats") + else: + ValueError("This mode is not recognized.") + + if vae or loss_on_both: + kl_losses = [] + if not ablation_no_kl_combine and not ablation_no_motionencoder: + kl_losses.extend(["kl_text2motion", "kl_motion2text"]) + if not ablation_no_kl_gaussian: + if ablation_no_motionencoder: + kl_losses.extend(["kl_text"]) + else: + kl_losses.extend(["kl_text", "kl_motion"]) + losses.extend(kl_losses) + + if not self.vae or loss_on_both: + if not ablation_no_motionencoder: + losses.append("latent_manifold") + losses.append("total") + + if self.use_infonce: + losses.append("contrastive_infonce") + + + for loss in losses: + self.register_buffer(loss, torch.tensor(0.0)) + self.register_buffer("count", torch.tensor(0)) + # self.register_buffer(loss, default=torch.tensor(0.0), dist_reduce_fx="sum") + # self.register_buffer("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.losses = losses + + # Instantiate loss functions + # self._losses_func = {loss: hydra.utils.instantiate(kwargs[loss + "_func"]) + # for loss in losses if loss != "total"} + self._losses_func = {} + self._params = {} + + for loss in losses: + if loss != 'total': + if loss.split('_')[0] == 'kl': + self._losses_func[loss] = KLLoss() + self._params[loss] = cfg.LOSS.LAMBDA_KL + elif loss.split('_')[0] == 'recons': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_REC + elif loss.split('_')[0] == 'latent': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_LATENT + elif loss.split('_')[0] == 'cycle': + self._losses_func[loss] = torch.nn.SmoothL1Loss( + reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_CYCLE + elif loss.split('_')[0] == 'contrastive': + self._losses_func[loss] = InfoNCE(self.infonce_temp) + self._params[loss] = cfg.LOSS.LAMBDA_INFONCE + else: + ValueError("This loss is not recognized.") + + def update(self, + f_text=None, + f_motion=None, + f_ref=None, + lat_text=None, + lat_motion=None, + dis_text=None, + dis_motion=None, + dis_ref=None, + emb_dist=None): + total: float = 0.0 + + if self.mode == "xyz" or self.force_loss_on_jfeats: + if not self.ablation_no_motionencoder: + total += self._update_loss("recons_jfeats2jfeats", f_motion, + f_ref) + total += self._update_loss("recons_text2jfeats", f_text, f_ref) + + if self.mode == "smpl": + if not self.ablation_no_motionencoder: + total += self._update_loss("recons_rfeats2rfeats", + f_motion.rfeats, f_ref.rfeats) + total += self._update_loss("recons_text2rfeats", f_text.rfeats, + f_ref.rfeats) + + if self.vae or self.loss_on_both: + if not self.ablation_no_kl_combine and not self.ablation_no_motionencoder: + total += self._update_loss("kl_text2motion", dis_text, + dis_motion) + total += self._update_loss("kl_motion2text", dis_motion, + dis_text) + if not self.ablation_no_kl_gaussian: + total += self._update_loss("kl_text", dis_text, dis_ref) + if not self.ablation_no_motionencoder: + total += self._update_loss("kl_motion", dis_motion, + dis_ref) + if not self.vae or self.loss_on_both: + if not self.ablation_no_motionencoder: + total += self._update_loss("latent_manifold", lat_text, + lat_motion) + + if self.use_infonce: + if self.infonce_use_latent: + # print('use latent feature to calculate caontrastive loss') + total += self._update_loss("contrastive_infonce", (lat_text, lat_motion), emb_dist) + else: + total += self._update_loss("contrastive_infonce", (dis_motion.loc, dis_text.loc), emb_dist) + + + self.total += total.detach() + self.count += 1 + + return total + + def compute(self, split): + count = getattr(self, "count") + return {loss: getattr(self, loss) / count for loss in self.losses} + + def _update_loss(self, loss: str, outputs, inputs): + # Update the loss + val = self._losses_func[loss](outputs, inputs) + getattr(self, loss).__iadd__(val.detach()) + # Return a weighted sum + weighted_loss = self._params[loss] * val + return weighted_loss + + def loss2logname(self, loss: str, split: str): + if loss == "total": + log_name = f"{loss}/{split}" + else: + loss_type, name = loss.split("_") + log_name = f"{loss_type}/{name}/{split}" + return log_name + + +class KLLoss: + + def __init__(self): + pass + + def __call__(self, q, p): + div = torch.distributions.kl_divergence(q, p) + return div.mean() + + def __repr__(self): + return "KLLoss()" + + +class KLLossMulti: + + def __init__(self): + self.klloss = KLLoss() + + def __call__(self, qlist, plist): + return sum([self.klloss(q, p) for q, p in zip(qlist, plist)]) + + def __repr__(self): + return "KLLossMulti()" diff --git a/Evaluator_272/mld/models/losses/tmost.py b/Evaluator_272/mld/models/losses/tmost.py new file mode 100644 index 0000000000000000000000000000000000000000..14301e0ea2072b7b88216cf2e8a94f9c71741166 --- /dev/null +++ b/Evaluator_272/mld/models/losses/tmost.py @@ -0,0 +1,178 @@ +import torch +import torch.nn as nn +from torchmetrics import Metric + +class TmostLosses(Metric): + """ + Loss + Modify loss + refer to temos loss + add loss like deep-motion-editing + 'gen_loss_total': l_total, + 'gen_loss_adv': l_adv, + 'gen_loss_recon_all': l_rec, + 'gen_loss_recon_r': l_r_rec, + 'gen_loss_recon_s': l_s_rec, + 'gen_loss_feature_all': l_ft, + 'gen_loss_feature_r': l_ft_r, + 'gen_loss_feature_s': l_ft_s, + 'gen_loss_feature_t': l_ft_t, + 'gen_loss_quaternion': l_qt, + 'gen_loss_twist': l_tw, + 'gen_loss_triplet': l_triplet, + 'gen_loss_joint': l_joint, + + """ + def __init__(self, vae, mode, cfg): + super().__init__(dist_sync_on_step=cfg.LOSS.DIST_SYNC_ON_STEP) + + # Save parameters + self.vae = vae + self.mode = mode + + + losses = [] + losses.append("recons_mm2m") + losses.append("recons_t2m") + + losses.append("cross_mt2m") + losses.append("cross_tm2m") + + # cycle consistency loss + losses.append("cycle_cmsm2mContent") + losses.append("cycle_cmsm2mStyle") + + # latent loss + losses.append("latent_ct2cm") + losses.append("latent_st2sm") + + # KL loss + losses.append("kl_motion") + losses.append("kl_text") + losses.append("kl_ct2cm") + losses.append("kl_cm2ct") + + losses.append("total") + + for loss in losses: + self.register_buffer(loss, torch.tensor(0.0)) + self.register_buffer("count", torch.tensor(0)) + self.losses = losses + + self.ablation_cycle = cfg.TRAIN.ABLATION.CYCLE + + self._losses_func = {} + self._params = {} + for loss in losses: + if loss !='total': + if loss.split('_')[0] == 'kl': + self._losses_func[loss] = KLLoss() + self._params[loss] = cfg.LOSS.LAMBDA_KL + elif loss.split('_')[0] == 'recons': + self._losses_func[loss] = torch.nn.SmoothL1Loss(reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_REC + elif loss.split('_')[0] == 'cross': + self._losses_func[loss] = torch.nn.SmoothL1Loss(reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_CROSS + elif loss.split('_')[0] =='latent': + self._losses_func[loss] = torch.nn.SmoothL1Loss(reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_LATENT + elif loss.split('_')[0] =='cycle': + self._losses_func[loss] = torch.nn.SmoothL1Loss(reduction='mean') + self._params[loss] = cfg.LOSS.LAMBDA_CYCLE + else: + ValueError("This loss is not recognized.") + + + def update(self, rs_set, dist_ref): + total: float = 0.0 + + # Compute the losses + """ + loss list + - triplet loss + - anchor style1 + - pos style2 + - neg diff_style + anchor = s_xa + pos = s_xpos + neg = self.gen.enc_style(co_data[diff_style], diff_style[-2:]) + l_triplet = self.triplet_loss(anchor, pos, neg) + - + """ + + + total += self._update_loss("recons_mm2m", rs_set['rs_cm1sm1'], rs_set['m1']) + total += self._update_loss("recons_t2m", rs_set['rs_ct1st1'], rs_set['m1']) + + # loss - cross reconstruction loss + total += self._update_loss("cross_mt2m", rs_set['rs_cm1st1'], rs_set['m1']) + total += self._update_loss("cross_tm2m", rs_set['rs_ct1sm1'], rs_set['m1']) + + + if self.ablation_cycle: + total += self._update_loss("cycle_cmsm2mContent", rs_set['cyc_rs_cm1sm1'], rs_set['m1']) + total += self._update_loss("cycle_cmsm2mStyle", rs_set['cyc_rs_cm2sm2'], rs_set['m2']) + + + total += self._update_loss("latent_ct2cm", rs_set['lat_ct1'], rs_set['lat_cm1']) + total += self._update_loss("latent_st2sm", rs_set['lat_st1'], rs_set['lat_sm1']) + + + total += self._update_loss("kl_motion", rs_set['dist_cm1'], dist_ref) + # total += self._update_loss("kl_motion", rs_set['dist_sm1'], dist_ref) + + total += self._update_loss("kl_text", rs_set['dist_ct1'], dist_ref) + # total += self._update_loss("kl_text", rs_set['dist_st1'], dist_ref) + + total += self._update_loss("kl_ct2cm", rs_set['dist_ct1'], rs_set['dist_cm1']) + total += self._update_loss("kl_cm2ct", rs_set['dist_cm1'], rs_set['dist_ct1']) + + self.total += total.detach() + self.count += 1 + + return total + + def compute(self, split): + count = getattr(self, "count") + return {loss: getattr(self, loss)/count for loss in self.losses} + + def _update_loss(self, loss: str, outputs, inputs): + # Update the loss + val = self._losses_func[loss](outputs, inputs) + getattr(self, loss).__iadd__(val.detach()) + # Return a weighted sum + weighted_loss = self._params[loss] * val + return weighted_loss + + def loss2logname(self, loss: str, split: str): + if loss == "total": + log_name = f"{loss}/{split}" + else: + loss_type, name = loss.split("_") + log_name = f"{loss_type}/{name}/{split}" + return log_name + + +class KLLoss: + def __init__(self): + pass + + def __call__(self, q, p): + div = torch.distributions.kl_divergence(q, p) + return div.mean() + + def __repr__(self): + return "KLLoss()" + + +class KLLossMulti: + def __init__(self): + self.klloss = KLLoss() + + def __call__(self, qlist, plist): + return sum([self.klloss(q, p) + for q, p in zip(qlist, plist)]) + + def __repr__(self): + return "KLLossMulti()" diff --git a/Evaluator_272/mld/models/losses/utils.py b/Evaluator_272/mld/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..961744c30af867a3fa19337293cce6ee2cd5c268 --- /dev/null +++ b/Evaluator_272/mld/models/losses/utils.py @@ -0,0 +1,185 @@ + +import torch + +# --- +def keypoint_loss(self, pred_keypoints_2d, gt_keypoints_2d, openpose_weight, gt_weight): + """ + Compute 2D reprojection loss on the keypoints. + The loss is weighted by the confidence. + The available keypoints are different for each dataset. + """ + conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() + conf[:, :25] *= openpose_weight + conf[:, 25:] *= gt_weight + loss = (conf * self.criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() + return loss + +def compute_2d_loss(model, batch): + ''' + keypoints loss + ''' + gt = batch["kp_2d"] + out = batch["pred_2d"] + mask = batch["mask"] + + gtmasked = gt[mask] + outmasked = out[mask] + loss = F.mse_loss(gtmasked, outmasked, reduction='mean') + return loss + +def compute_limb_loss(model, batch): + # limb position loss + x = batch["x_xyz"] + output = batch["output_xyz"] + mask = batch["mask"] + + # remove glob translation + # [bs njoint nfeats lenghs] = > [bs lengths njoints nfeats] + rootindex = JOINTSTYPE_ROOT[model.jointstype] + gt = x - x[:,:,[rootindex],:] + out = output - output[:,:,[rootindex],:] + + limbndex = JOINTSTYPE_LIMB[model.jointstype] + gtmasked = gt[:,:,limbndex,:][mask] + outmasked = out[:,:,limbndex,:][mask] + + loss = F.mse_loss(gtmasked, outmasked, reduction='mean') + return loss + +def compute_glob_loss(model, batch): + # glob rotation for the first (root) joint + x = batch["x"] + output = batch["output"] + mask = batch["mask"] + + # [bs njoint nfeats lenghs] = > [bs lengths njoints nfeats] + rootindex = JOINTSTYPE_ROOT[model.jointstype] + gtmasked = x[:,:,[rootindex],:][mask] + outmasked = output[:,:,[rootindex],:][mask] + + loss = F.mse_loss(gtmasked, outmasked, reduction='mean') + return loss + +def compute_theta_loss(model, batch): + x = batch['theta'] + output = batch["output_theta"] + mask = batch["mask"] + + gtmasked = x[mask] + outmasked = output[mask] + + # translation loss + root_index = THETA_MAP['root'] + w_root = batch["w_root"][mask][:,None] + gtmasked[:,root_index] *= w_root + outmasked[:,root_index] *= w_root + + loss = F.mse_loss(gtmasked, outmasked, reduction='mean') + return loss + +def compute_rc_loss(model, batch): + x = batch["x"] + output = batch["output"] + mask = batch["mask"] + + gtmasked = x[mask] + outmasked = output[mask] + + loss = F.mse_loss(gtmasked, outmasked, reduction='mean') + return loss + +def compute_rcxyz_loss(model, batch): + x = batch["x_xyz"] + output = batch["output_xyz"] + mask = batch["mask"] + + # dummpy + # ---ignore global output for no global dataset--- + root_index = THETA_MAP['root'] + w_root = batch["w_root"][mask][:,None,None] + trans = batch['theta'][:,:,None,root_index,...][mask] + output_trans = batch['output_theta'][:,:,None,root_index][mask] + + gtmasked = x[mask] + outmasked = output[mask] + + gtmasked -= trans*(1-w_root) + outmasked -= output_trans*(1-w_root) + # ------------------------------------------------- + loss = F.mse_loss(gtmasked, outmasked, reduction='mean') + return loss + +def compute_rcverts_loss(model, batch): + x = batch["x_vertices"] + output = batch["output_vertices"] + mask = batch["mask"] + + # dummy + # ---ignore global output for no global dataset--- + root_index = THETA_MAP['root'] + w_root = batch["w_root"][mask][:,None,None] + trans = batch['theta'][:,:,None,root_index,...][mask] + output_trans = batch['output_theta'][:,:,None,root_index][mask] + + gtmasked = x[mask] + outmasked = output[mask] + + gtmasked -= trans*(1-w_root) + outmasked -= output_trans*(1-w_root) + # ------------------------------------------------- + loss = F.mse_loss(gtmasked, outmasked, reduction='mean') + return loss + +def compute_vel_loss(model, batch): + x = batch["x"] + output = batch["output"] + gtvel = (x[:,1:,...] - x[:, :-1,...]) + outputvel = (output[:,1:,...] - output[:,1:,...]) + + mask = batch["mask"][:,1:] + + gtvelmasked = gtvel[mask] + outvelmasked = outputvel[mask] + + loss = F.mse_loss(gtvelmasked, outvelmasked, reduction='mean') + return loss + + +def compute_velxyz_loss(model, batch): + x = batch["x_xyz"] + output = batch["output_xyz"] + gtvel = (x[:,1:,...] - x[:,:-1,...]) + outputvel = (output[:,1:,...] - output[:,:-1,...]) + + mask = batch["mask"][:, 1:] + + gtvelmasked = gtvel[mask] + outvelmasked = outputvel[mask] + + loss = F.mse_loss(gtvelmasked, outvelmasked, reduction='mean') + return loss + + +def compute_hp_loss(model, batch): + loss = hessian_penalty(model.return_latent, batch, seed=torch.random.seed()) + return loss + + +def compute_kl_loss(model, batch): + mu, logvar = batch["mu"], batch["logvar"] + loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + return loss + +_matching_ = {"rc": compute_rc_loss, "kl": compute_kl_loss, "hp": compute_hp_loss, + "rcxyz": compute_rcxyz_loss, + "vel": compute_vel_loss, "velxyz": compute_velxyz_loss, + "glob":compute_glob_loss, "limb":compute_limb_loss, "rcverts": compute_rcverts_loss, + "theta": compute_theta_loss, "2d": compute_2d_loss} + +def get_loss_function(ltype): + return _matching_[ltype] + + +def get_loss_names(): + return list(_matching_.keys()) +# --- diff --git a/Evaluator_272/mld/models/metrics/__init__.py b/Evaluator_272/mld/models/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ee054e1bf6bfb8d16eaec4249bf85d7b181b1f70 --- /dev/null +++ b/Evaluator_272/mld/models/metrics/__init__.py @@ -0,0 +1,12 @@ +from .compute import ComputeMetrics +from .mr import MRMetrics +from .tm2t import TM2TMetrics +from .tm2t_R256 import TM2TMetrics_R256 +from .tmr_tm2t import TMR_TM2TMetrics +from .mm import MMMetrics +# from .gru import HUMANACTMetrics +# from .stgcn import UESTCMetrics +from .uncond import UncondMetrics +from .compute_body_hand import ComputeMetrics_body_hand +# from .mr_body_hand import MRMetrics_body_hand +from .acc import ACCMetrics diff --git a/Evaluator_272/mld/models/metrics/acc.py b/Evaluator_272/mld/models/metrics/acc.py new file mode 100644 index 0000000000000000000000000000000000000000..dee70f79bbc2bee9e20474c60973c316f464e764 --- /dev/null +++ b/Evaluator_272/mld/models/metrics/acc.py @@ -0,0 +1,96 @@ +from typing import List +import random +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.functional import pairwise_euclidean_distance +import os +from .utils import * + + + +class ACCMetrics(Metric): + + def __init__(self, + dist_sync_on_step=True, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "acc" + + # add metrics + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + self.metrics = [] + # Accuracy + self.add_state("accuracy", + default=torch.tensor(0.), + dist_reduce_fx="sum") + self.add_state("gt_accuracy", + default=torch.tensor(0.), + dist_reduce_fx="sum") + self.metrics.extend(["accuracy", "gt_accuracy"]) + + def compute(self, sanity_flag): + count = self.count.item() + count_seq = self.count_seq.item() + + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # if in sanity check stage then jump + if sanity_flag: + return metrics + + # Accuracy + self.accuracy = torch.trace(self.confusion) / torch.sum(self.confusion) + self.gt_accuracy = torch.trace(self.gt_confusion) / torch.sum( + self.gt_confusion) + + # cat all embeddings + all_labels = torch.cat(self.label_embeddings, axis=0) + all_genmotions = torch.cat(self.recmotion_embeddings, axis=0) + all_gtmotions = torch.cat(self.gtmotion_embeddings, axis=0) + all_gtmotions2 = all_gtmotions.clone()[ + torch.randperm(all_gtmotions.shape[0]), :] + genstats = calculate_activation_statistics(all_genmotions) + gtstats = calculate_activation_statistics(all_gtmotions) + gtstats2 = calculate_activation_statistics(all_gtmotions2) + + all_labels = all_labels.cpu() + + # calculate diversity and multimodality + self.Diversity, self.Multimodality = calculate_diversity_multimodality( + all_genmotions, + all_labels, + self.num_labels, + diversity_times=self.diversity_times, + multimodality_times=self.multimodality_times) + + self.gt_Diversity, self.gt_Multimodality = calculate_diversity_multimodality( + all_gtmotions, all_labels, self.num_labels) + + metrics.update( + {metric: getattr(self, metric) + for metric in self.metrics}) + + # Compute Fid + metrics["FID"] = calculate_fid(gtstats, genstats) + metrics["gt_FID"] = calculate_fid(gtstats, gtstats2) + + return {**metrics} + + def update( + self, + pred_idx: List, + label: List, + lengths: List[int] + ): + self.count += sum(lengths) + self.count_seq += len(lengths) + + + + diff --git a/Evaluator_272/mld/models/metrics/compute.py b/Evaluator_272/mld/models/metrics/compute.py new file mode 100644 index 0000000000000000000000000000000000000000..6a093e3e60f7ad1c9387e9b0c257c4306e4b01c0 --- /dev/null +++ b/Evaluator_272/mld/models/metrics/compute.py @@ -0,0 +1,196 @@ +from typing import List + +import torch +from einops import rearrange +from torch import Tensor +from torchmetrics import Metric + +from mld.models.tools.tools import remove_padding +from mld.transforms.joints2jfeats import Rifke +from mld.utils.geometry import matrix_of_angles + +from .utils import l2_norm, variance + + +class ComputeMetrics(Metric): + + def __init__(self, + njoints, + jointstype: str = "mmm", + force_in_meter: bool = True, + dist_sync_on_step=True, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + if jointstype not in ["mmm", "humanml3d", "motionx", "motionx_v26"]: + raise NotImplementedError("This jointstype is not implemented.") + + self.name = 'APE and AVE' + self.jointstype = jointstype + self.rifke = Rifke(jointstype=jointstype, normalization=False) + + self.force_in_meter = force_in_meter + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + # APE + self.add_state("APE_root", + default=torch.tensor(0.), + dist_reduce_fx="sum") + self.add_state("APE_traj", + default=torch.tensor(0.), + dist_reduce_fx="sum") + self.add_state("APE_pose", + default=torch.zeros(njoints - 1), + dist_reduce_fx="sum") + self.add_state("APE_joints", + default=torch.zeros(njoints), + dist_reduce_fx="sum") + self.APE_metrics = ["APE_root", "APE_traj", "APE_pose", "APE_joints"] + + # AVE + self.add_state("AVE_root", + default=torch.tensor(0.), + dist_reduce_fx="sum") + self.add_state("AVE_traj", + default=torch.tensor(0.), + dist_reduce_fx="sum") + self.add_state("AVE_pose", + default=torch.zeros(njoints - 1), + dist_reduce_fx="sum") + self.add_state("AVE_joints", + default=torch.zeros(njoints), + dist_reduce_fx="sum") + self.AVE_metrics = ["AVE_root", "AVE_traj", "AVE_pose", "AVE_joints"] + + # All metric + self.metrics = self.APE_metrics + self.AVE_metrics + + def compute(self, sanity_flag): + count = self.count + APE_metrics = { + metric: getattr(self, metric) / count + for metric in self.APE_metrics + } + + # Compute average of APEs + APE_metrics["APE_mean_pose"] = self.APE_pose.mean() / count + APE_metrics["APE_mean_joints"] = self.APE_joints.mean() / count + + # Remove arrays + APE_metrics.pop("APE_pose") + APE_metrics.pop("APE_joints") + + count_seq = self.count_seq + AVE_metrics = { + metric: getattr(self, metric) / count_seq + for metric in self.AVE_metrics + } + + # Compute average of AVEs + AVE_metrics["AVE_mean_pose"] = self.AVE_pose.mean() / count_seq + AVE_metrics["AVE_mean_joints"] = self.AVE_joints.mean() / count_seq + + # Remove arrays + AVE_metrics.pop("AVE_pose") + AVE_metrics.pop("AVE_joints") + + return {**APE_metrics, **AVE_metrics} + + def update(self, jts_text: Tensor, jts_ref: Tensor, lengths: List[int]): + self.count += sum(lengths) + self.count_seq += len(lengths) + + jts_text, poses_text, root_text, traj_text = self.transform( + jts_text, lengths) + jts_ref, poses_ref, root_ref, traj_ref = self.transform( + jts_ref, lengths) + + for i in range(len(lengths)): + self.APE_root += l2_norm(root_text[i], root_ref[i], dim=1).sum() + self.APE_pose += l2_norm(poses_text[i], poses_ref[i], dim=2).sum(0) + self.APE_traj += l2_norm(traj_text[i], traj_ref[i], dim=1).sum() + self.APE_joints += l2_norm(jts_text[i], jts_ref[i], dim=2).sum(0) + + root_sigma_text = variance(root_text[i], lengths[i], dim=0) + root_sigma_ref = variance(root_ref[i], lengths[i], dim=0) + self.AVE_root += l2_norm(root_sigma_text, root_sigma_ref, dim=0) + + traj_sigma_text = variance(traj_text[i], lengths[i], dim=0) + traj_sigma_ref = variance(traj_ref[i], lengths[i], dim=0) + self.AVE_traj += l2_norm(traj_sigma_text, traj_sigma_ref, dim=0) + + poses_sigma_text = variance(poses_text[i], lengths[i], dim=0) + poses_sigma_ref = variance(poses_ref[i], lengths[i], dim=0) + self.AVE_pose += l2_norm(poses_sigma_text, poses_sigma_ref, dim=1) + + jts_sigma_text = variance(jts_text[i], lengths[i], dim=0) + jts_sigma_ref = variance(jts_ref[i], lengths[i], dim=0) + self.AVE_joints += l2_norm(jts_sigma_text, jts_sigma_ref, dim=1) + + def transform(self, joints: Tensor, lengths): + features = self.rifke(joints) + + ret = self.rifke.extract(features) + root_y, poses_features, vel_angles, vel_trajectory_local = ret + # already have the good dimensionality + angles = torch.cumsum(vel_angles, dim=-1) + # First frame should be 0, but if infered it is better to ensure it + angles = angles - angles[..., [0]] + + cos, sin = torch.cos(angles), torch.sin(angles) + rotations = matrix_of_angles(cos, sin, inv=False) + + # Get back the local poses + poses_local = rearrange(poses_features, + "... (joints xyz) -> ... joints xyz", + xyz=3) + + # Rotate the poses + poses = torch.einsum("...lj,...jk->...lk", poses_local[..., [0, 2]], + rotations) + poses = torch.stack( + (poses[..., 0], poses_local[..., 1], poses[..., 1]), axis=-1) + + # Rotate the vel_trajectory + vel_trajectory = torch.einsum("...j,...jk->...k", vel_trajectory_local, + rotations) + # Integrate the trajectory + # Already have the good dimensionality + trajectory = torch.cumsum(vel_trajectory, dim=-2) + # First frame should be 0, but if infered it is better to ensure it + trajectory = trajectory - trajectory[..., [0], :] + + # get the root joint + root = torch.cat( + (trajectory[..., :, [0]], root_y[..., None], trajectory[..., :, + [1]]), + dim=-1) + + # Add the root joints (which is still zero) + poses = torch.cat((0 * poses[..., [0], :], poses), -2) + # put back the root joint y + poses[..., 0, 1] = root_y + + # Add the trajectory globally + poses[..., [0, 2]] += trajectory[..., None, :] + if self.force_in_meter: + # different jointstypes have different scale factors + if self.jointstype == 'mmm': + factor = 1000.0 + elif self.jointstype in ['humanml3d', 'motionx', 'motionx_v26']: + factor = 1000.0 * 0.75 / 480.0 + else: + raise NotImplementedError("This jointstype is not implemented.") + + # return results in meters + return (remove_padding(poses / factor, lengths), # torch.Size([32, 196, 52, 3]) + remove_padding(poses_local / factor, lengths), #torch.Size([32, 196, 51, 3]) + remove_padding(root / factor, lengths), + remove_padding(trajectory / factor, lengths)) + else: + return (remove_padding(poses, lengths), + remove_padding(poses_local, + lengths), remove_padding(root, lengths), + remove_padding(trajectory, lengths)) diff --git a/Evaluator_272/mld/models/metrics/compute_best.py b/Evaluator_272/mld/models/metrics/compute_best.py new file mode 100644 index 0000000000000000000000000000000000000000..fc3bf47a746e968f2df23a48b1db613815aa4127 --- /dev/null +++ b/Evaluator_272/mld/models/metrics/compute_best.py @@ -0,0 +1,60 @@ +from typing import List + +import torch +from einops import rearrange +from torch import Tensor +from torchmetrics import Metric +import numpy as np +from .compute import ComputeMetrics, l2_norm, variance + + +class ComputeMetricsBest(ComputeMetrics): + def update(self, jts_text_: List[Tensor], jts_ref_: List[Tensor], lengths: List[List[int]]): + self.count += sum(lengths[0]) + self.count_seq += len(lengths[0]) + + ntrials = len(jts_text_) + metrics = [] + for index in range(ntrials): + jts_text, poses_text, root_text, traj_text = self.transform(jts_text_[index], lengths[index]) + jts_ref, poses_ref, root_ref, traj_ref = self.transform(jts_ref_[index], lengths[index]) + + mets = [] + for i in range(len(lengths[index])): + APE_root = l2_norm(root_text[i], root_ref[i], dim=1).sum() + APE_pose = l2_norm(poses_text[i], poses_ref[i], dim=2).sum(0) + APE_traj = l2_norm(traj_text[i], traj_ref[i], dim=1).sum() + APE_joints = l2_norm(jts_text[i], jts_ref[i], dim=2).sum(0) + + root_sigma_text = variance(root_text[i], lengths[index][i], dim=0) + root_sigma_ref = variance(root_ref[i], lengths[index][i], dim=0) + AVE_root = l2_norm(root_sigma_text, root_sigma_ref, dim=0) + + traj_sigma_text = variance(traj_text[i], lengths[index][i], dim=0) + traj_sigma_ref = variance(traj_ref[i], lengths[index][i], dim=0) + AVE_traj = l2_norm(traj_sigma_text, traj_sigma_ref, dim=0) + + poses_sigma_text = variance(poses_text[i], lengths[index][i], dim=0) + poses_sigma_ref = variance(poses_ref[i], lengths[index][i], dim=0) + AVE_pose = l2_norm(poses_sigma_text, poses_sigma_ref, dim=1) + + jts_sigma_text = variance(jts_text[i], lengths[index][i], dim=0) + jts_sigma_ref = variance(jts_ref[i], lengths[index][i], dim=0) + AVE_joints = l2_norm(jts_sigma_text, jts_sigma_ref, dim=1) + + met = [APE_root, APE_pose, APE_traj, APE_joints, + AVE_root, AVE_pose, AVE_traj, AVE_joints] + mets.append(met) + metrics.append(mets) + + # Quick hacks + mmm = metrics[np.argmin([x[0][0] for x in metrics])] + APE_root, APE_pose, APE_traj, APE_joints, AVE_root, AVE_pose, AVE_traj, AVE_joints = mmm[0] + self.APE_root += APE_root + self.APE_pose += APE_pose + self.APE_traj += APE_traj + self.APE_joints += APE_joints + self.AVE_root += AVE_root + self.AVE_pose += AVE_pose + self.AVE_traj += AVE_traj + self.AVE_joints += AVE_joints diff --git a/Evaluator_272/mld/models/metrics/compute_body_hand.py b/Evaluator_272/mld/models/metrics/compute_body_hand.py new file mode 100644 index 0000000000000000000000000000000000000000..b5eadf90d0ecc9783e0481aa4ee9bb549b7144a2 --- /dev/null +++ b/Evaluator_272/mld/models/metrics/compute_body_hand.py @@ -0,0 +1,286 @@ +from typing import List + +import torch +from einops import rearrange +from torch import Tensor +from torchmetrics import Metric + +from mld.models.tools.tools import remove_padding +from mld.transforms.joints2jfeats import Rifke +from mld.utils.geometry import matrix_of_angles + +from .utils import l2_norm, variance + + +class ComputeMetrics_body_hand(Metric): + + def __init__(self, + njoints, + jointstype: str = "mmm", + force_in_meter: bool = False, + dist_sync_on_step=True, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + if jointstype not in ["mmm", "humanml3d", "motionx", 'motionx_v26']: + raise NotImplementedError("This jointstype is not implemented.") + self.name = 'APE and AVE' + self.jointstype = jointstype + self.rifke = Rifke(jointstype=jointstype, normalization=False) + + self.force_in_meter = force_in_meter + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + # APE + self.add_state("APE_root", + default=torch.tensor(0.), + dist_reduce_fx="sum") + self.add_state("APE_traj", + default=torch.tensor(0.), + dist_reduce_fx="sum") + self.add_state("APE_pose", + default=torch.zeros(njoints - 1), + dist_reduce_fx="sum") + + self.add_state("APE_pose_body", + default=torch.zeros(22 - 1), + dist_reduce_fx="sum") + + self.add_state("APE_pose_hand", + default=torch.zeros(30), + dist_reduce_fx="sum") + + + self.add_state("APE_joints", + default=torch.zeros(njoints), + dist_reduce_fx="sum") + + + self.add_state("APE_joints_body", + default=torch.zeros(22), + dist_reduce_fx="sum") + + self.add_state("APE_joints_hand", + default=torch.zeros(30), + dist_reduce_fx="sum") + + self.APE_metrics = ["APE_root", "APE_traj", "APE_pose", "APE_pose_body", "APE_pose_hand", "APE_joints", "APE_joints_body", "APE_joints_hand"] + + # AVE + self.add_state("AVE_root", + default=torch.tensor(0.), + dist_reduce_fx="sum") + self.add_state("AVE_traj", + default=torch.tensor(0.), + dist_reduce_fx="sum") + self.add_state("AVE_pose", + default=torch.zeros(njoints - 1), + dist_reduce_fx="sum") + + self.add_state("AVE_pose_body", + default=torch.zeros(22 - 1), + dist_reduce_fx="sum") + + self.add_state("AVE_pose_hand", + default=torch.zeros(30), + dist_reduce_fx="sum") + + self.add_state("AVE_joints", + default=torch.zeros(njoints), + dist_reduce_fx="sum") + + self.add_state("AVE_joints_body", + default=torch.zeros(22), + dist_reduce_fx="sum") + + self.add_state("AVE_joints_hand", + default=torch.zeros(30), + dist_reduce_fx="sum") + + + self.AVE_metrics = ["AVE_root", "AVE_traj", "AVE_pose", "AVE_pose_body", "AVE_pose_hand", "AVE_joints", "AVE_joints_body", "AVE_joints_hand"] + + # All metric + self.metrics = self.APE_metrics + self.AVE_metrics + + def compute(self, sanity_flag): + count = self.count + APE_metrics = { + metric: getattr(self, metric) / count + for metric in self.APE_metrics + } + + # Compute average of APEs + APE_metrics["APE_mean_pose"] = self.APE_pose.mean() / count + APE_metrics["APE_mean_pose_body"] = self.APE_pose_body.mean() / count + APE_metrics["APE_mean_pose_hand"] = self.APE_pose_hand.mean() / count + APE_metrics["APE_mean_joints"] = self.APE_joints.mean() / count + APE_metrics["APE_mean_joints_body"] = self.APE_joints_body.mean() / count + APE_metrics["APE_mean_joints_hand"] = self.APE_joints_hand.mean() / count + + # Remove arrays + APE_metrics.pop("APE_pose") + APE_metrics.pop("APE_pose_body") + APE_metrics.pop("APE_pose_hand") + APE_metrics.pop("APE_joints") + APE_metrics.pop("APE_joints_body") + APE_metrics.pop("APE_joints_hand") + + count_seq = self.count_seq + AVE_metrics = { + metric: getattr(self, metric) / count_seq + for metric in self.AVE_metrics + } + + # Compute average of AVEs + AVE_metrics["AVE_mean_pose"] = self.AVE_pose.mean() / count_seq + AVE_metrics["AVE_mean_pose_body"] = self.AVE_pose_body.mean() / count_seq + AVE_metrics["AVE_mean_pose_hand"] = self.AVE_pose_hand.mean() / count_seq + AVE_metrics["AVE_mean_joints"] = self.AVE_joints.mean() / count_seq + AVE_metrics["AVE_mean_joints_body"] = self.AVE_joints_body.mean() / count_seq + AVE_metrics["AVE_mean_joints_hand"] = self.AVE_joints_hand.mean() / count_seq + + # Remove arrays + AVE_metrics.pop("AVE_pose") + AVE_metrics.pop("AVE_pose_body") + AVE_metrics.pop("AVE_pose_hand") + AVE_metrics.pop("AVE_joints") + AVE_metrics.pop("AVE_joints_body") + AVE_metrics.pop("AVE_joints_hand") + + + return {**APE_metrics, **AVE_metrics} + + def update(self, jts_text: Tensor, jts_ref: Tensor, lengths: List[int]): + self.count += sum(lengths) + self.count_seq += len(lengths) + + jts_text, poses_text, root_text, traj_text = self.transform( + jts_text, lengths) + jts_ref, poses_ref, root_ref, traj_ref = self.transform( + jts_ref, lengths) + + + for i in range(len(lengths)): + jts_text_body = jts_text[i][..., :22, :] + jts_text_hand = jts_text[i][..., 22:, :] + jts_ref_body = jts_ref[i][..., :22, :] + jts_ref_hand = jts_ref[i][..., 22:, :] + + + poses_text_body = poses_text[i][..., :21, :] + poses_text_hand = poses_text[i][..., 21:, :] + poses_ref_body = poses_ref[i][..., :21, :] + poses_ref_hand = poses_ref[i][..., 21:, :] + + self.APE_root += l2_norm(root_text[i], root_ref[i], dim=1).sum() + self.APE_pose += l2_norm(poses_text[i], poses_ref[i], dim=2).sum(0) + self.APE_pose_body += l2_norm(poses_text_body, poses_ref_body, dim=2).sum(0) + self.APE_pose_hand += l2_norm(poses_text_hand, poses_ref_hand, dim=2).sum(0) + + self.APE_traj += l2_norm(traj_text[i], traj_ref[i], dim=1).sum() + self.APE_joints += l2_norm(jts_text[i], jts_ref[i], dim=2).sum(0) + self.APE_joints_body += l2_norm(jts_text_body, jts_ref_body, dim=2).sum(0) + self.APE_joints_hand += l2_norm(jts_text_hand, jts_ref_hand, dim=2).sum(0) + + root_sigma_text = variance(root_text[i], lengths[i], dim=0) + root_sigma_ref = variance(root_ref[i], lengths[i], dim=0) + self.AVE_root += l2_norm(root_sigma_text, root_sigma_ref, dim=0) + + traj_sigma_text = variance(traj_text[i], lengths[i], dim=0) + traj_sigma_ref = variance(traj_ref[i], lengths[i], dim=0) + self.AVE_traj += l2_norm(traj_sigma_text, traj_sigma_ref, dim=0) + + poses_sigma_text = variance(poses_text[i], lengths[i], dim=0) + poses_sigma_ref = variance(poses_ref[i], lengths[i], dim=0) + self.AVE_pose += l2_norm(poses_sigma_text, poses_sigma_ref, dim=1) + + poses_body_sigma_text = variance(poses_text_body, lengths[i], dim=0) + poses_body_sigma_ref = variance(poses_ref_body, lengths[i], dim=0) + self.AVE_pose_body += l2_norm(poses_body_sigma_text, poses_body_sigma_ref, dim=1) + + + poses_hand_sigma_text = variance(poses_text_hand, lengths[i], dim=0) + poses_hand_sigma_ref = variance(poses_ref_hand, lengths[i], dim=0) + self.AVE_pose_hand += l2_norm(poses_hand_sigma_text, poses_hand_sigma_ref, dim=1) + + + jts_sigma_text = variance(jts_text[i], lengths[i], dim=0) + jts_sigma_ref = variance(jts_ref[i], lengths[i], dim=0) + self.AVE_joints += l2_norm(jts_sigma_text, jts_sigma_ref, dim=1) + + jts_body_sigma_text = variance(jts_text_body, lengths[i], dim=0) + jts_body_sigma_ref = variance(jts_ref_body, lengths[i], dim=0) + self.AVE_joints_body += l2_norm(jts_body_sigma_text, jts_body_sigma_ref, dim=1) + + jts_hand_sigma_text = variance(jts_text_hand, lengths[i], dim=0) + jts_hand_sigma_ref = variance(jts_ref_hand, lengths[i], dim=0) + self.AVE_joints_hand += l2_norm(jts_hand_sigma_text, jts_hand_sigma_ref, dim=1) + + + + def transform(self, joints: Tensor, lengths): + features = self.rifke(joints) + + ret = self.rifke.extract(features) + root_y, poses_features, vel_angles, vel_trajectory_local = ret + # already have the good dimensionality + angles = torch.cumsum(vel_angles, dim=-1) + # First frame should be 0, but if infered it is better to ensure it + angles = angles - angles[..., [0]] + + cos, sin = torch.cos(angles), torch.sin(angles) + rotations = matrix_of_angles(cos, sin, inv=False) + + # Get back the local poses + poses_local = rearrange(poses_features, + "... (joints xyz) -> ... joints xyz", + xyz=3) + + # Rotate the poses + poses = torch.einsum("...lj,...jk->...lk", poses_local[..., [0, 2]], + rotations) + poses = torch.stack( + (poses[..., 0], poses_local[..., 1], poses[..., 1]), axis=-1) + + # Rotate the vel_trajectory + vel_trajectory = torch.einsum("...j,...jk->...k", vel_trajectory_local, + rotations) + # Integrate the trajectory + # Already have the good dimensionality + trajectory = torch.cumsum(vel_trajectory, dim=-2) + # First frame should be 0, but if infered it is better to ensure it + trajectory = trajectory - trajectory[..., [0], :] + + # get the root joint + root = torch.cat( + (trajectory[..., :, [0]], root_y[..., None], trajectory[..., :, + [1]]), + dim=-1) + + # Add the root joints (which is still zero) + poses = torch.cat((0 * poses[..., [0], :], poses), -2) + # put back the root joint y + poses[..., 0, 1] = root_y + + # Add the trajectory globally + poses[..., [0, 2]] += trajectory[..., None, :] + if self.force_in_meter: + # different jointstypes have different scale factors + if self.jointstype == 'mmm': + factor = 1000.0 + elif self.jointstype in ['humanml3d', 'motionx']: + factor = 1000.0 * 0.75 / 480.0 + + # return results in meters + return (remove_padding(poses / factor, lengths), # torch.Size([32, 196, 52, 3]) + remove_padding(poses_local / factor, lengths), #torch.Size([32, 196, 51, 3]) + remove_padding(root / factor, lengths), + remove_padding(trajectory / factor, lengths)) + else: + return (remove_padding(poses, lengths), + remove_padding(poses_local, + lengths), remove_padding(root, lengths), + remove_padding(trajectory, lengths)) diff --git a/Evaluator_272/mld/models/metrics/compute_worst.py b/Evaluator_272/mld/models/metrics/compute_worst.py new file mode 100644 index 0000000000000000000000000000000000000000..95b489ad75b536ba1b1e9f0be4063afa4d010086 --- /dev/null +++ b/Evaluator_272/mld/models/metrics/compute_worst.py @@ -0,0 +1,60 @@ +from typing import List + +import torch +from einops import rearrange +from torch import Tensor +from torchmetrics import Metric +import numpy as np +from .compute import ComputeMetrics, l2_norm, variance + + +class ComputeMetricsWorst(ComputeMetrics): + def update(self, jts_text_: List[Tensor], jts_ref_: List[Tensor], lengths: List[List[int]]): + self.count += sum(lengths[0]) + self.count_seq += len(lengths[0]) + + ntrials = len(jts_text_) + metrics = [] + for index in range(ntrials): + jts_text, poses_text, root_text, traj_text = self.transform(jts_text_[index], lengths[index]) + jts_ref, poses_ref, root_ref, traj_ref = self.transform(jts_ref_[index], lengths[index]) + + mets = [] + for i in range(len(lengths[index])): + APE_root = l2_norm(root_text[i], root_ref[i], dim=1).sum() + APE_pose = l2_norm(poses_text[i], poses_ref[i], dim=2).sum(0) + APE_traj = l2_norm(traj_text[i], traj_ref[i], dim=1).sum() + APE_joints = l2_norm(jts_text[i], jts_ref[i], dim=2).sum(0) + + root_sigma_text = variance(root_text[i], lengths[index][i], dim=0) + root_sigma_ref = variance(root_ref[i], lengths[index][i], dim=0) + AVE_root = l2_norm(root_sigma_text, root_sigma_ref, dim=0) + + traj_sigma_text = variance(traj_text[i], lengths[index][i], dim=0) + traj_sigma_ref = variance(traj_ref[i], lengths[index][i], dim=0) + AVE_traj = l2_norm(traj_sigma_text, traj_sigma_ref, dim=0) + + poses_sigma_text = variance(poses_text[i], lengths[index][i], dim=0) + poses_sigma_ref = variance(poses_ref[i], lengths[index][i], dim=0) + AVE_pose = l2_norm(poses_sigma_text, poses_sigma_ref, dim=1) + + jts_sigma_text = variance(jts_text[i], lengths[index][i], dim=0) + jts_sigma_ref = variance(jts_ref[i], lengths[index][i], dim=0) + AVE_joints = l2_norm(jts_sigma_text, jts_sigma_ref, dim=1) + + met = [APE_root, APE_pose, APE_traj, APE_joints, + AVE_root, AVE_pose, AVE_traj, AVE_joints] + mets.append(met) + metrics.append(mets) + + # Quick hacks + mmm = metrics[np.argmax([x[0][0] for x in metrics])] + APE_root, APE_pose, APE_traj, APE_joints, AVE_root, AVE_pose, AVE_traj, AVE_joints = mmm[0] + self.APE_root += APE_root + self.APE_pose += APE_pose + self.APE_traj += APE_traj + self.APE_joints += APE_joints + self.AVE_root += AVE_root + self.AVE_pose += AVE_pose + self.AVE_traj += AVE_traj + self.AVE_joints += AVE_joints diff --git a/Evaluator_272/mld/models/metrics/mm.py b/Evaluator_272/mld/models/metrics/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc23a4800972121ba1934a9c7c8215f2c5b1450 --- /dev/null +++ b/Evaluator_272/mld/models/metrics/mm.py @@ -0,0 +1,62 @@ +from typing import List + +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.functional import pairwise_euclidean_distance + +from .utils import * + + +class MMMetrics(Metric): + full_state_update = True + + def __init__(self, mm_num_times=10, dist_sync_on_step=True, **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "MultiModality scores" + + self.mm_num_times = mm_num_times + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = ["MultiModality"] + self.add_state("MultiModality", + default=torch.tensor(0.), + dist_reduce_fx="sum") + + # chached batches + self.add_state("mm_motion_embeddings", default=[], dist_reduce_fx=None) + + def compute(self, sanity_flag): + count = self.count.item() + count_seq = self.count_seq.item() + + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # if in sanity check stage then jump + if sanity_flag: + return metrics + + # cat all embeddings + all_mm_motions = torch.cat(self.mm_motion_embeddings, + axis=0).cpu().numpy() + metrics['MultiModality'] = calculate_multimodality_np( + all_mm_motions, self.mm_num_times) + + return {**metrics} + + def update( + self, + mm_motion_embeddings: Tensor, + lengths: List[int], + ): + self.count += sum(lengths) + self.count_seq += len(lengths) + + # store all mm motion embeddings + self.mm_motion_embeddings.append(mm_motion_embeddings) diff --git a/Evaluator_272/mld/models/metrics/mr.py b/Evaluator_272/mld/models/metrics/mr.py new file mode 100644 index 0000000000000000000000000000000000000000..d50d105ea1fa5e9fd59e5fc87a0196a8a4af4ad8 --- /dev/null +++ b/Evaluator_272/mld/models/metrics/mr.py @@ -0,0 +1,106 @@ +from typing import List + +import torch +from torch import Tensor +from torchmetrics import Metric + +from .utils import * + + +# motion reconstruction metric +class MRMetrics(Metric): + + def __init__(self, + njoints, + jointstype: str = "mmm", + force_in_meter: bool = True, + align_root: bool = True, + dist_sync_on_step=True, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + if jointstype not in ["mmm", "humanml3d", "motionx", "motionx_v26"]: + raise NotImplementedError("This jointstype is not implemented.") + + self.name = 'Motion Reconstructions' + self.jointstype = jointstype + self.align_root = align_root + self.force_in_meter = force_in_meter + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.add_state("MPJPE", + default=torch.tensor([0.0]), + dist_reduce_fx="sum") + + + self.add_state("PAMPJPE", + default=torch.tensor([0.0]), + dist_reduce_fx="sum") + + + + self.add_state("ACCEL", + default=torch.tensor([0.0]), + dist_reduce_fx="sum") + + + # todo + # self.add_state("ROOT", default=torch.tensor([0.0]), dist_reduce_fx="sum") + + self.MR_metrics = ["MPJPE", "PAMPJPE", "ACCEL"] + + # All metric + self.metrics = self.MR_metrics + + def compute(self, sanity_flag): + if self.force_in_meter: + # different jointstypes have different scale factors + # if self.jointstype == 'mmm': + # factor = 1000.0 + # elif self.jointstype == 'humanml3d': + # factor = 1000.0 * 0.75 / 480 + factor = 1000.0 + else: + factor = 1.0 + + count = self.count + count_seq = self.count_seq + mr_metrics = {} + mr_metrics["MPJPE"] = self.MPJPE / count * factor + + mr_metrics["PAMPJPE"] = self.PAMPJPE / count * factor + + # accel error: joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:] + # n-2 for each sequences + mr_metrics["ACCEL"] = self.ACCEL / (count - 2 * count_seq) * factor + + return mr_metrics + + def update(self, joints_rst: Tensor, joints_ref: Tensor, + lengths: List[int]): + assert joints_rst.shape == joints_ref.shape + assert joints_rst.dim() == 4 + # (bs, seq, njoint=22, 3) + + self.count += sum(lengths) + self.count_seq += len(lengths) + + # avoid cuda error of DDP in pampjpe + rst = joints_rst.detach().cpu() + ref = joints_ref.detach().cpu() + + # align root joints index + if self.align_root and self.jointstype in ['mmm', 'humanml3d', 'motionx']: + align_inds = [0] + else: + align_inds = None + + for i in range(len(lengths)): + self.MPJPE += torch.sum( + calc_mpjpe(rst[i], ref[i], align_inds=align_inds)) + self.PAMPJPE += torch.sum(calc_pampjpe(rst[i], ref[i])) + self.ACCEL += torch.sum(calc_accel(rst[i], ref[i])) diff --git a/Evaluator_272/mld/models/metrics/retrieval_recall.py b/Evaluator_272/mld/models/metrics/retrieval_recall.py new file mode 100644 index 0000000000000000000000000000000000000000..5483db36d99155efcad2f9e14781340e85c4b1ef --- /dev/null +++ b/Evaluator_272/mld/models/metrics/retrieval_recall.py @@ -0,0 +1,170 @@ +from typing import List + +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.functional import pairwise_euclidean_distance + +from .utils import * + + +class Retrieval_Recall_Metrics(Metric): + full_state_update = True + + def __init__(self, + # top_k=3, + # R_size=32, + # diversity_times=300, + mode = ['all'], + dist_sync_on_step=True, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "Retrieval_recall" + + if 'small_batch' in mode: + self.R_size = R_size + + # self.top_k = top_k + self.top_k = ['1', '2', '3', '5', '10'] + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = [] + # Matching scores + # self.add_state("Matching_score", + # default=torch.tensor(0.0), + # dist_reduce_fx="sum") + # self.add_state("gt_Matching_score", + # default=torch.tensor(0.0), + # dist_reduce_fx="sum") + # self.Matching_metrics = ["Matching_score", "gt_Matching_score"] + self.Matching_metrics = [] + for k in self.top_k: + self.add_state( + f"R_precision_top_{k}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"R_precision_top_{k}") + for k in self.top_k: + self.add_state( + f"gt_R_precision_top_{k}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"gt_R_precision_top_{k}") + + self.metrics.extend(self.Matching_metrics) + + # chached batches + self.add_state("text_embeddings", default=[], dist_reduce_fx=None) + self.add_state("recmotion_embeddings", default=[], dist_reduce_fx=None) + self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx=None) + + def compute(self, sanity_flag): + count = self.count.item() + count_seq = self.count_seq.item() + + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # if in sanity check stage then jump + if sanity_flag: + return metrics + # cat all embeddings + shuffle_idx = torch.randperm(count_seq) + all_texts = torch.cat(self.text_embeddings, + axis=0).cpu()[shuffle_idx, :] + all_genmotions = torch.cat(self.recmotion_embeddings, + axis=0).cpu()[shuffle_idx, :] + all_gtmotions = torch.cat(self.gtmotion_embeddings, + axis=0).cpu()[shuffle_idx, :] + + # Compute r-precision + assert count_seq > self.R_size + # print("**********************************") + # print(count_seq) + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_genmotions[i * self.R_size:(i + 1) * + self.R_size] + # dist_mat = pairwise_euclidean_distance(group_texts, group_motions) + # [bs=32, 32] + dist_mat = euclidean_distance_matrix(group_texts, + group_motions).nan_to_num() + # print(dist_mat[:5]) + self.Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) + R_count = count_seq // self.R_size * self.R_size + metrics["Matching_score"] = self.Matching_score / R_count + for k in range(self.top_k): + metrics[f"R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count + + # Compute r-precision with gt + assert count_seq > self.R_size + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_gtmotions[i * self.R_size:(i + 1) * + self.R_size] + # [bs=32, 32] + dist_mat = euclidean_distance_matrix(group_texts, + group_motions).nan_to_num() + # match score + self.gt_Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) + metrics["gt_Matching_score"] = self.gt_Matching_score / R_count + for k in range(self.top_k): + metrics[f"gt_R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count + + # tensor -> numpy for FID + all_genmotions = all_genmotions.numpy() + all_gtmotions = all_gtmotions.numpy() + + # Compute fid + mu, cov = calculate_activation_statistics_np(all_genmotions) + # gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov) + + # Compute diversity + assert count_seq > self.diversity_times + metrics["Diversity"] = calculate_diversity_np(all_genmotions, + self.diversity_times) + metrics["gt_Diversity"] = calculate_diversity_np( + all_gtmotions, self.diversity_times) + + return {**metrics} + + def update( + self, + text_embeddings: Tensor, + recmotion_embeddings: Tensor, + gtmotion_embeddings: Tensor, + lengths: List[int], + ): + self.count += sum(lengths) + self.count_seq += len(lengths) + + # [bs, nlatent*ndim] <= [bs, nlatent, ndim] + text_embeddings = torch.flatten(text_embeddings, start_dim=1).detach() + recmotion_embeddings = torch.flatten(recmotion_embeddings, + start_dim=1).detach() + gtmotion_embeddings = torch.flatten(gtmotion_embeddings, + start_dim=1).detach() + + # store all texts and motions + self.text_embeddings.append(text_embeddings) + self.recmotion_embeddings.append(recmotion_embeddings) + self.gtmotion_embeddings.append(gtmotion_embeddings) diff --git a/Evaluator_272/mld/models/metrics/tm2t.py b/Evaluator_272/mld/models/metrics/tm2t.py new file mode 100644 index 0000000000000000000000000000000000000000..9a12ac4f0e3f2b446a95f85dcf66b9f437148065 --- /dev/null +++ b/Evaluator_272/mld/models/metrics/tm2t.py @@ -0,0 +1,180 @@ +from typing import List + +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.functional import pairwise_euclidean_distance + +from .utils import * + + +class TM2TMetrics(Metric): + full_state_update = True + + def __init__(self, + top_k=3, + R_size=32, + diversity_times=300, + dist_sync_on_step=True, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "matching, fid, and diversity scores" + + self.top_k = top_k + self.R_size = R_size + self.diversity_times = diversity_times + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = [] + # Matching scores + self.add_state("Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("gt_Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.Matching_metrics = ["Matching_score", "gt_Matching_score"] + for k in range(1, top_k + 1): + self.add_state( + f"R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"R_precision_top_{str(k)}") + for k in range(1, top_k + 1): + self.add_state( + f"gt_R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}") + + self.metrics.extend(self.Matching_metrics) + + # Fid + self.add_state("FID", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.metrics.append("FID") + + # Diversity + self.add_state("Diversity", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("gt_Diversity", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.metrics.extend(["Diversity", "gt_Diversity"]) + + # chached batches + self.add_state("text_embeddings", default=[], dist_reduce_fx=None) + self.add_state("recmotion_embeddings", default=[], dist_reduce_fx=None) + self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx=None) + + def compute(self, sanity_flag): + count = self.count.item() + count_seq = self.count_seq.item() + + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # if in sanity check stage then jump + if sanity_flag: + return metrics + + # cat all embeddings + shuffle_idx = torch.randperm(count_seq) + all_texts = torch.cat(self.text_embeddings, + axis=0).cpu()[shuffle_idx, :] + all_genmotions = torch.cat(self.recmotion_embeddings, + axis=0).cpu()[shuffle_idx, :] + all_gtmotions = torch.cat(self.gtmotion_embeddings, + axis=0).cpu()[shuffle_idx, :] + + # Compute r-precision + assert count_seq > self.R_size + # print("**********************************") + # print(count_seq) + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_genmotions[i * self.R_size:(i + 1) * + self.R_size] + # dist_mat = pairwise_euclidean_distance(group_texts, group_motions) + # [bs=32, 32] + dist_mat = euclidean_distance_matrix(group_texts, + group_motions).nan_to_num() + # print(dist_mat[:5]) + self.Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) + R_count = count_seq // self.R_size * self.R_size + metrics["Matching_score"] = self.Matching_score / R_count + for k in range(self.top_k): + metrics[f"R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count + + # Compute r-precision with gt + assert count_seq > self.R_size + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_gtmotions[i * self.R_size:(i + 1) * + self.R_size] + # [bs=32, 32] + dist_mat = euclidean_distance_matrix(group_texts, + group_motions).nan_to_num() + # match score + self.gt_Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) + metrics["gt_Matching_score"] = self.gt_Matching_score / R_count + for k in range(self.top_k): + metrics[f"gt_R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count + + # tensor -> numpy for FID + all_genmotions = all_genmotions.numpy() + all_gtmotions = all_gtmotions.numpy() + + # Compute fid + mu, cov = calculate_activation_statistics_np(all_genmotions) + # gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov) + + # Compute diversity + assert count_seq > self.diversity_times + metrics["Diversity"] = calculate_diversity_np(all_genmotions, + self.diversity_times) + metrics["gt_Diversity"] = calculate_diversity_np( + all_gtmotions, self.diversity_times) + + return {**metrics} + + def update( + self, + text_embeddings: Tensor, + recmotion_embeddings: Tensor, + gtmotion_embeddings: Tensor, + lengths: List[int], + ): + self.count += sum(lengths) + self.count_seq += len(lengths) + + # [bs, nlatent*ndim] <= [bs, nlatent, ndim] + text_embeddings = torch.flatten(text_embeddings, start_dim=1).detach() + recmotion_embeddings = torch.flatten(recmotion_embeddings, + start_dim=1).detach() + gtmotion_embeddings = torch.flatten(gtmotion_embeddings, + start_dim=1).detach() + + # store all texts and motions + self.text_embeddings.append(text_embeddings) + self.recmotion_embeddings.append(recmotion_embeddings) + self.gtmotion_embeddings.append(gtmotion_embeddings) diff --git a/Evaluator_272/mld/models/metrics/tm2t_R256.py b/Evaluator_272/mld/models/metrics/tm2t_R256.py new file mode 100644 index 0000000000000000000000000000000000000000..6302cdbb1cd131d97ef11e7267cde18749996d15 --- /dev/null +++ b/Evaluator_272/mld/models/metrics/tm2t_R256.py @@ -0,0 +1,167 @@ +from typing import List + +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.functional import pairwise_euclidean_distance + +from .utils import * + + +class TM2TMetrics_R256(Metric): + full_state_update = True + + def __init__(self, + top_k=10, + R_size=256, + diversity_times=300, + dist_sync_on_step=True, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "matching, fid, and diversity scores" + + self.top_k = top_k + self.R_size = R_size + # self.diversity_times = diversity_times + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = [] + # Matching scores + # self.add_state("Matching_score", + # default=torch.tensor(0.0), + # dist_reduce_fx="sum") + # self.add_state("gt_Matching_score", + # default=torch.tensor(0.0), + # dist_reduce_fx="sum") + self.Matching_metrics = [] + for k in range(1, top_k + 1): + self.add_state( + f"R_precision_top_{str(k)}_{str(self.R_size)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"R_precision_top_{str(k)}_{str(self.R_size)}") + for k in range(1, top_k + 1): + self.add_state( + f"gt_R_precision_top_{str(k)}_{str(self.R_size)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}_{str(self.R_size)}") + + self.metrics.extend(self.Matching_metrics) + + # chached batches + self.add_state("text_embeddings", default=[], dist_reduce_fx=None) + self.add_state("recmotion_embeddings", default=[], dist_reduce_fx=None) + self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx=None) + + def compute(self, sanity_flag): + count = self.count.item() + count_seq = self.count_seq.item() + + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # if in sanity check stage then jump + if sanity_flag: + return metrics + + # cat all embeddings + shuffle_idx = torch.randperm(count_seq) + all_texts = torch.cat(self.text_embeddings, + axis=0).cpu()[shuffle_idx, :] + all_genmotions = torch.cat(self.recmotion_embeddings, + axis=0).cpu()[shuffle_idx, :] + all_gtmotions = torch.cat(self.gtmotion_embeddings, + axis=0).cpu()[shuffle_idx, :] + + # Compute r-precision + assert count_seq > self.R_size + # print("**********************************") + # print(count_seq) + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_genmotions[i * self.R_size:(i + 1) * + self.R_size] + # dist_mat = pairwise_euclidean_distance(group_texts, group_motions) + # [bs=32, 32] + dist_mat = euclidean_distance_matrix(group_texts, + group_motions).nan_to_num() + # print(dist_mat[:5]) + # self.Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) + R_count = count_seq // self.R_size * self.R_size + # metrics["Matching_score"] = self.Matching_score / R_count + for k in range(self.top_k): + metrics[f"R_precision_top_{str(k+1)}_{self.R_size}"] = top_k_mat[k] / R_count + + # Compute r-precision with gt + assert count_seq > self.R_size + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_gtmotions[i * self.R_size:(i + 1) * + self.R_size] + # [bs=32, 32] + dist_mat = euclidean_distance_matrix(group_texts, + group_motions).nan_to_num() + # match score + # self.gt_Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) + # metrics["gt_Matching_score"] = self.gt_Matching_score / R_count + for k in range(self.top_k): + metrics[f"gt_R_precision_top_{str(k+1)}_{self.R_size}"] = top_k_mat[k] / R_count + + # tensor -> numpy for FID + # all_genmotions = all_genmotions.numpy() + # all_gtmotions = all_gtmotions.numpy() + + # Compute fid + # mu, cov = calculate_activation_statistics_np(all_genmotions) + # # gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + # gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + # metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov) + + # Compute diversity + # assert count_seq > self.diversity_times + # metrics["Diversity"] = calculate_diversity_np(all_genmotions, + # self.diversity_times) + # metrics["gt_Diversity"] = calculate_diversity_np( + # all_gtmotions, self.diversity_times) + + return {**metrics} + + def update( + self, + text_embeddings: Tensor, + recmotion_embeddings: Tensor, + gtmotion_embeddings: Tensor, + lengths: List[int], + ): + self.count += sum(lengths) + self.count_seq += len(lengths) + + # [bs, nlatent*ndim] <= [bs, nlatent, ndim] + text_embeddings = torch.flatten(text_embeddings, start_dim=1).detach() + recmotion_embeddings = torch.flatten(recmotion_embeddings, + start_dim=1).detach() + gtmotion_embeddings = torch.flatten(gtmotion_embeddings, + start_dim=1).detach() + + # store all texts and motions + self.text_embeddings.append(text_embeddings) + self.recmotion_embeddings.append(recmotion_embeddings) + self.gtmotion_embeddings.append(gtmotion_embeddings) diff --git a/Evaluator_272/mld/models/metrics/tmr_tm2t.py b/Evaluator_272/mld/models/metrics/tmr_tm2t.py new file mode 100644 index 0000000000000000000000000000000000000000..72ca7c79fe80549c991ee57765ac4ca74e94e5ba --- /dev/null +++ b/Evaluator_272/mld/models/metrics/tmr_tm2t.py @@ -0,0 +1,188 @@ +from typing import List + +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.functional import pairwise_euclidean_distance + +from .utils import * + + +class TMR_TM2TMetrics(Metric): + full_state_update = True + + def __init__(self, + top_k=3, + R_size=32, + diversity_times=300, + dist_sync_on_step=True, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "matching" + + self.top_k = top_k + self.R_size = R_size + # self.diversity_times = diversity_times + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = [] + # Matching scores + self.add_state("TMR_Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("TMR_gt_Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.Matching_metrics = ["TMR_Matching_score", "TMR_gt_Matching_score"] + for k in range(1, top_k + 1): + self.add_state( + f"TMR_R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"TMR_R_precision_top_{str(k)}") + for k in range(1, top_k + 1): + self.add_state( + f"TMR_gt_R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"TMR_gt_R_precision_top_{str(k)}") + + self.metrics.extend(self.Matching_metrics) + + # Fid + # self.add_state("FID", default=torch.tensor(0.0), dist_reduce_fx="sum") + # self.metrics.append("FID") + + # Diversity + # self.add_state("Diversity", + # default=torch.tensor(0.0), + # dist_reduce_fx="sum") + # self.add_state("gt_Diversity", + # default=torch.tensor(0.0), + # dist_reduce_fx="sum") + # self.metrics.extend(["Diversity", "gt_Diversity"]) + + # chached batches + self.add_state("text_embeddings", default=[], dist_reduce_fx=None) + self.add_state("recmotion_embeddings", default=[], dist_reduce_fx=None) + self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx=None) + + def compute(self, sanity_flag): + count = self.count.item() + count_seq = self.count_seq.item() + + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # if in sanity check stage then jump + if sanity_flag: + return metrics + + # cat all embeddings + shuffle_idx = torch.randperm(count_seq) + all_texts = torch.cat(self.text_embeddings, + axis=0).cpu()[shuffle_idx, :] + all_genmotions = torch.cat(self.recmotion_embeddings, + axis=0).cpu()[shuffle_idx, :] + all_gtmotions = torch.cat(self.gtmotion_embeddings, + axis=0).cpu()[shuffle_idx, :] + + # Compute r-precision + assert count_seq > self.R_size + + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_genmotions[i * self.R_size:(i + 1) * + self.R_size] + # dist_mat = pairwise_euclidean_distance(group_texts, group_motions) + # [bs=32, 32] + group_texts = torch.nn.functional.normalize(group_texts, dim=1) + group_motions = torch.nn.functional.normalize(group_motions, dim=1) + + dist_mat = euclidean_distance_matrix(group_texts, + group_motions).nan_to_num() + # print(dist_mat[:5]) + self.TMR_Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) + R_count = count_seq // self.R_size * self.R_size + metrics["TMR_Matching_score"] = self.TMR_Matching_score / R_count + for k in range(self.top_k): + metrics[f"TMR_R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count + + # Compute r-precision with gt + assert count_seq > self.R_size + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_gtmotions[i * self.R_size:(i + 1) * + self.R_size] + # [bs=32, 32] + + group_texts = torch.nn.functional.normalize(group_texts, dim=1) + group_motions = torch.nn.functional.normalize(group_motions, dim=1) + + dist_mat = euclidean_distance_matrix(group_texts, + group_motions).nan_to_num() + + # match score + self.TMR_gt_Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) + metrics["TMR_gt_Matching_score"] = self.TMR_gt_Matching_score / R_count + for k in range(self.top_k): + metrics[f"TMR_gt_R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count + + # tensor -> numpy for FID + # all_genmotions = all_genmotions.numpy() + # all_gtmotions = all_gtmotions.numpy() + + # Compute fid + # mu, cov = calculate_activation_statistics_np(all_genmotions) + # gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + # gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + # metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov) + + # Compute diversity + # assert count_seq > self.diversity_times + # metrics["Diversity"] = calculate_diversity_np(all_genmotions, + # self.diversity_times) + # metrics["gt_Diversity"] = calculate_diversity_np( + # all_gtmotions, self.diversity_times) + + return {**metrics} + + def update( + self, + text_embeddings: Tensor, + recmotion_embeddings: Tensor, + gtmotion_embeddings: Tensor, + lengths: List[int], + ): + + self.count += sum(lengths) + self.count_seq += len(lengths) + + # [bs, nlatent*ndim] <= [bs, nlatent, ndim] + text_embeddings = torch.flatten(text_embeddings, start_dim=1).detach() + recmotion_embeddings = torch.flatten(recmotion_embeddings, + start_dim=1).detach() + gtmotion_embeddings = torch.flatten(gtmotion_embeddings, + start_dim=1).detach() + + # store all texts and motions + self.text_embeddings.append(text_embeddings) + self.recmotion_embeddings.append(recmotion_embeddings) + self.gtmotion_embeddings.append(gtmotion_embeddings) diff --git a/Evaluator_272/mld/models/metrics/uncond.py b/Evaluator_272/mld/models/metrics/uncond.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1cb27c1944b6d57a832bca512b7888233537bd --- /dev/null +++ b/Evaluator_272/mld/models/metrics/uncond.py @@ -0,0 +1,120 @@ +from typing import List + +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.functional import pairwise_euclidean_distance + +from .utils import * + + +class UncondMetrics(Metric): + full_state_update = True + + def __init__(self, + top_k=3, + R_size=32, + diversity_times=300, + dist_sync_on_step=True, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "fid, kid, and diversity scores" + + self.top_k = top_k + self.R_size = R_size + self.diversity_times = 300 + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = [] + + # KID + self.add_state("KID_mean", + default=torch.tensor(0.0), + dist_reduce_fx="mean") + self.add_state("KID_std", + default=torch.tensor(0.0), + dist_reduce_fx="mean") + self.metrics.extend(["KID_mean", "KID_std"]) + # Fid + self.add_state("FID", default=torch.tensor(0.0), dist_reduce_fx="mean") + self.metrics.append("FID") + + # Diversity + self.add_state("Diversity", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("gt_Diversity", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.metrics.extend(["Diversity", "gt_Diversity"]) + + # chached batches + self.add_state("recmotion_embeddings", default=[], dist_reduce_fx=None) + self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx=None) + + def compute(self, sanity_flag): + count = self.count.item() + count_seq = self.count_seq.item() + + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # if in sanity check stage then jump + if sanity_flag: + return metrics + + # cat all embeddings + all_gtmotions = torch.cat(self.gtmotion_embeddings, axis=0).cpu() + all_genmotions = torch.cat(self.recmotion_embeddings, axis=0).cpu() + + # Compute kid + + KID_mean, KID_std = calculate_kid(all_gtmotions, all_genmotions) + metrics["KID_mean"] = KID_mean + metrics["KID_std"] = KID_std + + # tensor -> numpy for FID + all_genmotions = all_genmotions.numpy() + all_gtmotions = all_gtmotions.numpy() + + # Compute fid + mu, cov = calculate_activation_statistics_np(all_genmotions) + # gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov) + + # Compute diversity + assert count_seq > self.diversity_times + print(all_genmotions.shape) + print(all_gtmotions.shape) + metrics["Diversity"] = calculate_diversity_np(all_genmotions, + self.diversity_times) + metrics["gt_Diversity"] = calculate_diversity_np( + all_gtmotions, self.diversity_times) + + return {**metrics} + + def update( + self, + gtmotion_embeddings: Tensor, + lengths: List[int], + recmotion_embeddings=None, + ): + self.count += sum(lengths) + self.count_seq += len(lengths) + + # [bs, nlatent*ndim] <= [bs, nlatent, ndim] + if recmotion_embeddings is not None: + recmotion_embeddings = torch.flatten(recmotion_embeddings, + start_dim=1).detach() + # store all texts and motions + self.recmotion_embeddings.append(recmotion_embeddings) + gtmotion_embeddings = torch.flatten(gtmotion_embeddings, + start_dim=1).detach() + + self.gtmotion_embeddings.append(gtmotion_embeddings) diff --git a/Evaluator_272/mld/models/metrics/utils.py b/Evaluator_272/mld/models/metrics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..955e5705732abbe7a1c1d050f72cab8efed26999 --- /dev/null +++ b/Evaluator_272/mld/models/metrics/utils.py @@ -0,0 +1,644 @@ +import numpy as np +import scipy.linalg +import torch +from torch import linalg +import sys + + +def l2_norm(x1, x2, dim): + return torch.linalg.vector_norm(x1 - x2, ord=2, dim=dim) + + +def variance(x, T, dim): + mean = x.mean(dim) + out = (x - mean)**2 + out = out.sum(dim) + return out / (T - 1) + + +def sqrtm(input): + m = input.detach().cpu().numpy().astype(np.float64_) + sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m)).to(input) + return sqrtm + + +# (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train +def euclidean_distance_matrix(matrix1, matrix2): + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dist: N1 x N2 + dist[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * torch.mm(matrix1, matrix2.T) # shape (num_test, num_train) + d2 = torch.sum(torch.square(matrix1), axis=1, + keepdims=True) # shape (num_test, 1) + d3 = torch.sum(torch.square(matrix2), axis=1) # shape (num_train, ) + dists = torch.sqrt(d1 + d2 + d3) # broadcasting + return dists + + +def euclidean_distance_matrix_np(matrix1, matrix2): + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dist: N1 x N2 + dist[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) + d2 = np.sum(np.square(matrix1), axis=1, + keepdims=True) # shape (num_test, 1) + d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) + dists = np.sqrt(d1 + d2 + d3) # broadcasting + return dists + + +def calculate_top_k(mat, top_k): + size = mat.shape[0] + gt_mat = (torch.unsqueeze(torch.arange(size), + 1).to(mat.device).repeat_interleave(size, 1)) + bool_mat = mat == gt_mat + correct_vec = False + top_k_list = [] + for i in range(top_k): + # print(correct_vec, bool_mat[:, i]) + correct_vec = correct_vec | bool_mat[:, i] + # print(correct_vec) + top_k_list.append(correct_vec[:, None]) + top_k_mat = torch.cat(top_k_list, dim=1) + return top_k_mat + + +def calculate_activation_statistics(activations): + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + activations = activations.cpu().numpy() + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return mu, sigma + + +def calculate_activation_statistics_np(activations): + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + mu = np.mean(activations, axis=0) + cov = np.cov(activations, rowvar=False) + return mu, cov + + +# def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): +# """Numpy implementation of the Frechet Distance. +# The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) +# and X_2 ~ N(mu_2, C_2) is +# d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). +# Stable version by Dougal J. Sutherland. +# Params: +# -- mu1 : Numpy array containing the activations of a layer of the +# inception net (like returned by the function 'get_predictions') +# for generated samples. +# -- mu2 : The sample mean over activations, precalculated on an +# representative data set. +# -- sigma1: The covariance matrix over activations for generated samples. +# -- sigma2: The covariance matrix over activations, precalculated on an +# representative data set. +# Returns: +# -- : The Frechet Distance. +# """ + +# mu1 = torch.atleast_1d(mu1) +# mu2 = torch.atleast_1d(mu2) + +# sigma1 = torch.atleast_2d(sigma1) +# sigma2 = torch.atleast_2d(sigma2) + +# assert mu1.shape == mu2.shape, \ +# 'Training and test mean vectors have different lengths' +# assert sigma1.shape == sigma2.shape, \ +# 'Training and test covariances have different dimensions' + +# diff = mu1 - mu2 + +# # Product might be almost singular +# # covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False) +# covmean = sqrtm(torch.mm(sigma1,sigma2)) +# if not torch.isfinite(covmean).all(): +# msg = ('fid calculation produces singular product; ' +# 'adding %s to diagonal of cov estimates') % eps +# print(msg) +# offset = torch.eye(sigma1.shape[0]) * eps +# # covmean = sqrtm((sigma1 + offset).dot(sigma2 + offset)) +# covmean = sqrtm(torch.mm(sigma1+ offset,sigma2+ offset)) + +# # Numerical error might give slight imaginary component +# if torch.is_complex(covmean): +# if not torch.allclose(torch.diagonal(covmean).imag, 0, atol=1e-3): +# m = torch.max(torch.abs(covmean.imag)) +# raise ValueError('Imaginary component {}'.format(m)) +# covmean = covmean.real + +# tr_covmean = torch.trace(covmean) + +# return (diff.dot(diff) + torch.trace(sigma1) + +# torch.trace(sigma2) - 2 * tr_covmean) + + +def calculate_frechet_distance_np(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert (mu1.shape == mu2.shape + ), "Training and test mean vectors have different lengths" + assert (sigma1.shape == sigma2.shape + ), "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + # Product might be almost singular + covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ("fid calculation produces singular product; " + "adding %s to diagonal of cov estimates") % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + # print("Imaginary component {}".format(m)) + covmean = covmean.real + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace( + sigma2) - 2 * tr_covmean + + +def calculate_diversity(activation, diversity_times): + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + second_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + dist = linalg.norm(activation[first_indices] - activation[second_indices], + axis=1) + return dist.mean() + + +def calculate_diversity_np(activation, diversity_times): + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + second_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + dist = scipy.linalg.norm(activation[first_indices] - + activation[second_indices], + axis=1) + return dist.mean() + + +def calculate_multimodality_np(activation, multimodality_times): + assert len(activation.shape) == 3 + assert activation.shape[1] > multimodality_times + num_per_sent = activation.shape[1] + + first_dices = np.random.choice(num_per_sent, + multimodality_times, + replace=False) + second_dices = np.random.choice(num_per_sent, + multimodality_times, + replace=False) + dist = scipy.linalg.norm(activation[:, first_dices] - + activation[:, second_dices], + axis=2) + return dist.mean() + + +# motion reconstructions metrics + + +def batch_compute_similarity_transform_torch(S1, S2): + """ + Computes a similarity transform (sR, t) that takes + a set of 3D points S1 (3 x N) closest to a set of 3D points S2, + where R is an 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + """ + transposed = False + if S1.shape[0] != 3 and S1.shape[0] != 2: + S1 = S1.permute(0, 2, 1) + S2 = S2.permute(0, 2, 1) + transposed = True + assert S2.shape[1] == S1.shape[1] + + # 1. Remove mean. + mu1 = S1.mean(axis=-1, keepdims=True) + mu2 = S2.mean(axis=-1, keepdims=True) + + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = torch.sum(X1**2, dim=1).sum(dim=1) + + # 3. The outer product of X1 and X2. + K = X1.bmm(X2.permute(0, 2, 1)) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, s, V = torch.svd(K) + + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) + Z = Z.repeat(U.shape[0], 1, 1) + Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1)))) + + # Construct R. + R = V.bmm(Z.bmm(U.permute(0, 2, 1))) + + # 5. Recover scale. + scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 + + # 6. Recover translation. + t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) + + # 7. Error: + S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t + + if transposed: + S1_hat = S1_hat.permute(0, 2, 1) + + return S1_hat, (scale, R, t) + + +def compute_mpjpe(preds, + target, + valid_mask=None, + pck_joints=None, + sample_wise=True): + """ + Mean per-joint position error (i.e. mean Euclidean distance) + often referred to as "Protocol #1" in many papers. + """ + assert preds.shape == target.shape, print(preds.shape, + target.shape) # BxJx3 + mpjpe = torch.norm(preds - target, p=2, dim=-1) # BxJ + + if pck_joints is None: + if sample_wise: + mpjpe_seq = ((mpjpe * valid_mask.float()).sum(-1) / + valid_mask.float().sum(-1) + if valid_mask is not None else mpjpe.mean(-1)) + else: + mpjpe_seq = mpjpe[valid_mask] if valid_mask is not None else mpjpe + return mpjpe_seq + else: + mpjpe_pck_seq = mpjpe[:, pck_joints] + return mpjpe_pck_seq + + +def align_by_parts(joints, align_inds=None): + if align_inds is None: + return joints + pelvis = joints[:, align_inds].mean(1) + return joints - torch.unsqueeze(pelvis, dim=1) + +def align_by_parts_hand(joints, align_inds=None): + if align_inds is None: + return joints + + align_inds_left_wrist = align_inds[0] + align_inds_right_wrist = align_inds[1] + + left_wrist = joints[:, align_inds_left_wrist].mean(1) + right_wrist = joints[:, align_inds_right_wrist].mean(1) + result = torch.cat((joints[...,22:22+15,:] - torch.unsqueeze(left_wrist, dim=1), joints[...,22+15:,:] - torch.unsqueeze(right_wrist, dim=1)), dim=-2) + return result + +def calc_mpjpe(preds, target, align_inds=[0], sample_wise=True, trans=None): + # Expects BxJx3 + valid_mask = target[:, :, 0] != -2.0 + + + if align_inds is not None: + preds_aligned = align_by_parts(preds, align_inds=align_inds) + if trans is not None: + preds_aligned += trans + target_aligned = align_by_parts(target, align_inds=align_inds) + else: + preds_aligned, target_aligned = preds, target + mpjpe_each = compute_mpjpe(preds_aligned, + target_aligned, + valid_mask=valid_mask, + sample_wise=sample_wise) + return mpjpe_each + + +def calc_mpjpe_hand(preds, target, align_inds, sample_wise=True, trans=None): + assert len(align_inds) == 2 + # align_inds_left_wrist = align_inds[0] + # align_inds_right_wrist = align_inds[1] + # Expects BxJx3 + valid_mask = target[:, :, 0] != -2.0 + # valid_mask = torch.BoolTensor(target[:, :, 0].shape) + if align_inds is not None: + preds_aligned = align_by_parts_hand(preds, align_inds=align_inds) + if trans is not None: + preds_aligned += trans + target_aligned = align_by_parts_hand(target, align_inds=align_inds) + else: + preds_aligned, target_aligned = preds, target + + + mpjpe_each = compute_mpjpe(preds_aligned[...,-30:,:], + target_aligned[...,-30:,:], + valid_mask=valid_mask[..., -30:], + sample_wise=sample_wise) + return mpjpe_each + + + + +def calc_accel(preds, target): + """ + Mean joint acceleration error + often referred to as "Protocol #1" in many papers. + """ + assert preds.shape == target.shape, print(preds.shape, + target.shape) # BxJx3 + assert preds.dim() == 3 + # Expects BxJx3 + # valid_mask = torch.BoolTensor(target[:, :, 0].shape) + accel_gt = target[:-2] - 2 * target[1:-1] + target[2:] + accel_pred = preds[:-2] - 2 * preds[1:-1] + preds[2:] + normed = torch.linalg.norm(accel_pred - accel_gt, dim=-1) + accel_seq = normed.mean(1) + return accel_seq + + +def calc_pampjpe(preds, target, sample_wise=True, return_transform_mat=False): + # Expects BxJx3 + target, preds = target.float(), preds.float() + # extracting the keypoints that all samples have valid annotations + # valid_mask = (target[:, :, 0] != -2.).sum(0) == len(target) + # preds_tranformed, PA_transform = batch_compute_similarity_transform_torch(preds[:, valid_mask], target[:, valid_mask]) + # pa_mpjpe_each = compute_mpjpe(preds_tranformed, target[:, valid_mask], sample_wise=sample_wise) + + preds_tranformed, PA_transform = batch_compute_similarity_transform_torch( + preds, target) + pa_mpjpe_each = compute_mpjpe(preds_tranformed, + target, + sample_wise=sample_wise) + + if return_transform_mat: + return pa_mpjpe_each, PA_transform + else: + return pa_mpjpe_each + + +# from action2motion +def calculate_diversity_multimodality(activations, + labels, + num_labels, + diversity_times=200, + multimodality_times=20): + labels = labels.long() + num_motions = activations.shape[0] # len(labels) + + diversity = 0 + + first_indices = np.random.randint(0, num_motions, diversity_times) + second_indices = np.random.randint(0, num_motions, diversity_times) + for first_idx, second_idx in zip(first_indices, second_indices): + diversity += torch.dist(activations[first_idx, :], + activations[second_idx, :]) + diversity /= diversity_times + + multimodality = 0 + label_quotas = np.zeros(num_labels) + label_quotas[labels.unique( + )] = multimodality_times # if a label does not appear in batch, its quota remains zero + while np.any(label_quotas > 0): + # print(label_quotas) + first_idx = np.random.randint(0, num_motions) + first_label = labels[first_idx] + if not label_quotas[first_label]: + continue + + second_idx = np.random.randint(0, num_motions) + second_label = labels[second_idx] + while first_label != second_label: + second_idx = np.random.randint(0, num_motions) + second_label = labels[second_idx] + + label_quotas[first_label] -= 1 + + first_activation = activations[first_idx, :] + second_activation = activations[second_idx, :] + multimodality += torch.dist(first_activation, second_activation) + + multimodality /= (multimodality_times * num_labels) + + return diversity, multimodality + + +def calculate_fid(statistics_1, statistics_2): + return calculate_frechet_distance_np(statistics_1[0], statistics_1[1], + statistics_2[0], statistics_2[1]) + + +# from: https://github.com/abdulfatir/gan-metrics-pytorch/blob/master/kid_score.py +def polynomial_mmd_averages(codes_g, + codes_r, + n_subsets=50, + subset_size=1000, + ret_var=True, + output=sys.stdout, + **kernel_args): + m = min(codes_g.shape[0], codes_r.shape[0]) + mmds = np.zeros(n_subsets) + if ret_var: + vars = np.zeros(n_subsets) + choice = np.random.choice + + replace = subset_size < len(codes_g) + + for i in range(n_subsets): + g = codes_g[choice(len(codes_g), subset_size, replace=replace)] + r = codes_r[choice(len(codes_r), subset_size, replace=replace)] + o = polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var) + if ret_var: + mmds[i], vars[i] = o + else: + mmds[i] = o + + return (mmds, vars) if ret_var else mmds + + +def polynomial_mmd(codes_g, + codes_r, + degree=3, + gamma=None, + coef0=1, + var_at_m=None, + ret_var=True): + from sklearn.metrics.pairwise import polynomial_kernel + + # use k(x, y) = (gamma + coef0)^degree + # default gamma is 1 / dim + X = codes_g + Y = codes_r + + K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0) + K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0) + K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0) + + return _mmd2_and_variance(K_XX, + K_XY, + K_YY, + var_at_m=var_at_m, + ret_var=ret_var) + + +def _mmd2_and_variance(K_XX, + K_XY, + K_YY, + unit_diagonal=False, + mmd_est='unbiased', + block_size=1024, + var_at_m=None, + ret_var=True): + # based on + # https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py + # but changed to not compute the full kernel matrix at once + m = K_XX.shape[0] + assert K_XX.shape == (m, m) + assert K_XY.shape == (m, m) + assert K_YY.shape == (m, m) + if var_at_m is None: + var_at_m = m + + # Get the various sums of kernels that we'll use + # Kts drop the diagonal, but we don't need to compute them explicitly + if unit_diagonal: + diag_X = diag_Y = 1 + sum_diag_X = sum_diag_Y = m + sum_diag2_X = sum_diag2_Y = m + else: + diag_X = np.diagonal(K_XX) + diag_Y = np.diagonal(K_YY) + + sum_diag_X = diag_X.sum() + sum_diag_Y = diag_Y.sum() + + sum_diag2_X = _sqn(diag_X) + sum_diag2_Y = _sqn(diag_Y) + + Kt_XX_sums = K_XX.sum(axis=1) - diag_X + Kt_YY_sums = K_YY.sum(axis=1) - diag_Y + K_XY_sums_0 = K_XY.sum(axis=0) + K_XY_sums_1 = K_XY.sum(axis=1) + + Kt_XX_sum = Kt_XX_sums.sum() + Kt_YY_sum = Kt_YY_sums.sum() + K_XY_sum = K_XY_sums_0.sum() + + if mmd_est == 'biased': + mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) + (Kt_YY_sum + sum_diag_Y) / + (m * m) - 2 * K_XY_sum / (m * m)) + else: + assert mmd_est in {'unbiased', 'u-statistic'} + mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1)) + if mmd_est == 'unbiased': + mmd2 -= 2 * K_XY_sum / (m * m) + else: + mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m - 1)) + + if not ret_var: + return mmd2 + + Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X + Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y + K_XY_2_sum = _sqn(K_XY) + + dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1) + dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0) + + m1 = m - 1 + m2 = m - 2 + zeta1_est = ( + 1 / (m * m1 * m2) * + (_sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum) - 1 / + (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2) + 1 / (m * m * m1) * + (_sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum) - + 2 / m**4 * K_XY_sum**2 - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) + + 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum) + zeta2_est = (1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum) - 1 / (m * m1)**2 * + (Kt_XX_sum**2 + Kt_YY_sum**2) + 2 / (m * m) * K_XY_2_sum - + 2 / m**4 * K_XY_sum**2 - 4 / (m * m * m1) * + (dot_XX_XY + dot_YY_YX) + 4 / (m**3 * m1) * + (Kt_XX_sum + Kt_YY_sum) * K_XY_sum) + var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est + + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est) + + return mmd2, var_est + + +def _sqn(arr): + flat = np.ravel(arr) + return flat.dot(flat) + + +def calculate_kid(real_activations, generated_activations): + kid_values = polynomial_mmd_averages(real_activations, + generated_activations, + n_subsets=100) + results = (kid_values[0].mean(), kid_values[0].std()) + return results diff --git a/Evaluator_272/mld/models/modeltype/__init__.py b/Evaluator_272/mld/models/modeltype/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/models/modeltype/base.py b/Evaluator_272/mld/models/modeltype/base.py new file mode 100644 index 0000000000000000000000000000000000000000..493119bf6bee4167f99455c45404adbc0e183e48 --- /dev/null +++ b/Evaluator_272/mld/models/modeltype/base.py @@ -0,0 +1,424 @@ +import os +from pathlib import Path +import numpy as np +import torch +from pytorch_lightning import LightningModule +# from mld.models.metrics import ComputeMetrics, MRMetrics, TM2TMetrics, TM2TMetrics_R256, MMMetrics, HUMANACTMetrics, UESTCMetrics, UncondMetrics, ComputeMetrics_body_hand, MRMetrics_body_hand, ACCMetrics, TMR_TM2TMetrics +from mld.models.metrics import TMR_TM2TMetrics +from os.path import join as pjoin +from collections import OrderedDict + + +class BaseModel(LightningModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.times = [] + + def __post_init__(self): + trainable, nontrainable = 0, 0 + for p in self.parameters(): + if p.requires_grad: + trainable += np.prod(p.size()) + else: + nontrainable += np.prod(p.size()) + + self.hparams.n_params_trainable = trainable + self.hparams.n_params_nontrainable = nontrainable + + def training_step(self, batch, batch_idx): + return self.allsplit_step("train", batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.allsplit_step("val", batch, batch_idx) + + def test_step(self, batch, batch_idx): + if len(self.times) *self.cfg.TEST.BATCH_SIZE % (100) > 0 and len(self.times) > 0: + print(f"Average time per sample ({self.cfg.TEST.BATCH_SIZE*len(self.times)}): ", np.mean(self.times)/self.cfg.TEST.BATCH_SIZE) + return self.allsplit_step("test", batch, batch_idx) + + def predict_step(self, batch, batch_idx): + return self.forward(batch) + + def allsplit_epoch_end(self, split: str, outputs): + dico = {} + + if split in ["train", "val"]: + losses = self.losses[split] + loss_dict = losses.compute(split) + losses.reset() + dico.update({ + losses.loss2logname(loss, split): value.item() + for loss, value in loss_dict.items() if not torch.isnan(value) + }) + + if split in ["val", "test"]: + + if self.trainer.datamodule.is_mm and ("TM2TMetrics" in self.metrics_dict or "TM2TMetrics_R256" in self.metrics_dict): + metrics_dicts = ['MMMetrics'] + else: + metrics_dicts = self.metrics_dict + for metric in metrics_dicts: + metrics_dict = getattr( + self, + metric).compute(sanity_flag=self.trainer.sanity_checking) + # reset metrics + getattr(self, metric).reset() + dico.update({ + f"Metrics/{metric}": value.item() + for metric, value in metrics_dict.items() + }) + if split != "test": + dico.update({ + "epoch": float(self.trainer.current_epoch), + "step": float(self.trainer.current_epoch), + }) + # don't write sanity check into log + if not self.trainer.sanity_checking: + self.log_dict(dico, sync_dist=True, rank_zero_only=True) + + def training_epoch_end(self, outputs): + return self.allsplit_epoch_end("train", outputs) + + def validation_epoch_end(self, outputs): + # # ToDo + # # re-write vislization checkpoint? + # # visualize validation + # parameters = {"xx",xx} + # vis_path = viz_epoch(self, dataset, epoch, parameters, module=None, + # folder=parameters["folder"], writer=None, exps=f"_{dataset_val.dataset_name}_"+val_set) + return self.allsplit_epoch_end("val", outputs) + + def test_epoch_end(self, outputs): + self.save_npy(outputs) + self.cfg.TEST.REP_I = self.cfg.TEST.REP_I + 1 + + return self.allsplit_epoch_end("test", outputs) + + def on_save_checkpoint(self, checkpoint): + # don't save clip to checkpoint + state_dict = checkpoint['state_dict'] + clip_k = [] + for k, v in state_dict.items(): + if 'text_encoder' in k: + clip_k.append(k) + for k in clip_k: + del checkpoint['state_dict'][k] + + def on_load_checkpoint(self, checkpoint): + # restore clip state_dict to checkpoint + clip_state_dict = self.text_encoder.state_dict() + new_state_dict = OrderedDict() + for k, v in clip_state_dict.items(): + new_state_dict['text_encoder.' + k] = v + for k, v in checkpoint['state_dict'].items(): + if 'text_encoder' not in k: + new_state_dict[k] = v + checkpoint['state_dict'] = new_state_dict + + def load_state_dict(self, state_dict, strict=True): + # load clip state_dict to checkpoint + if hasattr(self, 'text_encoder'): + clip_state_dict = self.text_encoder.state_dict() + new_state_dict = OrderedDict() + for k, v in clip_state_dict.items(): + new_state_dict['text_encoder.' + k] = v + for k, v in state_dict.items(): + if 'text_encoder' not in k: + new_state_dict[k] = v + else: + new_state_dict = state_dict + + super().load_state_dict(new_state_dict, strict) + + def configure_optimizers(self): + return {"optimizer": self.optimizer} + + def configure_metrics(self): + for metric in self.metrics_dict: + if metric == "TemosMetric": + self.TemosMetric = ComputeMetrics( + njoints=self.njoints, + jointstype=self.cfg.DATASET.JOINT_TYPE, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + + elif metric == "TemosMetric_body_hand": + self.TemosMetric_body_hand = ComputeMetrics_body_hand( + njoints=self.njoints, + jointstype=self.cfg.DATASET.JOINT_TYPE, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + + elif metric == "TM2TMetrics": + self.TM2TMetrics = TM2TMetrics( + diversity_times=30 + if self.debug else self.cfg.TEST.DIVERSITY_TIMES, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + elif metric == 'TM2TMetrics_R256': + self.TM2TMetrics_R256 = TM2TMetrics_R256( + diversity_times=30 + if self.debug else self.cfg.TEST.DIVERSITY_TIMES, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + elif metric == 'TMR_TM2TMetrics': + self.TMR_TM2TMetrics = TMR_TM2TMetrics( + diversity_times=30 + if self.debug else self.cfg.TEST.DIVERSITY_TIMES, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + elif metric == "MRMetrics": + self.MRMetrics = MRMetrics( + njoints=self.njoints, + jointstype=self.cfg.DATASET.JOINT_TYPE, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + + elif metric == "MRMetrics_body_hand": + self.MRMetrics_body_hand = MRMetrics_body_hand( + njoints=self.njoints, + jointstype=self.cfg.DATASET.JOINT_TYPE, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + + elif metric == "HUMANACTMetrics": + self.HUMANACTMetrics = HUMANACTMetrics( + datapath=os.path.join(self.cfg.model.humanact12_rec_path, + "humanact12_gru.tar"), + diversity_times=30 + if self.debug else self.cfg.TEST.DIVERSITY_TIMES, + multimodality_times=self.cfg.TEST.MM_NUM_TIMES, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + elif metric == "UESTCMetrics": + self.UESTCMetrics = UESTCMetrics( + cfg=self.cfg, + diversity_times=30 + if self.debug else self.cfg.TEST.DIVERSITY_TIMES, + multimodality_times=self.cfg.TEST.MM_NUM_TIMES, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + elif metric == "UncondMetrics": + self.UncondMetrics = UncondMetrics( + diversity_times=30 + if self.debug else self.cfg.TEST.DIVERSITY_TIMES, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + elif metric == "ACCMetrics": + self.ACCMetrics = ACCMetrics(dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP) + else: + raise NotImplementedError( + f"Do not support Metric Type {metric}") + if "TM2TMetrics" in self.metrics_dict or "UncondMetrics" in self.metrics_dict or "TM2TMetrics_R256" in self.metrics_dict: + self.MMMetrics = MMMetrics( + mm_num_times=self.cfg.TEST.MM_NUM_TIMES, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + + def save_npy(self, outputs): + + cfg = self.cfg + output_dir = Path( + os.path.join( + cfg.FOLDER, + str(cfg.model.model_type), + str(cfg.NAME), + "samples", + )) + + if cfg.TEST.SAVE_PREDICTIONS and cfg.TEST.REP_I + 1 == cfg.TEST.REPLICATION_TIMES: + if cfg.TEST.inference_vq_code: + if self.vae_type in ["hvq", "hvq_body_hand"]: + name = [i[2] for i in outputs] + motion_code_t = [i[0] for i in outputs] + motion_code_b = [i[1] for i in outputs] + else: + name = [i[1] for i in outputs] + outputs = [i[0] for i in outputs] + + else: + if cfg.DATASET.MOTION_TYPE == 'vector_263': + lengths = [i[1] for i in outputs] + texts = [i[2] for i in outputs] + outputs = [i[0] for i in outputs] + elif cfg.DATASET.MOTION_TYPE == 'smplx_212': + if cfg.TRAIN.use_joints: + lengths = [i[1] for i in outputs] + gen_motions = [self.datamodule.renormt2m_back(i[0]) for i in outputs] + ref_motions = [self.datamodule.renormt2m_back(i[2]) for i in outputs] + else: + return + elif cfg.DATASET.MOTION_TYPE in ['ric_rot']: + lengths = [i[1] for i in outputs] + gen_motions = [i[0] for i in outputs] + ref_motions = [i[2] for i in outputs] + else: + raise NotImplementedError + + if cfg.TEST.DATASETS[0].lower() in ["humanml3d", "kit"]: + if cfg.TEST.inference_vq_code: + for i in range(len(outputs)): + if self.vae_type in ["hvq", "hvq_body_hand"]: + for bid in range( + min(cfg.TEST.BATCH_SIZE, motion_code_t[i].shape[0])): + + motion_vqcode_t = motion_code_t[i][bid].cpu().numpy()[None, :] + motion_vqcode_b = motion_code_b[i][bid].cpu().numpy()[None, :] + motion_name = name[i][bid] + + assert cfg.TEST.REPLICATION_TIMES == 1 + + motion_name = f"{motion_name}.npy" + output_dir_t = Path( + os.path.join(f'./datasets/{cfg.TEST.DATASETS[0]}/vq_tokens', str(cfg.model.vae_type), 'motion_vqcode_t')) + output_dir_b = Path( + os.path.join(f'./datasets/{cfg.TEST.DATASETS[0]}/vq_tokens', str(cfg.model.vae_type), 'motion_vqcode_b')) + # save predictions results + npypath_t = output_dir_t / motion_name + npypath_b = output_dir_b / motion_name + + np.save(npypath_t, motion_vqcode_t) + np.save(npypath_b, motion_vqcode_b) + + + + else: + for bid in range( + min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])): + motion_vqcode = outputs[i][bid].cpu().numpy()[None, :] + motion_name = name[i][bid] + + assert cfg.TEST.REPLICATION_TIMES == 1 + + motion_name = f"{motion_name}.npy" + output_dir = Path( + os.path.join(f'./datasets/{cfg.TEST.DATASETS[0]}/vq_tokens', str(cfg.model.vae_type))) + # save predictions results + npypath = output_dir / motion_name + np.save(npypath, motion_vqcode) + + + else: + keyids = self.trainer.datamodule.test_dataset.name_list + for i in range(len(outputs)): + for bid in range( + min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])): + keyid = keyids[i * cfg.TEST.BATCH_SIZE + bid] + gen_joints = outputs[i][bid].cpu().numpy() + text = texts[i][bid] + + if cfg.TEST.REPLICATION_TIMES > 1: + name = f"{keyid}_{cfg.TEST.REP_I}" + else: + name = f"{keyid}.npy" + # save predictions results + npypath = output_dir / name + np.save(npypath, gen_joints) + + textpath = output_dir / 'text' / (name + '.txt') + os.makedirs(os.path.split(textpath)[0], exist_ok=True) + with open(textpath, "w") as f: + f.write(text) + elif cfg.TEST.DATASETS[0].lower() in ["humanact12", "uestc"]: + keyids = range(len(self.trainer.datamodule.test_dataset)) + for i in range(len(outputs)): + for bid in range( + min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])): + keyid = keyids[i * cfg.TEST.BATCH_SIZE + bid] + gen_joints = outputs[i][bid].cpu() + gen_joints = gen_joints.permute(2, 0, + 1)[:lengths[i][bid], + ...].numpy() + if cfg.TEST.REPLICATION_TIMES > 1: + name = f"{keyid}_{cfg.TEST.REP_I}" + else: + name = f"{keyid}.npy" + # save predictions results + npypath = output_dir / name + np.save(npypath, gen_joints) + elif cfg.TEST.DATASETS[0].lower() in ["motionx", 'motionx_v26']: + + + if cfg.TEST.inference_vq_code: + for i in range(len(outputs)): + if self.vae_type in ["hvq", "hvq_body_hand"]: + for bid in range( + min(cfg.TEST.BATCH_SIZE, motion_code_t[i].shape[0])): + motion_vqcode_t = motion_code_t[i][bid].cpu().numpy()[None, :] + motion_vqcode_b = motion_code_b[i][bid].cpu().numpy()[None, :] + motion_name = name[i][bid] + + assert cfg.TEST.REPLICATION_TIMES == 1 + + motion_name = f"{motion_name}.npy" + if cfg.TEST.DATASETS[0].lower() == 'motionx_v26': + output_dir_t = Path( + os.path.join(f'./datasets/Motion-X-V26/vq_tokens', str(cfg.model.vae_type), 'motion_vqcode_t')) + output_dir_b = Path( + os.path.join(f'./datasets/Motion-X-V26/vq_tokens', str(cfg.model.vae_type), 'motion_vqcode_b')) + elif cfg.TEST.DATASETS[0].lower() == 'motionx': + output_dir_t = Path( + os.path.join(f'./datasets/Motion-X/vq_tokens', str(cfg.model.vae_type), 'motion_vqcode_t')) + output_dir_b = Path( + os.path.join(f'./datasets/Motion-X/vq_tokens', str(cfg.model.vae_type), 'motion_vqcode_b')) + else: + raise NotImplementedError + # save predictions results + + npypath_t = output_dir_t / motion_name + npypath_b = output_dir_b / motion_name + + npypath_t_ref_parent_directory = os.path.dirname(npypath_t) + if not os.path.exists(npypath_t_ref_parent_directory): + os.makedirs(npypath_t_ref_parent_directory) + + npypath_b_parent_directory = os.path.dirname(npypath_b) + if not os.path.exists(npypath_b_parent_directory): + os.makedirs(npypath_b_parent_directory) + + np.save(npypath_t, motion_vqcode_t) + np.save(npypath_b, motion_vqcode_b) + + + + + else: + for bid in range( + min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])): + motion_vqcode = outputs[i][bid].cpu().numpy()[None, :] + motion_name = name[i][bid] + + assert cfg.TEST.REPLICATION_TIMES == 1 + + motion_name = f"{motion_name}.npy" + output_dir = Path( + os.path.join(f'./datasets/Motion-X/vq_tokens', str(cfg.model.vae_type))) + # save predictions results + + npypath = output_dir / motion_name + npypath_parent_directory = os.path.dirname(npypath) + if not os.path.exists(npypath_parent_directory): + os.makedirs(npypath_parent_directory) + np.save(npypath, motion_vqcode) + + + + else: + + keyids = self.trainer.datamodule.test_dataset.name_list + for i in range(len(gen_motions)): + for bid in range( + min(cfg.TEST.BATCH_SIZE, gen_motions[i].shape[0])): + keyid = keyids[i * cfg.TEST.BATCH_SIZE + bid] + gen_joints = gen_motions[i][bid].cpu().numpy() + ref_joints = ref_motions[i][bid].cpu().numpy() + + gen_name = f"{keyid}.npy" + ref_name = f"{keyid}_gt.npy" + # save predictions results + npypath = output_dir / gen_name + os.makedirs(os.path.split(npypath)[0], exist_ok=True) + np.save(npypath, gen_joints) + + diff --git a/Evaluator_272/mld/models/modeltype/mld.py b/Evaluator_272/mld/models/modeltype/mld.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc1d99837201a4271b74d47696f2262fd9cd14e --- /dev/null +++ b/Evaluator_272/mld/models/modeltype/mld.py @@ -0,0 +1,3001 @@ +import inspect +import os +from mld.transforms.rotation2xyz import Rotation2xyz +import numpy as np +import torch +from torch import Tensor +from torch.optim import AdamW +from torchmetrics import MetricCollection +import time +from mld.config import instantiate_from_config +from os.path import join as pjoin +from mld.models.architectures import ( + mld_denoiser, + mld_dual_vae, + mld_vae, + vposert_vae, + t2m_motionenc, + t2m_textenc, + vposert_vae, +) +from mld.models.losses.mld import MLDLosses, MLDLosses_no_joint +from mld.models.losses.vqvae import VQVAELosses +from mld.models.modeltype.base import BaseModel +from mld.utils.temos_utils import remove_padding + +from mld.models.architectures.temos.textencoder.distillbert_actor import DistilbertActorAgnosticEncoder +from mld.models.architectures.temos.motionencoder.actor import ActorAgnosticEncoder + +from .base import BaseModel +from .smplx_layer import smplx_layer + +from ..body_skeleton.skeleton import Skeleton +from ..body_skeleton.paramUtil import * + +from collections import OrderedDict +from sentence_transformers import SentenceTransformer + +import copy + + +class MLD(BaseModel): + """ + Stage 1 vae + Stage 2 diffusion + """ + + def __init__(self, cfg, datamodule, **kwargs): + super().__init__() + + self.cfg = cfg + self.stage = cfg.TRAIN.STAGE + self.condition = cfg.model.condition + self.is_vae = cfg.model.vae + self.predict_epsilon = cfg.TRAIN.ABLATION.PREDICT_EPSILON + self.nfeats = cfg.DATASET.NFEATS + self.njoints = cfg.DATASET.NJOINTS + self.debug = cfg.DEBUG + self.latent_dim = cfg.model.latent_dim + self.guidance_scale = cfg.model.guidance_scale + self.guidance_uncodp = cfg.model.guidance_uncondp + self.datamodule = datamodule + + if 'MINOR_MOTION_TYPE' in cfg.DATASET: + self.input_format = cfg.DATASET.MINOR_MOTION_TYPE + else: + self.input_format = cfg.DATASET.MOTION_TYPE + + self.motion_type = cfg.DATASET.MOTION_TYPE + + self.eval_on_text = cfg.EVAL.eval_on_text + # + try: + self.vae_type = cfg.model.vae_type + except: + self.vae_type = cfg.model.motion_vae.target.split( + ".")[-1].lower().replace("vae", "") + + self.text_encoder = instantiate_from_config(cfg.model.text_encoder) + + self.smplx_model = smplx_layer() + + self.smplx_model.eval() + for p in self.smplx_model.parameters(): + p.requires_grad = False + + if self.vae_type != "no": + # + self.vae = instantiate_from_config(cfg.model.motion_vae) + + # Don't train the motion encoder and decoder + if self.stage == "diffusion": + if self.vae_type in ["mld", "vposert", "actor", "humanvq"]: + self.vae.training = False + for p in self.vae.parameters(): + p.requires_grad = False + elif self.vae_type == "no": + pass + else: + self.motion_encoder.training = False + for p in self.motion_encoder.parameters(): + p.requires_grad = False + self.motion_decoder.training = False + for p in self.motion_decoder.parameters(): + p.requires_grad = False + + self.denoiser = instantiate_from_config(cfg.model.denoiser) + if not self.predict_epsilon: + cfg.model.scheduler.params['prediction_type'] = 'sample' + cfg.model.noise_scheduler.params['prediction_type'] = 'sample' + self.scheduler = instantiate_from_config(cfg.model.scheduler) + self.noise_scheduler = instantiate_from_config( + cfg.model.noise_scheduler) + + if cfg.EVAL.eval_on_text: + if self.condition in ["text", "text_uncond", 'text_all', 'text_face', 'text_body', 'text_hand', 'text_face_body', 'text_seperate', 'only_pose_concat', 'only_pose_fusion']: + self._get_t2m_evaluator(cfg) + + if cfg.EVAL.use_tmr_eval: + if self.condition in ["text", "text_uncond", 'text_all', 'text_face', 'text_body', 'text_hand', 'text_face_body', 'text_seperate', 'only_pose_concat', 'only_pose_fusion']: + self._get_tmr_t2m_evaluator(cfg) + + if cfg.TRAIN.OPTIM.TYPE.lower() == "adamw": + self.optimizer = AdamW(lr=cfg.TRAIN.OPTIM.LR, + params=self.parameters()) + else: + raise NotImplementedError( + "Do not support other optimizer for now.") + + if cfg.LOSS.TYPE == "mld": + # assert cfg.DATASET.MOTION_TYPE in ['vector_263', 'root_position'] + self._losses = MetricCollection({ + split: MLDLosses(vae=self.is_vae, mode="xyz", cfg=cfg) + for split in ["losses_train", "losses_test", "losses_val"] + }) + + elif cfg.LOSS.TYPE == "vqvae": + + self._losses = MetricCollection({ + split: VQVAELosses(vae=self.is_vae, mode="xyz", cfg=cfg) + for split in ["losses_train", "losses_test", "losses_val"] + }) + + elif cfg.LOSS.TYPE == 'mld_no_joint': + # assert 'smpl' not in cfg.DATASET.MOTION_TYPE + self._losses = MetricCollection({ + split: MLDLosses_no_joint(vae=self.is_vae, mode="xyz", cfg=cfg) + for split in ["losses_train", "losses_test", "losses_val"] + }) + + else: + raise NotImplementedError( + "MotionCross model only supports mld losses.") + + # if cfg.LOSS.TYPE == 'mld_no_joint': + # assert cfg.TRAIN.use_joints == False + + self.losses = { + key: self._losses["losses_" + key] + for key in ["train", "test", "val"] + } + + self.metrics_dict = cfg.METRIC.TYPE + self.configure_metrics() + + # If we want to overide it at testing time + + if eval("self.cfg.TRAIN.DATASETS")[0].lower() == 'humanml3d': + n_raw_offsets = torch.from_numpy(t2m_raw_offsets) + kinematic_chain = t2m_kinematic_chain + elif eval("self.cfg.TRAIN.DATASETS")[0].lower() == 'kit': + n_raw_offsets = torch.from_numpy(kit_raw_offsets) + kinematic_chain = kit_kinematic_chain + elif eval("self.cfg.TRAIN.DATASETS")[0].lower() in ['motionx', 'motionx_v25', 'motionx_v26']: + n_raw_offsets = torch.from_numpy(t2m_raw_body_hand_offsets) + body_raw_offsets = n_raw_offsets[:22] + hand_raw_offsets = n_raw_offsets[22:] + kinematic_chain = t2m_body_hand_kinematic_chain + body_kinemantic_chain = t2m_kinematic_chain + hand_kinemantic_chain = t2m_left_hand_chain + t2m_right_hand_chain + else: + raise NotImplementedError + + + self.skel=None + if self.input_format in ['root_rot6d']: + example_data = np.load(os.path.join('./HumanML3D-1/joints', '000021' + '.npy')) + example_data = example_data.reshape(len(example_data), -1, 3) + example_data = torch.from_numpy(example_data) + tgt_skel = Skeleton(n_raw_offsets, kinematic_chain) + # (joints_num, 3) + tgt_offsets = tgt_skel.get_offsets_joints(example_data[0]) + self.skel = Skeleton(n_raw_offsets, kinematic_chain) + self.skel.set_offset(tgt_offsets) + + elif self.input_format in ['root_body_pos_vel_hand_rot']: + + example_data = np.load('./datasets/Motion-X/motion_data/joint/humanml/000021.npy') + example_data = example_data.reshape(len(example_data), -1, 3) + example_data = torch.from_numpy(example_data) + + example_data = example_data[:, :52] + + body_example_data = example_data[:, :22] + tgt_body_skel = Skeleton(body_raw_offsets, body_kinemantic_chain) + + tgt_skel = Skeleton(n_raw_offsets, kinematic_chain) + + # (joints_num, 3) + tgt_body_skel_offsets = tgt_body_skel.get_offsets_joints(body_example_data[0]) + tgt_skel_offsets = tgt_skel.get_offsets_joints(example_data[0]) + + body_skel = Skeleton(body_raw_offsets, body_kinemantic_chain) + all_skel = Skeleton(n_raw_offsets, kinematic_chain) + + body_skel.set_offset(tgt_body_skel_offsets) + all_skel.set_offset(tgt_skel_offsets) + + self.skel = (body_skel, all_skel) + # self.skel.set_offset(tgt_offsets) + + + + + self.sample_mean = False + self.fact = None + self.do_classifier_free_guidance = self.guidance_scale > 1.0 + if self.condition in ['text', 'text_uncond', "text_all", 'text_body', 'text_hand', 'text_face_body', 'text_face', "text_seperate", "only_pose_concat", "only_pose_fusion"]: + self.feats2joints = datamodule.feats2joints + self.renorm2ori = datamodule.renorm2ori + if self.cfg.model.vae_type == 'hvq_body_hand_face': + self.facerenorm2ori = datamodule.facerenorm2ori + elif self.condition == 'action': + self.rot2xyz = Rotation2xyz(smpl_path=cfg.DATASET.SMPL_PATH) + self.feats2joints_eval = lambda sample, mask: self.rot2xyz( + sample.view(*sample.shape[:-1], 6, 25).permute(0, 3, 2, 1), + mask=mask, + pose_rep='rot6d', + glob=True, + translation=True, + jointstype='smpl', + vertstrans=True, + betas=None, + beta=0, + glob_rot=None, + get_rotations_back=False) + self.feats2joints = lambda sample, mask: self.rot2xyz( + sample.view(*sample.shape[:-1], 6, 25).permute(0, 3, 2, 1), + mask=mask, + pose_rep='rot6d', + glob=True, + translation=True, + jointstype='vertices', + vertstrans=False, + betas=None, + beta=0, + glob_rot=None, + get_rotations_back=False) + + def _get_t2m_evaluator(self, cfg): + """ + load T2M text encoder and motion encoder for evaluating + """ + + + # init module + if cfg.model.eval_text_source == 'token': + + self.t2m_textencoder = t2m_textenc.TextEncoderBiGRUCo(word_size=cfg.model.t2m_textencoder.dim_word, + pos_size=cfg.model.t2m_textencoder.dim_pos_ohot, + hidden_size=cfg.model.t2m_textencoder.dim_text_hidden, + output_size=cfg.model.t2m_textencoder.dim_coemb_hidden, + ) + elif cfg.model.eval_text_source == 'only_text_token': + + self.t2m_textencoder = t2m_textenc.TextEncoderBiGRUCoV2(word_size=cfg.model.t2m_textencoder.dim_word, + hidden_size=cfg.model.t2m_textencoder.dim_text_hidden, + output_size=cfg.model.t2m_textencoder.dim_coemb_hidden, + ) + + elif cfg.model.eval_text_source in ['caption']: + + + if cfg.model.eval_text_encode_way == 'clip': + self.t2m_textencoder, clip_preprocess = clip.load("ViT-B/32", device=opt.device, jit=False) # Must set jit=False for training + clip.model.convert_weights(text_enc)# Actually this line is unnecessary since clip by default already on float16 + self.t2m_textencoder.eval() + for p in text_enc.parameters(): + p.requires_grad = False + + elif cfg.model.eval_text_encode_way == 't5': + os.environ["TOKENIZERS_PARALLELISM"] = "false" + self.t2m_textencoder = SentenceTransformer('sentence-transformers/sentence-t5-xl').to(opt.device) + self.t2m_textencoder.eval() + for p in self.t2m_textencoder.parameters(): + p.requires_grad = False + + elif 'GRU' in cfg.model.eval_text_encode_way: + self.t2m_textencoder = t2m_textenc.TextEncoderBiGRUCoV2(word_size=cfg.model.t2m_textencoder.dim_word, + hidden_size=cfg.model.t2m_textencoder.dim_text_hidden, + output_size=cfg.model.t2m_textencoder.dim_coemb_hidden, + ) + else: + raise NotImplementedError + + + + if cfg.DATASET.MOTION_TYPE in ['vector_263', 'ric_rot', 'vector_263_ori_humanml']: + self.t2m_moveencoder = t2m_motionenc.MovementConvEncoder( + input_size=cfg.DATASET.NFEATS - 4, + hidden_size=cfg.model.t2m_motionencoder.dim_move_hidden, + output_size=cfg.model.t2m_motionencoder.dim_move_latent, + ) + elif cfg.DATASET.MOTION_TYPE in ['smplx_212', 'smplx_159']: + self.t2m_moveencoder = t2m_motionenc.MovementConvEncoder( + input_size=cfg.DATASET.NFEATS, + hidden_size=cfg.model.t2m_motionencoder.dim_move_hidden, + output_size=cfg.model.t2m_motionencoder.dim_move_latent, + ) + + else: + raise NotImplementedError + + self.t2m_motionencoder = t2m_motionenc.MotionEncoderBiGRUCo( + input_size=cfg.model.t2m_motionencoder.dim_move_latent, + hidden_size=cfg.model.t2m_motionencoder.dim_motion_hidden, + output_size=cfg.model.t2m_motionencoder.dim_motion_latent, + ) + + # load pretrianed + dataname = cfg.TEST.DATASETS[0] + dataname = "t2m" if dataname == "humanml3d" else dataname + # t2m_checkpoint = torch.load( + # os.path.join(cfg.model.t2m_path, dataname, cfg.DATASET.VERSION, cfg.DATASET.MOTION_TYPE, + # "text_mot_match_glove_6B_caption_bs_256/model/finest.tar")) + + + minor_motin_type = cfg.DATASET.MINOR_MOTION_TYPE if 'MINOR_MOTION_TYPE' in cfg.DATASET else '' + + + if dataname in ['motionx', 'motionx_v25', 'motionx_v26']: + + if 'TEXT_TYPE' in cfg.DATASET: + if cfg.DATASET.TEXT_TYPE == 'vicuna1.5_13b': + + t2m_checkpoint = torch.load( + os.path.join(cfg.model.t2m_path, dataname, cfg.DATASET.VERSION, cfg.DATASET.MOTION_TYPE, minor_motin_type, + "text_mot_match_glove_6B_caption_bs_256_text_vicuna1.5/model/finest.tar"), map_location=torch.device('cpu')) + elif cfg.DATASET.TEXT_TYPE == 'vicuna1.5_13b_add_subject': + t2m_checkpoint = torch.load( + os.path.join(cfg.model.t2m_path, dataname, cfg.DATASET.VERSION, cfg.DATASET.MOTION_TYPE, minor_motin_type, + "text_mot_match_glove_6B_caption_bs_256_text_vicuna1.5_add_subject/model/finest.tar"), map_location=torch.device('cpu')) + + else: + t2m_checkpoint = torch.load( + os.path.join(cfg.model.t2m_path, dataname, cfg.DATASET.VERSION, cfg.DATASET.MOTION_TYPE, minor_motin_type, + "text_mot_match_glove_6B_caption_bs_256/model/finest.tar"), map_location=torch.device('cpu')) + else: + t2m_checkpoint = torch.load( + os.path.join(cfg.model.t2m_path, dataname, + "text_mot_match/model/finest.tar"), map_location=torch.device('cpu')) + + self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) + + self.t2m_moveencoder.load_state_dict( + t2m_checkpoint["movement_encoder"]) + + + self.t2m_motionencoder.load_state_dict( + t2m_checkpoint["motion_encoder"]) + + # freeze params + self.t2m_textencoder.eval() + self.t2m_moveencoder.eval() + self.t2m_motionencoder.eval() + for p in self.t2m_textencoder.parameters(): + p.requires_grad = False + for p in self.t2m_moveencoder.parameters(): + p.requires_grad = False + for p in self.t2m_motionencoder.parameters(): + p.requires_grad = False + + + def _get_tmr_t2m_evaluator(self, cfg): + """ + load tmr T2M text encoder and motion encoder for evaluating + """ + + assert cfg.model.eval_text_source in ['caption'] + + self.t2m_TMR_textencoder_eval = DistilbertActorAgnosticEncoder('distilbert-base-uncased', num_layers=4) + self.t2m_TMR_motionencoder_eval = ActorAgnosticEncoder(nfeats=cfg.DATASET.NFEATS, vae =True, num_layers=4) + + + # load pretrianed + dataname = cfg.TEST.DATASETS[0] + dataname = "t2m" if dataname == "humanml3d" else dataname + # t2m_checkpoint = torch.load( + # os.path.join(cfg.model.t2m_path, dataname, cfg.DATASET.VERSION, cfg.DATASET.MOTION_TYPE, + # "text_mot_match_glove_6B_caption_bs_256/model/finest.tar")) + + + minor_motin_type = cfg.DATASET.MINOR_MOTION_TYPE if 'MINOR_MOTION_TYPE' in cfg.DATASET else '' + if dataname in ['motionx', 'motionx_v25', 'motionx_v26']: + t2m_checkpoint = torch.load( + os.path.join(cfg.model.t2m_path, dataname, cfg.DATASET.VERSION, cfg.DATASET.MOTION_TYPE, minor_motin_type, "TMR_pretrain_new/epoch=59.ckpt"), map_location=torch.device('cpu')) + state_dict = t2m_checkpoint["state_dict"] + else: + t2m_checkpoint = torch.load( + os.path.join(cfg.model.t2m_path, dataname, + "text_mot_match/model/finest.tar"), map_location=torch.device('cpu')) + + tmr_textencoder_dict = OrderedDict() + for k, v in state_dict.items(): + # print(k) + if k.split(".")[0] == "textencoder": + name = k.replace("textencoder.", "") + tmr_textencoder_dict[name] = v + + self.t2m_TMR_textencoder_eval.load_state_dict(tmr_textencoder_dict, strict=True) + + + tmr_motionencoder_dict = OrderedDict() + for k, v in state_dict.items(): + # print(k) + if k.split(".")[0] == "motionencoder": + name = k.replace("motionencoder.", "") + tmr_motionencoder_dict[name] = v + + self.t2m_TMR_motionencoder_eval.load_state_dict(tmr_motionencoder_dict, strict=True) + + # freeze params + self.t2m_TMR_textencoder_eval.freeze() + self.t2m_TMR_motionencoder_eval.freeze() + self.t2m_TMR_textencoder_eval.eval() + self.t2m_TMR_motionencoder_eval.eval() + for p in self.t2m_TMR_textencoder_eval.parameters(): + p.requires_grad = False + for p in self.t2m_TMR_motionencoder_eval.parameters(): + p.requires_grad = False + + def sample_from_distribution( + self, + dist, + *, + fact=None, + sample_mean=False, + ) -> Tensor: + fact = fact if fact is not None else self.fact + sample_mean = sample_mean if sample_mean is not None else self.sample_mean + + if sample_mean: + return dist.loc.unsqueeze(0) + + # Reparameterization trick + if fact is None: + return dist.rsample().unsqueeze(0) + + # Resclale the eps + eps = dist.rsample() - dist.loc + z = dist.loc + fact * eps + + # add latent size + z = z.unsqueeze(0) + return z + + def forward(self, batch): + texts = batch["text"] + lengths = batch["length"] + if self.cfg.TEST.COUNT_TIME: + self.starttime = time.time() + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + motions = batch['motion'] + z, dist_m = self.vae.encode(motions, lengths) + + with torch.no_grad(): + # ToDo change mcross actor to same api + if self.vae_type in ["mld","actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + + if self.cfg.TEST.COUNT_TIME: + self.endtime = time.time() + elapsed = self.endtime - self.starttime + self.times.append(elapsed) + if len(self.times) % 100 == 0: + meantime = np.mean( + self.times[-100:]) / self.cfg.TEST.BATCH_SIZE + print( + f'100 iter mean Time (batch_size: {self.cfg.TEST.BATCH_SIZE}): {meantime}', + ) + if len(self.times) % 1000 == 0: + meantime = np.mean( + self.times[-1000:]) / self.cfg.TEST.BATCH_SIZE + print( + f'1000 iter mean Time (batch_size: {self.cfg.TEST.BATCH_SIZE}): {meantime}', + ) + with open(pjoin(self.cfg.FOLDER_EXP, 'times.txt'), 'w') as f: + for line in self.times: + f.write(str(line)) + f.write('\n') + joints = self.feats2joints(feats_rst.detach().cpu()) + return remove_padding(joints, lengths) + + def gen_from_latent(self, batch): + z = batch["latent"] + lengths = batch["length"] + + feats_rst = self.vae.decode(z, lengths) + + # feats => joints + joints = self.feats2joints(feats_rst.detach().cpu()) + return remove_padding(joints, lengths) + + def recon_from_motion(self, batch): + feats_ref = batch["motion"] + length = batch["length"] + + z, dist = self.vae.encode(feats_ref, length) + feats_rst = self.vae.decode(z, length) + + # feats => joints + joints = self.feats2joints(feats_rst.detach().cpu()) + joints_ref = self.feats2joints(feats_ref.detach().cpu()) + return remove_padding(joints, + length), remove_padding(joints_ref, length) + + def _diffusion_reverse(self, encoder_hidden_states, lengths=None): + # init latents + bsz = encoder_hidden_states.shape[0] + if self.do_classifier_free_guidance: + bsz = bsz // 2 + if self.vae_type == "no": + assert lengths is not None, "no vae (diffusion only) need lengths for diffusion" + latents = torch.randn( + (bsz, max(lengths), self.cfg.DATASET.NFEATS), + device=encoder_hidden_states.device, + dtype=torch.float, + ) + else: + latents = torch.randn( + (bsz, self.latent_dim[0], self.latent_dim[-1]), + device=encoder_hidden_states.device, + dtype=torch.float, + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + # set timesteps + self.scheduler.set_timesteps( + self.cfg.model.scheduler.num_inference_timesteps) + timesteps = self.scheduler.timesteps.to(encoder_hidden_states.device) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, and between [0, 1] + extra_step_kwargs = {} + if "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys()): + extra_step_kwargs["eta"] = self.cfg.model.scheduler.eta + + # reverse + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = (torch.cat( + [latents] * + 2) if self.do_classifier_free_guidance else latents) + lengths_reverse = (lengths * 2 if self.do_classifier_free_guidance + else lengths) + # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + noise_pred = self.denoiser( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=encoder_hidden_states, + lengths=lengths_reverse, + )[0] + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond) + # text_embeddings_for_guidance = encoder_hidden_states.chunk( + # 2)[1] if self.do_classifier_free_guidance else encoder_hidden_states + latents = self.scheduler.step(noise_pred, t, latents, + **extra_step_kwargs).prev_sample + # if self.predict_epsilon: + # latents = self.scheduler.step(noise_pred, t, latents, + # **extra_step_kwargs).prev_sample + # else: + # # predict x for standard diffusion model + # # compute the previous noisy sample x_t -> x_t-1 + # latents = self.scheduler.step(noise_pred, + # t, + # latents, + # **extra_step_kwargs).prev_sample + + # [batch_size, 1, latent_dim] -> [1, batch_size, latent_dim] + latents = latents.permute(1, 0, 2) + return latents + + def _diffusion_reverse_tsne(self, encoder_hidden_states, lengths=None): + # init latents + bsz = encoder_hidden_states.shape[0] + if self.do_classifier_free_guidance: + bsz = bsz // 2 + if self.vae_type == "no": + assert lengths is not None, "no vae (diffusion only) need lengths for diffusion" + latents = torch.randn( + (bsz, max(lengths), self.cfg.DATASET.NFEATS), + device=encoder_hidden_states.device, + dtype=torch.float, + ) + else: + latents = torch.randn( + (bsz, self.latent_dim[0], self.latent_dim[-1]), + device=encoder_hidden_states.device, + dtype=torch.float, + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + # set timesteps + self.scheduler.set_timesteps( + self.cfg.model.scheduler.num_inference_timesteps) + timesteps = self.scheduler.timesteps.to(encoder_hidden_states.device) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, and between [0, 1] + extra_step_kwargs = {} + if "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys()): + extra_step_kwargs["eta"] = self.cfg.model.scheduler.eta + + # reverse + latents_t = [] + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = (torch.cat( + [latents] * + 2) if self.do_classifier_free_guidance else latents) + lengths_reverse = (lengths * 2 if self.do_classifier_free_guidance + else lengths) + # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + noise_pred = self.denoiser( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=encoder_hidden_states, + lengths=lengths_reverse, + )[0] + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond) + # text_embeddings_for_guidance = encoder_hidden_states.chunk( + # 2)[1] if self.do_classifier_free_guidance else encoder_hidden_states + latents = self.scheduler.step(noise_pred, t, latents, + **extra_step_kwargs).prev_sample + # [batch_size, 1, latent_dim] -> [1, batch_size, latent_dim] + latents_t.append(latents.permute(1,0,2)) + # [1, batch_size, latent_dim] -> [t, batch_size, latent_dim] + latents_t = torch.cat(latents_t) + return latents_t + + def _diffusion_process(self, latents, encoder_hidden_states, lengths=None): + """ + heavily from https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py + """ + # our latent [batch_size, n_token=1 or 5 or 10, latent_dim=256] + # sd latent [batch_size, [n_token0=64,n_token1=64], latent_dim=4] + # [n_token, batch_size, latent_dim] -> [batch_size, n_token, latent_dim] + latents = latents.permute(1, 0, 2) + + # Sample noise that we'll add to the latents + # [batch_size, n_token, latent_dim] + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each motion + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bsz, ), + device=latents.device, + ) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + noisy_latents = self.noise_scheduler.add_noise(latents.clone(), noise, + timesteps) + # Predict the noise residual + noise_pred = self.denoiser( + sample=noisy_latents, + timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + lengths=lengths, + return_dict=False, + )[0] + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + if self.cfg.LOSS.LAMBDA_PRIOR != 0.0: + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) + else: + noise_pred_prior = 0 + noise_prior = 0 + n_set = { + "noise": noise, + "noise_prior": noise_prior, + "noise_pred": noise_pred, + "noise_pred_prior": noise_pred_prior, + } + if not self.predict_epsilon: + n_set["pred"] = noise_pred + n_set["latent"] = latents + return n_set + + + + + + def train_vae_forward(self, batch): + feats_ref = batch["motion"] + lengths = batch["length"] + + if self.vae_type in ["hvq_body_hand_face"]: + face_ref = batch["face_motion"] + + + joint_mask = batch["joint_mask"] + + if self.vae_type in ["mld", "vposert", "actor"]: + motion_z, dist_m = self.vae.encode(feats_ref, lengths) + feats_rst = self.vae.decode(motion_z, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + feats_rst, (commit_x, commit_x_d), perplexity = self.vae.forward(feats_ref) + elif self.vae_type in ["hvq"]: + feats_rst, (commit_x_t, commit_x_d_t), (commit_x_b, commit_x_d_b) = self.vae.forward(feats_ref) + elif self.vae_type in ["hvq_body_hand"]: + feats_rst, (commit_x_t, commit_x_d_t), (commit_x_b, commit_x_d_b) = self.vae.forward(feats_ref) + elif self.vae_type in ["hvq_body_hand_face"]: + feats_rst, (commit_x_f, commit_x_d_f), (commit_x_t, commit_x_d_t), (commit_x_b, commit_x_d_b) = self.vae.forward(torch.cat((feats_ref, face_ref), dim=2)) + face_rst = feats_rst[:, :, -53:] + feats_rst = feats_rst[:, :, :-53] + elif self.vae_type in ["rvq"]: + feats_rst, (commit_x, commit_x_d1, commit_x_d2), perplexity = self.vae.forward(feats_ref) + + elif self.vae_type in ["mld_dual_vae"]: + body_motion_z, hand_motion_z, body_dist_m, hand_dist_m = self.vae.encode(feats_ref, lengths) + feats_rst = self.vae.decode(body_motion_z, hand_motion_z, lengths) + elif self.vae_type in ["dual_human_vq"]: + feats_rst, (body_commit_x, body_commit_x_d), (hand_commit_x, hand_commit_x_d), body_perplexity, hand_perplexity = self.vae.forward(feats_ref) + else: + raise TypeError("vae_type must be mcross or actor") + + # prepare for metric + if self.vae_type in ["mld", "vposert", "actor"]: + recons_z, dist_rm = self.vae.encode(feats_rst, lengths) + elif self.vae_type in ["mld_dual_vae"]: + body_recons_z, hand_recons_z, body_dist_rm, hand_dist_rm = self.vae.encode(feats_ref, lengths) + + # joints recover + if self.condition in ["text", "text_all", 'text_hand', 'text_body', 'text_face', "text_seperate", "only_pose_concat", "only_pose_fusion"]: + + if self.input_format in ['vector_263', 'vector_263_ori_humanml', 'root_position', 'root_position_vel', 'root_position_rot6d', 'all', 'root_body_pos_vel_hand_all', 'root_body_pos_vel_hand_pos_vel', 'root_body_pos_vel_hand_pos', 'root_position_vel_only_body', 'root_body_pos_vel_hand_pos_vel_hand_wrist']: + joints_rst = self.feats2joints(feats_rst, self.input_format) # feats_rst.shape (bs, seq, 67) joints_rst.shape (bs, seq, 22, 3) + joints_ref = self.feats2joints(feats_ref, self.input_format) + elif self.input_format in ['root_rot6d']: + joints_rst = self.feats2joints(feats_rst, skel=self.skel, motion_type=self.input_format) + joints_rst = joints_rst.view(feats_rst.shape[0], feats_rst.shape[1], self.njoints, 3) + joints_ref = self.feats2joints(feats_ref, skel=self.skel, motion_type=self.input_format) + joints_ref = joints_ref.view(feats_ref.shape[0], feats_ref.shape[1], self.njoints, 3) + elif self.input_format in ['smplx_212', 'smplx_159'] and self.cfg.TRAIN.use_joints: + joints_rst = self.feats2joints(feats_rst, self.input_format, self.smplx_model) + joints_ref = self.feats2joints(feats_ref, self.input_format, self.smplx_model) + elif self.input_format == 'root_body_pos_vel_hand_rot': + + joints_rst = self.feats2joints(feats_rst, skel=self.skel, motion_type=self.input_format) + joints_rst = joints_rst.view(feats_rst.shape[0], feats_rst.shape[1], self.njoints, 3) + joints_ref = self.feats2joints(feats_ref, skel=self.skel, motion_type=self.input_format) + joints_ref = joints_ref.view(feats_ref.shape[0], feats_ref.shape[1], self.njoints, 3) + elif self.input_format in ['smplx_212', 'smplx_159'] and (not self.cfg.TRAIN.use_joints): + pass + + else: + raise NotImplementedError + + elif self.condition == "action": + mask = batch["mask"] + joints_rst = self.feats2joints(feats_rst, mask) + joints_ref = self.feats2joints(feats_ref, mask) + + if self.vae_type in ["mld", "vposert", "actor"]: + if dist_m is not None: + if self.is_vae: + # Create a centred normal distribution to compare with + mu_ref = torch.zeros_like(dist_m.loc) + scale_ref = torch.ones_like(dist_m.scale) + dist_ref = torch.distributions.Normal(mu_ref, scale_ref) + else: + dist_ref = dist_m + + elif self.vae_type in ["mld_dual_vae"]: + if body_dist_m is not None: + if self.is_vae: + # Create a centred normal distribution to compare with + body_mu_ref = torch.zeros_like(body_dist_m.loc) + body_scale_ref = torch.ones_like(body_dist_m.scale) + body_dist_ref = torch.distributions.Normal(body_mu_ref, body_scale_ref) + else: + body_dist_ref = body_dist_m + + if hand_dist_m is not None: + if self.is_vae: + # Create a centred normal distribution to compare with + hand_mu_ref = torch.zeros_like(hand_dist_m.loc) + hand_scale_ref = torch.ones_like(hand_dist_m.scale) + hand_dist_ref = torch.distributions.Normal(hand_mu_ref, hand_scale_ref) + else: + hand_dist_ref = hand_dist_m + + + # cut longer part over max length + min_len = min(feats_ref.shape[1], feats_rst.shape[1]) + + if self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + + rs_set = { + "m_ref": feats_ref[:, :min_len, :], + "m_rst": feats_rst[:, :min_len, :], + "commit_x": commit_x, + "commit_x_d": commit_x_d + } + + return rs_set + + elif self.vae_type in ["rvq"]: + rs_set = { + "m_ref": feats_ref[:, :min_len, :], + "m_rst": feats_rst[:, :min_len, :], + "commit_x": commit_x, + "commit_x_d1": commit_x_d1, + "commit_x_d2": commit_x_d2 + } + + return rs_set + + + elif self.vae_type in ["dual_human_vq"]: + rs_set = { + "m_ref": feats_ref[:, :min_len, :], + "m_rst": feats_rst[:, :min_len, :], + "body_commit_x": body_commit_x, + "hand_commit_x": hand_commit_x, + "body_commit_x_d": body_commit_x_d, + "hand_commit_x_d": hand_commit_x_d, + } + + + return rs_set + + + elif self.vae_type in ["hvq", "hvq_body_hand"]: + rs_set = { + "m_ref": feats_ref[:, :min_len, :], + "m_rst": feats_rst[:, :min_len, :], + "commit_x_t": commit_x_t, + "commit_x_d_t": commit_x_d_t, + "commit_x_b": commit_x_b , + "commit_x_d_b": commit_x_d_b, + # "" + } + + elif self.vae_type in ['hvq_body_hand_face']: + rs_set = { + "m_ref": feats_ref[:, :min_len, :], + "m_rst": feats_rst[:, :min_len, :], + "fm_ref": face_ref[:, :min_len, :], + "fm_rst": face_rst[:, :min_len, :], + "commit_x_t": commit_x_t, + "commit_x_d_t": commit_x_d_t, + "commit_x_b": commit_x_b , + "commit_x_d_b": commit_x_d_b, + "commit_x_f": commit_x_f, + "commit_x_d_f": commit_x_d_f + # "" + } + + # return rs_set + + + if self.vae_type in ['mld_dual_vae']: + + if self.cfg.TRAIN.use_joints: + rs_set = { + "m_ref": feats_ref[:, :min_len, :], + "m_rst": feats_rst[:, :min_len, :], + # [bs, ntoken, nfeats]<= [ntoken, bs, nfeats] + "body_lat_m": body_motion_z.permute(1, 0, 2), + "hand_lat_m": hand_motion_z.permute(1, 0, 2), + "body_lat_rm": body_recons_z.permute(1, 0, 2), + "hand_lat_rm": hand_recons_z.permute(1, 0, 2), + "joints_ref": joints_ref, + "joints_rst": joints_rst, + "body_dist_m": body_dist_m, + "hand_dist_m": hand_dist_m, + "body_dist_ref": body_dist_ref, + "hand_dist_ref": hand_dist_ref, + } + else: + + rs_set = { + "m_ref": feats_ref[:, :min_len, :], + "m_rst": feats_rst[:, :min_len, :], + # [bs, ntoken, nfeats]<= [ntoken, bs, nfeats] + "body_lat_m": body_motion_z.permute(1, 0, 2), + "hand_lat_m": hand_motion_z.permute(1, 0, 2), + "body_lat_rm": body_recons_z.permute(1, 0, 2), + "hand_lat_rm": hand_recons_z.permute(1, 0, 2), + "body_dist_m": body_dist_m, + "hand_dist_m": hand_dist_m, + "body_dist_ref": body_dist_ref, + "hand_dist_ref": hand_dist_ref, + } + + # return rs_set + + elif self.vae_type in ["mld"]: + if self.cfg.TRAIN.use_joints: + rs_set = { + "m_ref": feats_ref[:, :min_len, :], + "m_rst": feats_rst[:, :min_len, :], + # [bs, ntoken, nfeats]<= [ntoken, bs, nfeats] + "lat_m": motion_z.permute(1, 0, 2), + "lat_rm": recons_z.permute(1, 0, 2), + "joints_ref": joints_ref, + "joints_rst": joints_rst, + "dist_m": dist_m, + "dist_ref": dist_ref, + } + else: + rs_set = { + "m_ref": feats_ref[:, :min_len, :], + "m_rst": feats_rst[:, :min_len, :], + # [bs, ntoken, nfeats]<= [ntoken, bs, nfeats] + "lat_m": motion_z.permute(1, 0, 2), + "lat_rm": recons_z.permute(1, 0, 2), + "dist_m": dist_m, + "dist_ref": dist_ref, + } + + else: + if self.cfg.TRAIN.use_joints: + # rs_set = { + # "m_ref": feats_ref[:, :min_len, :], + # "m_rst": feats_rst[:, :min_len, :], + # # [bs, ntoken, nfeats]<= [ntoken, bs, nfeats] + # "lat_m": motion_z.permute(1, 0, 2), + # "lat_rm": recons_z.permute(1, 0, 2), + # "joints_ref": joints_ref, + # "joints_rst": joints_rst, + # "dist_m": dist_m, + # "dist_ref": dist_ref, + # } + rs_set["joints_ref"] = joints_ref + rs_set["joints_rst"] = joints_rst + else: + rs_set = { + "m_ref": feats_ref[:, :min_len, :], + "m_rst": feats_rst[:, :min_len, :], + # [bs, ntoken, nfeats]<= [ntoken, bs, nfeats] + "lat_m": motion_z.permute(1, 0, 2), + "lat_rm": recons_z.permute(1, 0, 2), + "dist_m": dist_m, + "dist_ref": dist_ref, + } + + + if self.cfg.LOSS.hand_mask: + rs_set['joint_mask'] = batch['joint_mask'][:, :min_len, :] + + + if self.cfg.LOSS.Velocity_loss: + vel_ref = feats_ref[:, :min_len, :][:, 1:, 3:] - feats_ref[:, :min_len, :][:, :-1, 3:] + vel_rst = feats_rst[:, :min_len, :][:, 1:, 3:] - feats_rst[:, :min_len, :][:, :-1, 3:] + rs_set['vel_rst'] = vel_rst + rs_set['vel_ref'] = vel_ref + + + return rs_set + + def train_diffusion_forward(self, batch): + feats_ref = batch["motion"] + lengths = batch["length"] + # motion encode + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist = self.vae.encode(feats_ref, lengths) + elif self.vae_type == "no": + z = feats_ref.permute(1, 0, 2) + else: + raise TypeError("vae_type must be mcross or actor") + + if self.condition in ["text", "text_uncond"]: + text = batch["text"] + # classifier free guidance: randomly drop text during training + text = [ + "" if np.random.rand(1) < self.guidance_uncodp else i + for i in text + ] + # text encode + cond_emb = self.text_encoder(text) + elif self.condition in ["text_all"]: + text = [] + + for i in range(len(batch["text"])): + text.append(batch["text"][i] +' ' + batch['face_text'][i] + ' ' + batch["body_text"][i] + ' ' + batch["hand_text"][i]) + # text = batch["text"] +' ' + batch["body_text"] + ' ' + batch["hand_text"] + # classifier free guidance: randomly drop text during training + text = [ + "" if np.random.rand(1) < self.guidance_uncodp else i + for i in text + ] + # text encode + cond_emb = self.text_encoder(text) + elif self.condition in ["text_face"]: + text = [] + + for i in range(len(batch["text"])): + text.append(batch["text"][i] +' ' + batch['face_text'][i]) + + text = [ + "" if np.random.rand(1) < self.guidance_uncodp else i + for i in text + ] + # text encode + cond_emb = self.text_encoder(text) + elif self.condition in ["text_body"]: + text = [] + + for i in range(len(batch["text"])): + text.append(batch["text"][i] +' ' + batch['body_text'][i]) + + # text = batch["text"] +' ' + batch["body_text"] + ' ' + batch["hand_text"] + # classifier free guidance: randomly drop text during training + text = [ + "" if np.random.rand(1) < self.guidance_uncodp else i + for i in text + ] + # text encode + cond_emb = self.text_encoder(text) + elif self.condition in ["text_hand"]: + text = [] + + for i in range(len(batch["text"])): + text.append(batch["text"][i] +' ' + batch['hand_text'][i]) + + # text = batch["text"] +' ' + batch["body_text"] + ' ' + batch["hand_text"] + # classifier free guidance: randomly drop text during training + text = [ + "" if np.random.rand(1) < self.guidance_uncodp else i + for i in text + ] + # text encode + cond_emb = self.text_encoder(text) + + elif self.condition in ['text_face_body']: + text = [] + + for i in range(len(batch["text"])): + text.append(batch["text"][i] +' ' + batch['face_text'][i] + ' ' + batch["body_text"][i]) + + # text = batch["text"] +' ' + batch["body_text"] + ' ' + batch["hand_text"] + # classifier free guidance: randomly drop text during training + text = [ + "" if np.random.rand(1) < self.guidance_uncodp else i + for i in text + ] + # text encode + cond_emb = self.text_encoder(text) + + elif self.condition in ["text_seperate"]: + + text = [] + for i in range(len(batch["text"])): + text.append((batch["text"][i], batch["face_text"][i], batch["body_text"][i], batch["hand_text"][i])) + + # text = batch["text"] +' ' + batch["body_text"] + ' ' + batch["hand_text"] + # classifier free guidance: randomly drop text during training + text = [ + ("", "", "", "") if np.random.rand(1) < self.guidance_uncodp else i + for i in text + ] + # text encode + + semantic_text = [] + face_text = [] + body_text = [] + hand_text = [] + for i in range(len(text)): + semantic_text.append(text[i][0]) + face_text.append(text[i][1]) + body_text.append(text[i][2]) + hand_text.append(text[i][3]) + + cond_emb_semantic = self.text_encoder(semantic_text) + cond_emb_face = self.text_encoder(face_text) + cond_emb_body = self.text_encoder(body_text) + cond_emb_hand = self.text_encoder(hand_text) + + cond_emb = self.linear_fusion(cond_emb_semantic, cond_emb_face, cond_emb_body, cond_emb_hand) + + elif self.condition in ["only_pose_concat"]: + text = [] + for i in range(len(batch["text"])): + text.append(batch["face_text"][i] +' ' + batch["body_text"][i] + ' ' + batch["hand_text"][i]) + + # text = batch["text"] +' ' + batch["body_text"] + ' ' + batch["hand_text"] + # classifier free guidance: randomly drop text during training + text = [ + "" if np.random.rand(1) < self.guidance_uncodp else i + for i in text + ] + # text encode + cond_emb = self.text_encoder(text) + + elif self.condition in ["only_pose_fusion"]: + + text = [] + for i in range(len(batch["text"])): + text.append((batch["face_text"][i], batch["body_text"][i], batch["hand_text"][i])) + + # text = batch["text"] +' ' + batch["body_text"] + ' ' + batch["hand_text"] + # classifier free guidance: randomly drop text during training + text = [ + ("", "", "") if np.random.rand(1) < self.guidance_uncodp else i + for i in text + ] + # text encode + + face_text = [] + body_text = [] + hand_text = [] + for i in range(len(text)): + face_text.append(text[i][0]) + body_text.append(text[i][1]) + hand_text.append(text[i][2]) + + cond_emb_face = self.text_encoder(face_text) + cond_emb_body = self.text_encoder(body_text) + cond_emb_hand = self.text_encoder(hand_text) + + + cond_emb = self.linear_fusion(None,cond_emb_face, cond_emb_body, cond_emb_hand) + # emb_cat = torch.cat((cond_emb_face, cond_emb_body), axis=1) + # emb_cat = emb_cat.view(emb_cat.size(0), -1) + # cond_emb = self.emb_fuse(emb_cat).unsqueeze(1) + + + elif self.condition in ['action']: + action = batch['action'] + # text encode + cond_emb = action + else: + raise TypeError(f"condition type {self.condition} not supported") + + # diffusion process return with noise and noise_pred + n_set = self._diffusion_process(z, cond_emb, lengths) + return {**n_set} + + def test_diffusion_forward(self, batch, finetune_decoder=False): + lengths = batch["length"] + + if self.condition in ["text", "text_uncond"]: + # get text embeddings + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(lengths) + if self.condition == 'text': + texts = batch["text"] + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + texts = uncond_tokens + cond_emb = self.text_encoder(texts) + elif self.condition in ['action']: + cond_emb = batch['action'] + if self.do_classifier_free_guidance: + cond_emb = torch.cat( + cond_emb, + torch.zeros_like(batch['action'], + dtype=batch['action'].dtype)) + else: + raise TypeError(f"condition type {self.condition} not supported") + + # diffusion reverse + with torch.no_grad(): + z = self._diffusion_reverse(cond_emb, lengths) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + else: + raise TypeError("vae_type must be mcross or actor or mld") + + joints_rst = self.feats2joints(feats_rst) + + rs_set = { + "m_rst": feats_rst, + # [bs, ntoken, nfeats]<= [ntoken, bs, nfeats] + "lat_t": z.permute(1, 0, 2), + "joints_rst": joints_rst, + } + + # prepare gt/refer for metric + if "motion" in batch.keys() and not finetune_decoder: + feats_ref = batch["motion"].detach() + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + motion_z, dist_m = self.vae.encode(feats_ref, lengths) + recons_z, dist_rm = self.vae.encode(feats_rst, lengths) + elif self.vae_type == "no": + motion_z = feats_ref + recons_z = feats_rst + + joints_ref = self.feats2joints(feats_ref) + + rs_set["m_ref"] = feats_ref + rs_set["lat_m"] = motion_z.permute(1, 0, 2) + rs_set["lat_rm"] = recons_z.permute(1, 0, 2) + rs_set["joints_ref"] = joints_ref + return rs_set + + def t2m_eval(self, batch): + texts = batch["text"] + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + + if self.vae_type in ["hvq_body_hand_face"]: + face_ref = batch["face_motion"] + # start + start = time.time() + + if self.trainer.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + quants = self.vae.encode(motions) + elif self.vae_type in ["hvq", "hvq_body_hand"]: + _, _, _, _, id_t, id_b = self.vae.encode(motions) + elif self.vae_type in ["rvq"]: + quants_1, quants_2 = self.vae.encode(motions) + elif self.vae_type in ["dual_human_vq"]: + body_quants, hand_quants = self.vae.encode(motions) + elif self.vae_type == "hvq_body_hand_face": + _, _, _, _, _, _, id_f, id_t, id_b = self.vae.encode(motions) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + feats_rst = self.vae.forward_decoder(quants) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type in ["hvq", "hvq_body_hand"]: + feats_rst = self.vae.forward_decoder(id_t, id_b) + elif self.vae_type in ["hvq_body_hand_face"]: + feats_rst = self.vae.forward_decoder(id_f, id_t, id_b) + face_rst = feats_rst[:, :, -53:] + feats_rst = feats_rst[:, :, :-53] + + elif self.vae_type in ["rvq"]: + feats_rst = self.vae.forward_decoder(quants_1, quants_2) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type in ["dual_human_vq"]: + feats_rst = self.vae.forward_decoder(body_quants, hand_quants) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + else: + raise TypeError("Not supported vae type!") + + # end time + end = time.time() + self.times.append(end - start) + # joints recover + joints_rst = self.feats2joints(feats_rst, self.input_format) + joints_ref = self.feats2joints(motions, self.input_format) + + + + # renorm for t2m evaluators + feats_rst = self.datamodule.renorm4t2m(feats_rst) + motions = self.datamodule.renorm4t2m(motions) + + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + recons_mov = self.t2m_moveencoder(feats_rst[..., :-4]).detach() + recons_emb = self.t2m_motionencoder(recons_mov, m_lens) + motion_mov = self.t2m_moveencoder(motions[..., :-4]).detach() + motion_emb = self.t2m_motionencoder(motion_mov, m_lens) + + # t2m text encoder + if self.cfg.model.eval_text_source == 'token': + text_emb = self.t2m_textencoder(word_embs, pos_ohot,text_lengths)[align_idx] + elif self.cfg.model.eval_text_source == 'only_text_token': + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + elif self.cfg.model.eval_text_source in ['caption']: + if self.cfg.model.eval_text_encode_way == 'clip': + raise NotImplementedError + + elif self.cfg.model.eval_text_encode_way == 't5': + raise NotImplementedError + + elif 'GRU' in self.cfg.model.eval_text_encode_way: + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + else: + raise NotImplementedError + + + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + return rs_set + + + def tmr_t2m_eval(self, batch): + + texts = batch["text"] + texts_ori = copy.deepcopy(batch["text"]) + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + + name = batch["retrieval_name"] + + # start + start = time.time() + + if self.trainer.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + texts_ori = texts_ori * self.cfg.TEST.MM_NUM_REPEATS + + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) # 1, 30 , 256 + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + quants = self.vae.encode(motions) + elif self.vae_type in ["hvq", "hvq_body_hand"]: + _, _, _, _, id_t, id_b = self.vae.encode(motions) + elif self.vae_type in ["dual_human_vq"]: + body_quants, hand_quants = self.vae.encode(motions) + elif self.vae_type in ["rvq"]: + quants_1, quants_2 = self.vae.encode(motions) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) # 30, 180, 313 + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + feats_rst = self.vae.forward_decoder(quants) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type in ["hvq", "hvq_body_hand"]: + feats_rst = self.vae.forward_decoder(id_t, id_b) + elif self.vae_type in ["rvq"]: + feats_rst = self.vae.forward_decoder(quants_1, quants_2) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type in ["dual_human_vq"]: + feats_rst = self.vae.forward_decoder(body_quants, hand_quants) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + else: + raise TypeError("Not supported vae type!") + + # end time + end = time.time() + self.times.append(end - start) + # joints recover + joints_rst = self.feats2joints(feats_rst, self.input_format) + joints_ref = self.feats2joints(motions, self.input_format) + + + + # renorm for t2m evaluators + feats_rst_before_renorm4t2m = feats_rst.clone() + motions_before_renorm4t2m = motions.clone() + + feats_rst = self.datamodule.renorm4t2m(feats_rst) + motions = self.datamodule.renorm4t2m(motions) + + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens_ori = m_lens.clone() + feats_rst_before_renorm4t2m = feats_rst_before_renorm4t2m[align_idx] + motions_before_renorm4t2m = motions_before_renorm4t2m[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + recons_mov = self.t2m_moveencoder(feats_rst[..., :-4]).detach() + recons_emb = self.t2m_motionencoder(recons_mov, m_lens) + motion_mov = self.t2m_moveencoder(motions[..., :-4]).detach() + motion_emb = self.t2m_motionencoder(motion_mov, m_lens) + + recons_emb_tmr = self.t2m_TMR_motionencoder_eval(feats_rst_before_renorm4t2m, m_lens_ori).loc + motion_emb_tmr = self.t2m_TMR_motionencoder_eval(motions_before_renorm4t2m, m_lens_ori).loc + + + + # t2m text encoder + assert self.cfg.model.eval_text_source in ['caption'] + + + if self.cfg.model.eval_text_encode_way == 'clip': + raise NotImplementedError + + elif self.cfg.model.eval_text_encode_way == 't5': + raise NotImplementedError + + elif 'GRU' in self.cfg.model.eval_text_encode_way: + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] # 30 ,512 + else: + raise NotImplementedError + + + + text_emb_tmr = self.t2m_TMR_textencoder_eval(texts_ori).loc[align_idx] # 30 , 256 + + + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_t_tmr": text_emb_tmr, + "lat_m": motion_emb, + "lat_rm": recons_emb, + "lat_m_tmr": motion_emb_tmr, + "lat_rm_tmr": recons_emb_tmr, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + + + return rs_set + + def t2m_eval_save_motion_token(self, batch): + + name = batch["name"] + motions = batch["motion"].detach().clone() + + # start + start = time.time() + + + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + quants = self.vae.encode(motions) + elif self.vae_type in ["hvq", "hvq_body_hand"]: + _, _, _, _, id_t, id_b = self.vae.encode(motions) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + feats_rst = self.vae.forward_decoder(quants) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type in ["hvq", "hvq_body_hand"]: + feats_rst = self.vae.forward_decoder(id_t, id_b) + elif self.vae_type in ["rvq"]: + feats_rst = self.vae.forward_decoder(quants_1, quants_2) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + + # end time + end = time.time() + self.times.append(end - start) + + # joints recover + + joints_rst = self.feats2joints(feats_rst, skel=self.skel, motion_type=self.input_format) + + joints_rst = joints_rst.view(feats_rst.shape[0], feats_rst.shape[1], self.njoints, 3) + joints_ref = self.feats2joints(motions, skel=self.skel, motion_type=self.input_format) + joints_ref = joints_ref.view(motions.shape[0], motions.shape[1], self.njoints, 3) + + assert len(name) == 1 + + feats_rst = self.renorm2ori(feats_rst) + motions = self.renorm2ori(motions) + feats_rst_path = os.path.join(f"./visualization/visualization/{self.cfg.model.vae_type}_VAE_motionx_feats_rst_norm_back", name[0] + '.npy') + feats_ref_path = os.path.join(f"./visualization/visualization/{self.cfg.model.vae_type}_VAE_motionx_feats_ref_norm_back", name[0] + '.npy') + joitns_rst_path = os.path.join(f"./visualization/visualization/{self.cfg.model.vae_type}_VAE_motionx_joints_rst_norm_back", name[0] + '.npy') + joitns_ref_path = os.path.join(f"./visualization/visualization/{self.cfg.model.vae_type}_VAE_motionx_joints_ref_norm_back", name[0] + '.npy') + + feats_rst_parent_directory = os.path.dirname(feats_rst_path) + if not os.path.exists(feats_rst_parent_directory): + os.makedirs(feats_rst_parent_directory) + + feats_ref_parent_directory = os.path.dirname(feats_ref_path) + if not os.path.exists(feats_ref_parent_directory): + os.makedirs(feats_ref_parent_directory) + + joints_rst_parent_directory = os.path.dirname(joitns_rst_path) + if not os.path.exists(joints_rst_parent_directory): + os.makedirs(joints_rst_parent_directory) + + joints_ref_parent_directory = os.path.dirname(joitns_ref_path) + if not os.path.exists(joints_ref_parent_directory): + os.makedirs(joints_ref_parent_directory) + + + + np.save(feats_rst_path, feats_rst[0].detach().cpu().numpy()) + np.save(feats_ref_path, motions[0].detach().cpu().numpy()) + np.save(joitns_rst_path, joints_rst[0].detach().cpu().numpy()) + np.save(joitns_ref_path, joints_ref[0].detach().cpu().numpy()) + + + if self.vae_type in ["hvq", "hvq_body_hand"]: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + "motion_code_t": id_t, + "motion_code_b": id_b, + "name": name + } + + else: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + "motion_code": quants, + "name": name + } + return rs_set + + + def t2m_eval_cal_sort(self, batch): + + name = batch["name"] + motions = batch["motion"].detach().clone() + + start = time.time() + + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + quants = self.vae.encode(motions) + elif self.vae_type in ["hvq", "hvq_body_hand"]: + _, _, _, _, id_t, id_b = self.vae.encode(motions) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + feats_rst = self.vae.forward_decoder(quants) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type in ["hvq", "hvq_body_hand"]: + feats_rst = self.vae.forward_decoder(id_t, id_b) + elif self.vae_type in ["rvq"]: + feats_rst = self.vae.forward_decoder(quants_1, quants_2) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + + # end time + end = time.time() + self.times.append(end - start) + + # joints recover + + joints_rst = self.feats2joints(feats_rst, skel=self.skel, motion_type=self.input_format) + + joints_rst = joints_rst.view(feats_rst.shape[0], feats_rst.shape[1], self.njoints, 3) + joints_ref = self.feats2joints(motions, skel=self.skel, motion_type=self.input_format) + joints_ref = joints_ref.view(motions.shape[0], motions.shape[1], self.njoints, 3) + + feats_rst = self.renorm2ori(feats_rst) + motions = self.renorm2ori(motions) + + + assert len(name) == 1 + + + if self.vae_type in ["hvq", "hvq_body_hand"]: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + "motion_code_t": id_t, + "motion_code_b": id_b + } + + else: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + "motion_code": quants + } + return rs_set + + def normal_eval(self, batch): + texts = batch["text"] + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + + # start + start = time.time() + + if self.trainer.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + quants = self.vae.encode(motions) + elif self.vae_type in ["mld_dual_vae"]: + body_z, hand_z, body_dist_m, hand_dist_m = self.vae.encode(motions, lengths) + elif self.vae_type in ["dual_human_vq"]: + body_quants, hand_quants = self.vae.encode(motions) + elif self.vae_type in ["rvq"]: + quants_1, quants_2 = self.vae.encode(motions) + elif self.vae_type in ["hvq", "hvq_body_hand"]: + _, _, _, _, id_t, id_b = self.vae.encode(motions) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + feats_rst = self.vae.forward_decoder(quants) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type in ["mld_dual_vae"]: + feats_rst = self.vae.decode(body_z, hand_z, lengths) + elif self.vae_type in ["dual_human_vq"]: + feats_rst = self.vae.forward_decoder(body_quants, hand_quants) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type in ["rvq"]: + feats_rst = self.vae.forward_decoder(quants_1, quants_2) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type in ["hvq", "hvq_body_hand"]: + feats_rst = self.vae.forward_decoder(id_t, id_b) + + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + else: + raise NotImplenetError + + # end time + end = time.time() + self.times.append(end - start) + + + joints_rst = self.feats2joints(feats_rst, skel=self.skel, motion_type=self.input_format) + joints_rst = joints_rst.view(feats_rst.shape[0], feats_rst.shape[1], self.njoints, 3) + joints_ref = self.feats2joints(motions, skel=self.skel, motion_type=self.input_format) + joints_ref = joints_ref.view(motions.shape[0], motions.shape[1], self.njoints, 3) + + + + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + return rs_set + + + def t2m_eval_smplx(self, batch): + + + texts = batch["text"] + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + # start + start = time.time() + if self.trainer.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + quants = self.vae.encode(motions) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type in ["humanvq"]: + feats_rst = self.vae.forward_decoder(quants) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + else: + raise TypeError("Not supported vae type!") + + # end time + end = time.time() + self.times.append(end - start) + # + # joints recover + if self.cfg.TRAIN.use_joints: + joints_rst = self.feats2joints(feats_rst, self.motion_type, self.smplx_model) + joints_ref = self.feats2joints(motions, self.motion_type, self.smplx_model) + + + + # renorm for t2m evaluators + feats_rst = self.datamodule.renorm4t2m(feats_rst) + motions = self.datamodule.renorm4t2m(motions) + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + assert self.motion_type in ['smplx_212', 'smplx_159'] + + + + recons_mov = self.t2m_moveencoder(feats_rst).detach() + recons_emb = self.t2m_motionencoder(recons_mov, m_lens) + motion_mov = self.t2m_moveencoder(motions).detach() + motion_emb = self.t2m_motionencoder(motion_mov, m_lens) + + # t2m text encoder + if self.cfg.model.eval_text_source == 'token': + text_emb = self.t2m_textencoder(word_embs, pos_ohot,text_lengths)[align_idx] + elif self.cfg.model.eval_text_source == 'only_text_token': + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + elif self.cfg.model.eval_text_source in ['caption']: + if self.cfg.model.eval_text_encode_way == 'clip': + raise NotImplementedError + + elif self.cfg.model.eval_text_encode_way == 't5': + raise NotImplementedError + + elif 'GRU' in self.cfg.model.eval_text_encode_way: + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + else: + raise NotImplementedError + if self.cfg.TRAIN.use_joints: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + else: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + } + + return rs_set + + + + + def t2m_eval_smplx_save_motion_token(self, batch): + # texts = batch["text"] + name = batch["name"] + motions = batch["motion"].detach().clone() + # lengths = batch["length"] + # word_embs = batch["word_embs"].detach().clone() + # pos_ohot = batch["pos_ohot"].detach().clone() + # text_lengths = batch["text_len"].detach().clone() + # start + start = time.time() + + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + elif self.vae_type in ["humanvq", "spatial_MLP_vqvae", "spatial_transformer_vqvae"]: + quants = self.vae.encode(motions) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type in ["humanvq"]: + feats_rst = self.vae.forward_decoder(quants) + feats_rst = feats_rst.reshape(motions.shape[0], motions.shape[1], -1) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + else: + raise TypeError("Not supported vae type!") + + # end time + end = time.time() + self.times.append(end - start) + # joints recover + if self.cfg.TRAIN.use_joints: + joints_rst = self.feats2joints(feats_rst, self.motion_type, self.smplx_model) + joints_ref = self.feats2joints(motions, self.motion_type, self.smplx_model) + + + #############for save tokens############# + feats_rst = self.renorm2ori(feats_rst) + motions = self.renorm2ori(motions) + feats_rst_path = os.path.join(f"./visualization/visualization/test_case/{self.cfg.TRAIN.DATASETS[0]}/{self.cfg.model.vae_type}_VAE_motionx_feats_rst_norm_back", name[0] + '.npy') + feats_ref_path = os.path.join(f"./visualization/visualization/test_case/{self.cfg.TRAIN.DATASETS[0]}/{self.cfg.model.vae_type}_VAE_motionx_feats_ref_norm_back", name[0] + '.npy') + + feats_rst_parent_directory = os.path.dirname(feats_rst_path) + if not os.path.exists(feats_rst_parent_directory): + os.makedirs(feats_rst_parent_directory) + + feats_ref_parent_directory = os.path.dirname(feats_ref_path) + if not os.path.exists(feats_ref_parent_directory): + os.makedirs(feats_ref_parent_directory) + + + np.save(feats_rst_path, feats_rst[0].detach().cpu().numpy()) + np.save(feats_ref_path, motions[0].detach().cpu().numpy()) + + + + assert self.motion_type == ['smplx_212', 'smplx_159'] + + + if self.cfg.TRAIN.use_joints: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + # "lat_t": text_emb, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + "motion_code": quants, + "name": name + } + else: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + # "lat_t": text_emb, + "motion_code": quants, + "name": name + } + + return rs_set + + + def t2m_eval_smplx_text_all(self, batch): + assert self.condition == 'text_all' + texts = [] + for i in range(len(batch["text"])): + texts.append(batch["text"][i] +' ' + batch['face_text'][i] + ' ' + batch["body_text"][i] + ' ' + batch["hand_text"][i]) + + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + # start + start = time.time() + + if self.trainer.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text_all': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + + # end time + end = time.time() + self.times.append(end - start) + + # joints recover + if self.cfg.TRAIN.use_joints: + joints_rst = self.feats2joints(feats_rst, self.motion_type, self.smplx_model) + joints_ref = self.feats2joints(motions, self.motion_type, self.smplx_model) + + # renorm for t2m evaluators + feats_rst = self.datamodule.renorm4t2m(feats_rst) + motions = self.datamodule.renorm4t2m(motions) + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + assert self.motion_type == 'smplx_212' + + + + recons_mov = self.t2m_moveencoder(feats_rst).detach() + recons_emb = self.t2m_motionencoder(recons_mov, m_lens) + motion_mov = self.t2m_moveencoder(motions).detach() + motion_emb = self.t2m_motionencoder(motion_mov, m_lens) + + # t2m text encoder + if self.cfg.model.eval_text_source == 'token': + text_emb = self.t2m_textencoder(word_embs, pos_ohot,text_lengths)[align_idx] + elif self.cfg.model.eval_text_source == 'only_text_token': + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + elif self.cfg.model.eval_text_source in ['caption']: + if self.cfg.model.eval_text_encode_way == 'clip': + raise NotImplementedError + + elif self.cfg.model.eval_text_encode_way == 't5': + raise NotImplementedError + + elif 'GRU' in self.cfg.model.eval_text_encode_way: + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + else: + raise NotImplementedError + if self.cfg.TRAIN.use_joints: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + else: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + } + + + return rs_set + + + + def t2m_eval_smplx_text_face(self, batch): + assert self.condition == 'text_face' + texts = [] + for i in range(len(batch["text"])): + texts.append(batch["text"][i] +' ' + batch['face_text'][i]) + + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + # start + start = time.time() + + if self.trainer.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text_face': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + + # end time + end = time.time() + self.times.append(end - start) + + # joints recover + if self.cfg.TRAIN.use_joints: + joints_rst = self.feats2joints(feats_rst, self.motion_type, self.smplx_model) + joints_ref = self.feats2joints(motions, self.motion_type, self.smplx_model) + + # renorm for t2m evaluators + feats_rst = self.datamodule.renorm4t2m(feats_rst) + motions = self.datamodule.renorm4t2m(motions) + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + assert self.motion_type == 'smplx_212' + + + + recons_mov = self.t2m_moveencoder(feats_rst).detach() + recons_emb = self.t2m_motionencoder(recons_mov, m_lens) + motion_mov = self.t2m_moveencoder(motions).detach() + motion_emb = self.t2m_motionencoder(motion_mov, m_lens) + + # t2m text encoder + if self.cfg.model.eval_text_source == 'token': + text_emb = self.t2m_textencoder(word_embs, pos_ohot,text_lengths)[align_idx] + elif self.cfg.model.eval_text_source == 'only_text_token': + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + elif self.cfg.model.eval_text_source in ['caption']: + if self.cfg.model.eval_text_encode_way == 'clip': + raise NotImplementedError + + elif self.cfg.model.eval_text_encode_way == 't5': + raise NotImplementedError + + elif 'GRU' in self.cfg.model.eval_text_encode_way: + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + else: + raise NotImplementedError + if self.cfg.TRAIN.use_joints: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + else: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + } + + + return rs_set + + + + + + def t2m_eval_smplx_text_body(self, batch): + assert self.condition == 'text_body' + texts = [] + for i in range(len(batch["text"])): + texts.append(batch["text"][i] +' ' + batch['body_text'][i]) + + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + # start + start = time.time() + + if self.trainer.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text_body': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + else: + raise NotImplementedError + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + + # end time + end = time.time() + self.times.append(end - start) + + # joints recover + if self.cfg.TRAIN.use_joints: + joints_rst = self.feats2joints(feats_rst, self.motion_type, self.smplx_model) + joints_ref = self.feats2joints(motions, self.motion_type, self.smplx_model) + + # renorm for t2m evaluators + feats_rst = self.datamodule.renorm4t2m(feats_rst) + motions = self.datamodule.renorm4t2m(motions) + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + assert self.motion_type == 'smplx_212' + + + + recons_mov = self.t2m_moveencoder(feats_rst).detach() + recons_emb = self.t2m_motionencoder(recons_mov, m_lens) + motion_mov = self.t2m_moveencoder(motions).detach() + motion_emb = self.t2m_motionencoder(motion_mov, m_lens) + + # t2m text encoder + if self.cfg.model.eval_text_source == 'token': + text_emb = self.t2m_textencoder(word_embs, pos_ohot,text_lengths)[align_idx] + elif self.cfg.model.eval_text_source == 'only_text_token': + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + elif self.cfg.model.eval_text_source in ['caption']: + if self.cfg.model.eval_text_encode_way == 'clip': + raise NotImplementedError + + elif self.cfg.model.eval_text_encode_way == 't5': + raise NotImplementedError + + elif 'GRU' in self.cfg.model.eval_text_encode_way: + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + else: + raise NotImplementedError + if self.cfg.TRAIN.use_joints: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + else: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + } + + + return rs_set + + + + + def t2m_eval_smplx_text_hand(self, batch): + assert self.condition == 'text_hand' + texts = [] + for i in range(len(batch["text"])): + texts.append(batch["text"][i] +' ' + batch['hand_text'][i]) + + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + # start + start = time.time() + + if self.trainer.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text_hand': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + else: + raise NotImplementedError + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + + # end time + end = time.time() + self.times.append(end - start) + + # joints recover + if self.cfg.TRAIN.use_joints: + joints_rst = self.feats2joints(feats_rst, self.motion_type, self.smplx_model) + joints_ref = self.feats2joints(motions, self.motion_type, self.smplx_model) + + # renorm for t2m evaluators + feats_rst = self.datamodule.renorm4t2m(feats_rst) + motions = self.datamodule.renorm4t2m(motions) + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + assert self.motion_type == 'smplx_212' + + + + recons_mov = self.t2m_moveencoder(feats_rst).detach() + recons_emb = self.t2m_motionencoder(recons_mov, m_lens) + motion_mov = self.t2m_moveencoder(motions).detach() + motion_emb = self.t2m_motionencoder(motion_mov, m_lens) + + # t2m text encoder + if self.cfg.model.eval_text_source == 'token': + text_emb = self.t2m_textencoder(word_embs, pos_ohot,text_lengths)[align_idx] + elif self.cfg.model.eval_text_source == 'only_text_token': + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + elif self.cfg.model.eval_text_source in ['caption']: + if self.cfg.model.eval_text_encode_way == 'clip': + raise NotImplementedError + + elif self.cfg.model.eval_text_encode_way == 't5': + raise NotImplementedError + + elif 'GRU' in self.cfg.model.eval_text_encode_way: + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + else: + raise NotImplementedError + if self.cfg.TRAIN.use_joints: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + else: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + } + + + return rs_set + + + + def t2m_eval_smplx_text_face_body(self, batch): + assert self.condition == 'text_face_body' + texts = [] + for i in range(len(batch["text"])): + texts.append(batch["text"][i] +' ' + batch['face_text'][i] + ' ' + batch["body_text"][i]) + + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + # start + start = time.time() + + if self.trainer.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + if self.stage in ['diffusion', 'vae_diffusion']: + # diffusion reverse + if self.do_classifier_free_guidance: + uncond_tokens = [""] * len(texts) + if self.condition == 'text_face_body': + uncond_tokens.extend(texts) + elif self.condition == 'text_uncond': + uncond_tokens.extend(uncond_tokens) + else: + raise NotImplementedError + texts = uncond_tokens + text_emb = self.text_encoder(texts) + z = self._diffusion_reverse(text_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert", "actor"]: + z, dist_m = self.vae.encode(motions, lengths) + else: + raise TypeError("Not supported vae type!") + if self.condition in ['text_uncond']: + # uncond random sample + z = torch.randn_like(z) + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert", "actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + + # end time + end = time.time() + self.times.append(end - start) + + # joints recover + if self.cfg.TRAIN.use_joints: + joints_rst = self.feats2joints(feats_rst, self.motion_type, self.smplx_model) + joints_ref = self.feats2joints(motions, self.motion_type, self.smplx_model) + + # renorm for t2m evaluators + feats_rst = self.datamodule.renorm4t2m(feats_rst) + motions = self.datamodule.renorm4t2m(motions) + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + assert self.motion_type == 'smplx_212' + + + recons_mov = self.t2m_moveencoder(feats_rst).detach() + recons_emb = self.t2m_motionencoder(recons_mov, m_lens) + motion_mov = self.t2m_moveencoder(motions).detach() + motion_emb = self.t2m_motionencoder(motion_mov, m_lens) + + # t2m text encoder + if self.cfg.model.eval_text_source == 'token': + text_emb = self.t2m_textencoder(word_embs, pos_ohot,text_lengths)[align_idx] + elif self.cfg.model.eval_text_source == 'only_text_token': + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + elif self.cfg.model.eval_text_source in ['caption']: + if self.cfg.model.eval_text_encode_way == 'clip': + raise NotImplementedError + + elif self.cfg.model.eval_text_encode_way == 't5': + raise NotImplementedError + + elif 'GRU' in self.cfg.model.eval_text_encode_way: + text_emb = self.t2m_textencoder(word_embs, text_lengths)[align_idx] + else: + raise NotImplementedError + if self.cfg.TRAIN.use_joints: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + else: + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + } + + + return rs_set + + + + def a2m_eval(self, batch): + actions = batch["action"] + actiontexts = batch["action_text"] + motions = batch["motion"].detach().clone() + lengths = batch["length"] + + if self.do_classifier_free_guidance: + cond_emb = torch.cat((torch.zeros_like(actions), actions)) + + if self.stage in ['diffusion', 'vae_diffusion']: + z = self._diffusion_reverse(cond_emb, lengths) + elif self.stage in ['vae']: + if self.vae_type in ["mld", "vposert","actor"]: + z, dist_m = self.vae.encode(motions, lengths) + else: + raise TypeError("vae_type must be mcross or actor") + + with torch.no_grad(): + if self.vae_type in ["mld", "vposert","actor"]: + feats_rst = self.vae.decode(z, lengths) + elif self.vae_type == "no": + feats_rst = z.permute(1, 0, 2) + else: + raise TypeError("vae_type must be mcross or actor or mld") + + mask = batch["mask"] + joints_rst = self.feats2joints(feats_rst, mask) + joints_ref = self.feats2joints(motions, mask) + joints_eval_rst = self.feats2joints_eval(feats_rst, mask) + joints_eval_ref = self.feats2joints_eval(motions, mask) + + rs_set = { + "m_action": actions, + "m_ref": motions, + "m_rst": feats_rst, + "m_lens": lengths, + "joints_rst": joints_rst, + "joints_ref": joints_ref, + "joints_eval_rst": joints_eval_rst, + "joints_eval_ref": joints_eval_ref, + } + return rs_set + + def a2m_gt(self, batch): + actions = batch["action"] + actiontexts = batch["action_text"] + motions = batch["motion"].detach().clone() + lengths = batch["length"] + mask = batch["mask"] + + joints_ref = self.feats2joints(motions.to('cuda'), mask.to('cuda')) + + rs_set = { + "m_action": actions, + "m_text": actiontexts, + "m_ref": motions, + "m_lens": lengths, + "joints_ref": joints_ref, + } + return rs_set + + def eval_gt(self, batch, renoem=True): + + motions = batch["motion"].detach().clone() + lengths = batch["length"] + + # feats_rst = self.datamodule.renorm4t2m(feats_rst) + if renoem: + motions = self.datamodule.renorm4t2m(motions) + + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + word_embs = batch["word_embs"].detach() + pos_ohot = batch["pos_ohot"].detach() + text_lengths = batch["text_len"].detach() + + motion_mov = self.t2m_moveencoder(motions[..., :-4]).detach() + motion_emb = self.t2m_motionencoder(motion_mov, m_lens) + + # t2m text encoder + text_emb = self.t2m_textencoder(word_embs, pos_ohot, + text_lengths)[align_idx] + + # joints recover + joints_ref = self.feats2joints(motions) + + rs_set = { + "m_ref": motions, + "lat_t": text_emb, + "lat_m": motion_emb, + "joints_ref": joints_ref, + } + return rs_set + + def allsplit_step(self, split: str, batch, batch_idx): + if split in ["train", "val"]: + if self.stage == "vae": + if self.vae_type in ["mld", "vposert","actor"]: + rs_set = self.train_vae_forward(batch) + rs_set["lat_t"] = rs_set["lat_m"] + else: + rs_set = self.train_vae_forward(batch) + elif self.stage == "diffusion": + rs_set = self.train_diffusion_forward(batch) + elif self.stage == "vae_diffusion": + vae_rs_set = self.train_vae_forward(batch) + diff_rs_set = self.train_diffusion_forward(batch) + t2m_rs_set = self.test_diffusion_forward(batch, + finetune_decoder=True) + # merge results + rs_set = { + **vae_rs_set, + **diff_rs_set, + "gen_m_rst": t2m_rs_set["m_rst"], + "gen_joints_rst": t2m_rs_set["joints_rst"], + "lat_t": t2m_rs_set["lat_t"], + } + + else: + raise ValueError(f"Not support this stage {self.stage}!") + loss = self.losses[split].update(rs_set) + if loss is None: + raise ValueError( + "Loss is None, this happend with torchmetrics > 0.7") + + # Compute the metrics - currently evaluate results from text to motion + if split in ["val", "test"]: + if self.condition in ['text', 'text_uncond', 'text_all', 'text_face', 'text_body', 'text_hand', 'text_face_body', 'text_seperate', 'only_pose_concat', 'only_pose_fusion']: + # use t2m evaluators + if self.input_format in ['vector_263', 'root_body_pos_vel_hand_pos_vel']: + if self.condition == 'text': + if self.cfg.TEST.inference_vq_code: + rs_set = self.t2m_eval_save_motion_token(batch) + else: + if self.cfg.EVAL.use_tmr_eval: + rs_set = self.tmr_t2m_eval(batch) + else: + rs_set = self.t2m_eval(batch) + else: + raise NotImplementedError + elif self.input_format in ['smplx_212', 'smplx_159']: + if self.condition == 'text': + if self.cfg.TEST.inference_vq_code: + rs_set = self.t2m_eval_smplx_save_motion_token(batch) + else: + rs_set = self.t2m_eval_smplx(batch) + elif self.condition == 'text_all': + rs_set = self.t2m_eval_smplx_text_all(batch) + elif self.condition == 'text_face': + rs_set = self.t2m_eval_smplx_text_face(batch) + elif self.condition == 'text_body': + rs_set = self.t2m_eval_smplx_text_body(batch) + elif self.condition == 'text_hand': + rs_set = self.t2m_eval_smplx_text_hand(batch) + elif self.condition == 'text_face_body': + rs_set = self.t2m_eval_smplx_text_face_body(batch) + else: + raise NotImplementedError + # elif self.input_format in ['root_position', 'root_position_vel', 'root_position_rot6d', 'root_rot6d', 'all', 'root_body_pos_vel_hand_all', 'root_body_pos_vel_hand_pos_vel', 'root_body_pos_vel_hand_pos', 'root_body_pos_vel_hand_rot', 'root_position_vel_only_body', 'root_body_pos_vel_hand_pos_vel_hand_wrist']: + elif not self.eval_on_text: + rs_set = self.normal_eval(batch) + else: + rs_set = self.t2m_eval(batch) + # else: + # raise NotImplementedError + + elif self.condition == 'action': + # use a2m evaluators + rs_set = self.a2m_eval(batch) + else: + raise NotImplementedError + # MultiModality evaluation sperately + if self.trainer.datamodule.is_mm: + metrics_dicts = ['MMMetrics'] + else: + metrics_dicts = self.metrics_dict + + # metrics_dicts = [] + for metric in metrics_dicts: + if metric == "TemosMetric": + phase = split if split != "val" else "eval" + if eval(f"self.cfg.{phase.upper()}.DATASETS")[0].lower( + ) not in [ + "humanml3d", + "kit", + "motionx", + "motionx_v25", + 'motionx_v26' + ]: + raise TypeError( + "APE and AVE metrics only support humanml3d and kit datasets now" + ) + getattr(self, metric).update(rs_set["joints_rst"], + rs_set["joints_ref"], + batch["length"]) + + elif metric == "TemosMetric_body_hand": + phase = split if split != "val" else "eval" + if eval(f"self.cfg.{phase.upper()}.DATASETS")[0].lower( + ) not in [ + "humanml3d", + "kit", + "motionx", + "motionx_v25", + 'motionx_v26' + ]: + raise TypeError( + "APE and AVE metrics only support humanml3d and kit datasets now" + ) + getattr(self, metric).update(rs_set["joints_rst"], + rs_set["joints_ref"], + batch["length"]) + + elif metric == "TM2TMetrics": + getattr(self, metric).update( + # lat_t, latent encoded from diffusion-based text + # lat_rm, latent encoded from reconstructed motion + # lat_m, latent encoded from gt motion + # rs_set['lat_t'], rs_set['lat_rm'], rs_set['lat_m'], batch["length"]) + rs_set["lat_t"], + rs_set["lat_rm"], + rs_set["lat_m"], + batch["length"], + ) + elif metric == "TM2TMetrics_R256": + getattr(self, metric).update( + # lat_t, latent encoded from diffusion-based text + # lat_rm, latent encoded from reconstructed motion + # lat_m, latent encoded from gt motion + # rs_set['lat_t'], rs_set['lat_rm'], rs_set['lat_m'], batch["length"]) + rs_set["lat_t"], + rs_set["lat_rm"], + rs_set["lat_m"], + batch["length"], + ) + elif metric == "TMR_TM2TMetrics": + getattr(self, metric).update( + # lat_t, latent encoded from diffusion-based text + # lat_rm, latent encoded from reconstructed motion + # lat_m, latent encoded from gt motion + # rs_set['lat_t'], rs_set['lat_rm'], rs_set['lat_m'], batch["length"]) + rs_set["lat_t_tmr"], + rs_set["lat_rm_tmr"], + rs_set["lat_m_tmr"], + batch["length"], + ) + elif metric == "UncondMetrics": + getattr(self, metric).update( + recmotion_embeddings=rs_set["lat_rm"], + gtmotion_embeddings=rs_set["lat_m"], + lengths=batch["length"], + ) + elif metric in ["MRMetrics", "MRMetrics_body_hand"]: + if self.cfg.TEST.inference_vq_code: + getattr(self, metric).update(rs_set["joints_rst"], + rs_set["joints_ref"], + batch["length"], + rs_set["name"]) + else: + getattr(self, metric).update(rs_set["joints_rst"], + rs_set["joints_ref"], + batch["length"]) + + elif metric == "MMMetrics": + getattr(self, metric).update(rs_set["lat_rm"].unsqueeze(0), + batch["length"]) + elif metric == "HUMANACTMetrics": + getattr(self, metric).update(rs_set["m_action"], + rs_set["joints_eval_rst"], + rs_set["joints_eval_ref"], + rs_set["m_lens"]) + elif metric == "UESTCMetrics": + # the stgcn model expects rotations only + getattr(self, metric).update( + rs_set["m_action"], + rs_set["m_rst"].view(*rs_set["m_rst"].shape[:-1], 6, + 25).permute(0, 3, 2, 1)[:, :-1], + rs_set["m_ref"].view(*rs_set["m_ref"].shape[:-1], 6, + 25).permute(0, 3, 2, 1)[:, :-1], + rs_set["m_lens"]) + else: + raise TypeError(f"Not support this metric {metric}") + + # return forward output rather than loss during test + # self.datamodule.renorm4t2m + if split in ["test"]: + if self.cfg.TEST.inference_vq_code: + if self.vae_type in ["hvq", "hvq_body_hand"]: + return rs_set["motion_code_t"], rs_set["motion_code_b"], batch["name"] + else: + return rs_set["motion_code"], batch["name"] + + if self.motion_type == 'vector_263': + return rs_set["joints_rst"], batch["length"] + elif self.motion_type in ['smplx_212', 'smplx_159']: + if self.cfg.TRAIN.use_joints: + return rs_set["m_rst"], batch["length"], rs_set["m_ref"] + else: + return batch["length"] + elif self.motion_type in ['ric_rot']: + return rs_set["joints_rst"], batch["length"], rs_set["joints_ref"] + + else: + return batch["length"] + return loss diff --git a/Evaluator_272/mld/models/modeltype/temos.py b/Evaluator_272/mld/models/modeltype/temos.py new file mode 100644 index 0000000000000000000000000000000000000000..686de3c1fec9f67dbcb3400b4f84f00a2e44cab9 --- /dev/null +++ b/Evaluator_272/mld/models/modeltype/temos.py @@ -0,0 +1,662 @@ +from typing import List, Optional + +import torch +from torch import Tensor +from omegaconf import DictConfig +from mld.models.tools.tools import remove_padding + +from mld.models.metrics import ComputeMetrics +from torchmetrics import MetricCollection +from mld.models.modeltype.base import BaseModel +from torch.distributions.distribution import Distribution +from mld.config import instantiate_from_config + +from mld.models.losses.temos import TemosLosses +from torch.optim import AdamW +from sentence_transformers import SentenceTransformer + +from mld.models.architectures import t2m_textenc, t2m_motionenc +import os + +import time + +import numpy as np +import torch.nn.functional as f +from pathlib import Path + +class TEMOS(BaseModel): + def __init__(self, cfg, datamodule, **kwargs): + super().__init__() + + self.is_vae = cfg.model.vae + self.cfg = cfg + self.condition = cfg.model.condition + self.stage = cfg.TRAIN.STAGE + self.datamodule = datamodule + self.njoints = cfg.DATASET.NJOINTS + self.debug = cfg.DEBUG + self.motion_type = cfg.DATASET.MOTION_TYPE + + self.textencoder = instantiate_from_config(cfg.textencoder) + self.motionencoder = instantiate_from_config(cfg.motionencoder) + self.motiondecoder = instantiate_from_config(cfg.motiondecoder) + + + if self.condition in ["text", "text_uncond", 'text_all', 'text_face', 'text_body', 'text_hand', 'text_face_body', 'text_seperate', 'only_pose_concat', 'only_pose_fusion']: + self.feats2joints = datamodule.feats2joints + + if cfg.TRAIN.OPTIM.TYPE.lower() == "adamw": + self.optimizer = AdamW(lr=cfg.TRAIN.OPTIM.LR, + params=self.parameters()) + else: + raise NotImplementedError( + "Do not support other optimizer for now.") + + + self._losses = MetricCollection({ + split: TemosLosses(vae=self.is_vae, mode="xyz", cfg=cfg) + for split in ["losses_train", "losses_test", "losses_val"] + }) + + self.losses = {key: self._losses["losses_" + key] for key in ["train", "test", "val"]} + + self.metrics_dict = cfg.METRIC.TYPE + self.configure_metrics() + + # If we want to overide it at testing time + self.sample_mean = False + self.fact = None + + if self.cfg.LOSS.USE_INFONCE_FILTER: + self.filter_model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') + + self.retrieval_text_embedding = [] + self.retrieval_motion_embedding = [] + self.retrieval_sbert_embedding = [] + + self.retrieval_corres_name = [] + + self.gt_idx = 0 + + self.__post_init__() + + # Forward: text => motion + def forward(self, batch: dict) -> List[Tensor]: + datastruct_from_text = self.text_to_motion_forward(batch["text"], + batch["length"]) + + return remove_padding(datastruct_from_text.joints, batch["length"]) + + + def _get_t2m_evaluator(self, cfg): + """ + load T2M text encoder and motion encoder for evaluating + """ + + # init module + if cfg.model.eval_text_source == 'token': + + self.t2m_textencoder = t2m_textenc.TextEncoderBiGRUCo(word_size=cfg.model.t2m_textencoder.dim_word, + pos_size=cfg.model.t2m_textencoder.dim_pos_ohot, + hidden_size=cfg.model.t2m_textencoder.dim_text_hidden, + output_size=cfg.model.t2m_textencoder.dim_coemb_hidden, + ) + elif cfg.model.eval_text_source == 'only_text_token': + + self.t2m_textencoder = t2m_textenc.TextEncoderBiGRUCoV2(word_size=cfg.model.t2m_textencoder.dim_word, + hidden_size=cfg.model.t2m_textencoder.dim_text_hidden, + output_size=cfg.model.t2m_textencoder.dim_coemb_hidden, + ) + + elif cfg.model.eval_text_source in ['caption']: + + if cfg.model.eval_text_encode_way == 'clip': + self.t2m_textencoder, clip_preprocess = clip.load("ViT-B/32", device=opt.device, jit=False) # Must set jit=False for training + clip.model.convert_weights(text_enc)# Actually this line is unnecessary since clip by default already on float16 + self.t2m_textencoder.eval() + for p in text_enc.parameters(): + p.requires_grad = False + + elif cfg.model.eval_text_encode_way == 't5': + os.environ["TOKENIZERS_PARALLELISM"] = "false" + self.t2m_textencoder = SentenceTransformer('sentence-transformers/sentence-t5-xl').to(opt.device) + self.t2m_textencoder.eval() + for p in self.t2m_textencoder.parameters(): + p.requires_grad = False + + elif 'GRU' in cfg.model.eval_text_encode_way: + self.t2m_textencoder = t2m_textenc.TextEncoderBiGRUCoV2(word_size=cfg.model.t2m_textencoder.dim_word, + hidden_size=cfg.model.t2m_textencoder.dim_text_hidden, + output_size=cfg.model.t2m_textencoder.dim_coemb_hidden, + ) + else: + raise NotImplementedError + + + + self.t2m_moveencoder = t2m_motionenc.MovementConvEncoder( + input_size=cfg.DATASET.NFEATS - 4, + hidden_size=cfg.model.t2m_motionencoder.dim_move_hidden, + output_size=cfg.model.t2m_motionencoder.dim_move_latent, + ) + + + self.t2m_motionencoder = t2m_motionenc.MotionEncoderBiGRUCo( + input_size=cfg.model.t2m_motionencoder.dim_move_latent, + hidden_size=cfg.model.t2m_motionencoder.dim_motion_hidden, + output_size=cfg.model.t2m_motionencoder.dim_motion_latent, + ) + + # load pretrianed + dataname = cfg.TEST.DATASETS[0] + + t2m_checkpoint = torch.load( + os.path.join(cfg.model.t2m_path, dataname, + "text_mot_match/model/finest.tar"), map_location=torch.device('cpu')) + + self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) + + self.t2m_moveencoder.load_state_dict( + t2m_checkpoint["movement_encoder"]) + + + self.t2m_motionencoder.load_state_dict( + t2m_checkpoint["motion_encoder"]) + + # freeze params + self.t2m_textencoder.eval() + self.t2m_moveencoder.eval() + self.t2m_motionencoder.eval() + for p in self.t2m_textencoder.parameters(): + p.requires_grad = False + for p in self.t2m_moveencoder.parameters(): + p.requires_grad = False + for p in self.t2m_motionencoder.parameters(): + p.requires_grad = False + + + + def sample_from_distribution(self, distribution: Distribution, *, + fact: Optional[bool] = None, + sample_mean: Optional[bool] = False) -> Tensor: + fact = fact if fact is not None else self.fact + sample_mean = sample_mean if sample_mean is not None else self.sample_mean + + if sample_mean: + return distribution.loc + + # Reparameterization trick + if fact is None: + return distribution.rsample() + + # Resclale the eps + eps = distribution.rsample() - distribution.loc + latent_vector = distribution.loc + fact * eps + return latent_vector + + def text_to_motion_forward(self, text_sentences: List[str], lengths: List[int], *, + return_latent: bool = False): + # Encode the text to the latent space + if self.is_vae: + distribution = self.textencoder(text_sentences) + latent_vector = self.sample_from_distribution(distribution) + else: + distribution = None + latent_vector = self.textencoder(text_sentences) + + # Decode the latent vector to a motion + features = self.motiondecoder(latent_vector, lengths) + # datastruct = self.Datastruct(features=features) + + if not return_latent: + return features + return features, latent_vector, distribution + + def motion_to_motion_forward(self, features, + lengths: Optional[List[int]] = None, + return_latent: bool = False + ): + if self.is_vae: + distribution = self.motionencoder(features, lengths) + latent_vector = self.sample_from_distribution(distribution) + else: + distribution = None + latent_vector: Tensor = self.motionencoder(features, lengths) + + # Decode the latent vector to a motion + features = self.motiondecoder(latent_vector, lengths) + # datastruct = self.Datastruct(features=features) + + if not return_latent: + return features + return features, latent_vector, distribution + + + def save_embeddings(self, batch): + + with torch.no_grad(): + motion_all, text_all = None, None + sbert_embedding_all = None + + texts = batch["text"] + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + retrieval_name = batch['retrieval_name'] + + text_embedding = self.textencoder(texts).loc + motion_embedding = self.motionencoder(motions, lengths).loc + + Emb_text = f.normalize(text_embedding, dim=1) + Emb_motion = f.normalize(motion_embedding, dim=1) + + if text_all == None: + text_all = Emb_text + else: + text_all = torch.cat((text_all, Emb_text), 0) + + if motion_all == None: + motion_all = Emb_motion + else: + motion_all = torch.cat((motion_all, Emb_motion), 0) + + if self.cfg.LOSS.USE_INFONCE_FILTER: + sbert_embedding = torch.tensor(self.filter_model.encode(texts)) # (bs, 384) + sbert_embedding = f.normalize(sbert_embedding, dim=1) + + if sbert_embedding_all == None: + sbert_embedding_all = sbert_embedding + else: + sbert_embedding_all = torch.cat((sbert_embedding_all, sbert_embedding), 0) + + self.retrieval_sbert_embedding.append(sbert_embedding_all.detach().cpu().numpy()) + + self.retrieval_text_embedding.append(text_all.detach().cpu().numpy()) + self.retrieval_motion_embedding.append(motion_all.detach().cpu().numpy()) + self.retrieval_corres_name.append(retrieval_name) + + + + def t2m_eval(self, batch): + retrieval_name = batch['retrieval_name'] + texts = batch["text"] + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + + # start + start = time.time() + + if self.trainer.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + assert self.stage in ['temos'] + + # Encode the text/decode to a motion + with torch.no_grad(): + ret = self.text_to_motion_forward(texts, + lengths, + return_latent=True) + feat_from_text, latent_from_text, distribution_from_text = ret + + # Encode the motion/decode to a motion + ret = self.motion_to_motion_forward(motions, + lengths, + return_latent=True) + feat_from_motion, latent_from_motion, distribution_from_motion = ret + + # end time + end = time.time() + self.times.append(end - start) + + # joints recover + joints_ref = self.feats2joints(motions) + joints_rst = self.feats2joints(feat_from_text) + + # renorm for t2m evaluators + feats_rst = self.datamodule.renorm4t2m(feat_from_text) + motions = self.datamodule.renorm4t2m(motions) + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + # "lat_t": text_emb, + # "lat_m": motion_emb, + # "lat_rm": recons_emb, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + + return rs_set + + + def tmr_gt_eval(self, batch): + texts = batch["text"] + motions = batch["motion"].detach().clone() + lengths = batch["length"] + # word_embs = batch["word_embs"].detach().clone() + # pos_ohot = batch["pos_ohot"].detach().clone() + # text_lengths = batch["text_len"].detach().clone() + name = batch["retrieval_name"] + bs, seq = motions.shape[:2] + + # start + start = time.time() + + if self.trainer.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + bs = self.cfg.TEST.MM_NUM_REPEATS + + assert self.stage in ['temos'] + self.textencoder.eval() + self.motionencoder.eval() + self.motiondecoder.eval() + with torch.no_grad(): + + ret = self.text_to_motion_forward(texts, + lengths, + return_latent=True) + feat_from_text, latent_from_text, distribution_from_text = ret + # Encode the motion/decode to a motion + ret = self.motion_to_motion_forward(motions, + lengths, + return_latent=True) + feat_from_motion, latent_from_motion, distribution_from_motion = ret + + ret = self.motion_to_motion_forward(feat_from_motion, lengths, return_latent=True) + _, latent_from_motion_rst_motion, _ = ret + + # end time + end = time.time() + self.times.append(end - start) + # joints recover + joints_ref = self.feats2joints(motions) + joints_rst = self.feats2joints(feat_from_text) + + + # #########################saving output################### + feats_rst = self.datamodule.renorm4t2m(feat_from_text) + motions = self.datamodule.renorm4t2m(motions) + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D_272.UNIT_LEN, + rounding_mode="floor") + + recons_emb_tmr = latent_from_motion_rst_motion[align_idx] + motion_emb_tmr = latent_from_motion[align_idx] + text_emb_tmr = latent_from_text[align_idx] + + self.textencoder.train() + self.motionencoder.train() + self.motiondecoder.train() + + rs_set = { + "m_ref": motions, + "lat_t_tmr": text_emb_tmr, + "lat_m_tmr": motion_emb_tmr, + "lat_rm_tmr": recons_emb_tmr, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + return rs_set + + def allsplit_step(self, split: str, batch, batch_idx): + emb_dist = None + if self.cfg.LOSS.USE_INFONCE and self.cfg.LOSS.USE_INFONCE_FILTER: + with torch.no_grad(): + text_embedding = self.filter_model.encode(batch["text"]) + text_embedding = torch.tensor(text_embedding).to(batch['motion'][0]) + normalized = f.normalize(text_embedding, p=2, dim=1) + emb_dist = normalized.matmul(normalized.T) + + # Encode the text/decode to a motion + ret = self.text_to_motion_forward(batch["text"], + batch["length"], + return_latent=True) + feat_from_text, latent_from_text, distribution_from_text = ret + + # Encode the motion/decode to a motion + ret = self.motion_to_motion_forward(batch["motion"], + batch["length"], + return_latent=True) + feat_from_motion, latent_from_motion, distribution_from_motion = ret + + # GT data + # datastruct_ref = batch["datastruct"] + + # Compare to a Normal distribution + if self.is_vae: + # Create a centred normal distribution to compare with + mu_ref = torch.zeros_like(distribution_from_text.loc) + scale_ref = torch.ones_like(distribution_from_text.scale) + distribution_ref = torch.distributions.Normal(mu_ref, scale_ref) + else: + distribution_ref = None + # Compute the losses + loss = self.losses[split].update(f_text=feat_from_text, + f_motion=feat_from_motion, + f_ref=batch["motion"], + lat_text=latent_from_text, + lat_motion=latent_from_motion, + dis_text=distribution_from_text, + dis_motion=distribution_from_motion, + dis_ref=distribution_ref, + emb_dist=emb_dist) + + if loss is None: + raise ValueError("Loss is None, this happend with torchmetrics > 0.7") + + + if split in ["val", "test"]: + # self.save_embeddings(batch) + if self.cfg.EVAL.eval_self_on_gt: + rs_set = self.tmr_gt_eval(batch) + else: + if self.condition in ['text', 'text_uncond', 'text_all', 'text_face', 'text_body', 'text_hand', 'text_face_body', 'text_seperate', 'only_pose_concat', 'only_pose_fusion']: + # use t2m evaluators + rs_set = self.t2m_eval(batch) + elif self.condition == 'action': + # use a2m evaluators + rs_set = self.a2m_eval(batch) + else: + raise NotImplementedError + + # MultiModality evaluation sperately + if self.trainer.datamodule.is_mm: + metrics_dicts = ['MMMetrics'] + else: + metrics_dicts = self.metrics_dict + + for metric in metrics_dicts: + if metric == "TemosMetric": + phase = split if split != "val" else "eval" + if eval(f"self.cfg.{phase.upper()}.DATASETS")[0].lower( + ) not in [ + "humanml3d", + "kit" + ]: + raise TypeError( + "APE and AVE metrics only support humanml3d and kit datasets now" + ) + + getattr(self, metric).update(rs_set["joints_rst"], + rs_set["joints_ref"], + batch["length"]) + elif metric == "TM2TMetrics": + getattr(self, metric).update( + rs_set['lat_t'], + rs_set["lat_rm"], + rs_set["lat_m"], + batch["length"], + ) + elif metric == "UncondMetrics": + getattr(self, metric).update( + recmotion_embeddings=rs_set["lat_rm"], + gtmotion_embeddings=rs_set["lat_m"], + lengths=batch["length"], + ) + elif metric == "MRMetrics": + getattr(self, metric).update(rs_set["joints_rst"], + rs_set["joints_ref"], + batch["length"]) + elif metric == "MMMetrics": + getattr(self, metric).update(rs_set["lat_rm"].unsqueeze(0), + batch["length"]) + elif metric == "HUMANACTMetrics": + getattr(self, metric).update(rs_set["m_action"], + rs_set["joints_eval_rst"], + rs_set["joints_eval_ref"], + rs_set["m_lens"]) + elif metric == "TMR_TM2TMetrics": + getattr(self, metric).update( + rs_set["lat_t_tmr"], + rs_set["lat_rm_tmr"], + rs_set["lat_m_tmr"], + batch["length"], + ) + elif metric == "UESTCMetrics": + # the stgcn model expects rotations only + getattr(self, metric).update( + rs_set["m_action"], + rs_set["m_rst"].view(*rs_set["m_rst"].shape[:-1], 6, + 25).permute(0, 3, 2, 1)[:, :-1], + rs_set["m_ref"].view(*rs_set["m_ref"].shape[:-1], 6, + 25).permute(0, 3, 2, 1)[:, :-1], + rs_set["m_lens"]) + else: + raise TypeError(f"Not support this metric {metric}") + + + if split in ["test"]: + if self.motion_type == 'vector_263': + return rs_set["joints_rst"], batch["length"], batch["text"] + elif self.motion_type == 'smplx_212': + if self.cfg.TRAIN.use_joints: + return rs_set["m_rst"], batch["length"], rs_set["m_ref"] + else: + return batch["length"] + + return loss + + + def allsplit_epoch_end(self, split: str, outputs): + dico = {} + + if split in ["val", "test"]: + + if (self.trainer.current_epoch+1) % 1000 == 0: + output_dir = Path( + os.path.join( + self.cfg.FOLDER, + str(self.cfg.model.model_type), + str(self.cfg.NAME), + "embeddings", + split, + "epoch_" + str(self.trainer.current_epoch) + )) + + os.makedirs(output_dir, exist_ok=True) + + self.retrieval_text_embedding = torch.cat([i.view(-1, i.shape[-1]) for i in self.all_gather(self.retrieval_text_embedding)], dim=0) + self.retrieval_motion_embedding = torch.cat([i.view(-1, i.shape[-1]) for i in self.all_gather(self.retrieval_motion_embedding)], dim=0) + + + tmp_retrieval_name = [] + for i in self.all_gather(self.retrieval_corres_name): + tmp_retrieval_name += i + self.retrieval_corres_name = tmp_retrieval_name + with open(output_dir/"test_name_debug.txt", "w") as test_name_file: + for i in self.retrieval_corres_name: + test_name_file.write(i + '\n') + + if self.cfg.LOSS.USE_INFONCE_FILTER: + self.retrieval_sbert_embedding = torch.cat([i.view(-1, i.shape[-1]) for i in self.all_gather(self.retrieval_sbert_embedding)], dim=0) + np.save(output_dir/"sbert_embedding.npy", self.retrieval_sbert_embedding.detach().cpu().numpy()) + + + np.save(output_dir/"text_embedding.npy", self.retrieval_text_embedding.detach().cpu().numpy())# (2324, 256) + np.save(output_dir/"motion_embedding.npy", self.retrieval_motion_embedding.detach().cpu().numpy()) + + print('save embedding in {} at {}'.format(output_dir, self.trainer.current_epoch)) + + + self.retrieval_text_embedding = [] + self.retrieval_motion_embedding = [] + self.retrieval_sbert_embedding = [] + + if split in ["train", "val"]: + losses = self.losses[split] + loss_dict = losses.compute(split) + losses.reset() + dico.update({ + losses.loss2logname(loss, split): value.item() + for loss, value in loss_dict.items() if not torch.isnan(value) + }) + + if split in ["val", "test"]: + + if self.trainer.datamodule.is_mm and "TM2TMetrics" in self.metrics_dict: + metrics_dicts = ['MMMetrics'] + else: + metrics_dicts = self.metrics_dict + for metric in metrics_dicts: + metrics_dict = getattr( + self, + metric).compute(sanity_flag=self.trainer.sanity_checking) + # reset metrics + getattr(self, metric).reset() + dico.update({ + f"Metrics/{metric}": value.item() + for metric, value in metrics_dict.items() + }) + if split != "test": + dico.update({ + "epoch": float(self.trainer.current_epoch), + "step": float(self.trainer.current_epoch), + }) + # don't write sanity check into log + if not self.trainer.sanity_checking: + self.log_dict(dico, sync_dist=True, rank_zero_only=True) + + def training_epoch_end(self, outputs): + return self.allsplit_epoch_end("train", outputs) diff --git a/Evaluator_272/mld/models/operator/__init__.py b/Evaluator_272/mld/models/operator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1864a99ad508dba7501923a5e46c0b8f80c35473 --- /dev/null +++ b/Evaluator_272/mld/models/operator/__init__.py @@ -0,0 +1,4 @@ +from .adain import AdaptiveInstanceNorm1d +from .blocks import ConvBlock, LinearBlock +from .position_encoding_layer import PositionalEncoding + diff --git a/Evaluator_272/mld/models/operator/adain.py b/Evaluator_272/mld/models/operator/adain.py new file mode 100644 index 0000000000000000000000000000000000000000..3588f33e19fa3434ee2801f941c40566923abf41 --- /dev/null +++ b/Evaluator_272/mld/models/operator/adain.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class AdaptiveInstanceNorm1d(nn.Module): + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super(AdaptiveInstanceNorm1d, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = None + self.bias = None + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + + def forward(self, x, direct_weighting=False, no_std=False): + assert self.weight is not None and \ + self.bias is not None, "Please assign AdaIN weight first" + # (bs, nfeats, nframe) <= (nframe, bs, nfeats) + x = x.permute(1,2,0) + + b, c = x.size(0), x.size(1) # batch size & channels + running_mean = self.running_mean.repeat(b) + running_var = self.running_var.repeat(b) + # self.weight = torch.ones_like(self.weight) + + if direct_weighting: + x_reshaped = x.contiguous().view(b * c) + if no_std: + out = x_reshaped + self.bias + else: + out = x_reshaped.mul(self.weight) + self.bias + out = out.view(b, c, *x.size()[2:]) + else: + x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) + out = F.batch_norm( + x_reshaped, running_mean, running_var, self.weight, self.bias, + True, self.momentum, self.eps) + out = out.view(b, c, *x.size()[2:]) + + # (nframe, bs, nfeats) <= (bs, nfeats, nframe) + out = out.permute(2,0,1) + return out + + def __repr__(self): + return self.__class__.__name__ + '(' + str(self.num_features) + ')' + +def assign_adain_params(adain_params, model): + # assign the adain_params to the AdaIN layers in model + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm1d": + mean = adain_params[: , : m.num_features] + std = adain_params[: , m.num_features: 2 * m.num_features] + m.bias = mean.contiguous().view(-1) + m.weight = std.contiguous().view(-1) + if adain_params.size(1) > 2 * m.num_features: + adain_params = adain_params[: , 2 * m.num_features:] + + +def get_num_adain_params(model): + # return the number of AdaIN parameters needed by the model + num_adain_params = 0 + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm1d": + num_adain_params += 2 * m.num_features + return num_adain_params diff --git a/Evaluator_272/mld/models/operator/blocks.py b/Evaluator_272/mld/models/operator/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..1b44085b27d1c3980e6e035867d3341d0f276368 --- /dev/null +++ b/Evaluator_272/mld/models/operator/blocks.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mld.models.operator import AdaptiveInstanceNorm1d + + +class MLP(nn.Module): + + def __init__(self, cfg, out_dim, is_init): + super(MLP, self).__init__() + dims = cfg.MODEL.MOTION_DECODER.MLP_DIM + n_blk = len(dims) + norm = 'none' + acti = 'lrelu' + + layers = [] + for i in range(n_blk - 1): + layers += LinearBlock(dims[i], dims[i + 1], norm=norm, acti=acti) + layers += LinearBlock(dims[-1], out_dim, norm='none', acti='none') + self.model = nn.Sequential(*layers) + + if is_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + #nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(m.weight, 1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + return self.model(x.view(x.size(0), -1)) + + +def ZeroPad1d(sizes): + return nn.ConstantPad1d(sizes, 0) + + +def get_acti_layer(acti='relu', inplace=True): + + if acti == 'relu': + return [nn.ReLU(inplace=inplace)] + elif acti == 'lrelu': + return [nn.LeakyReLU(0.2, inplace=inplace)] + elif acti == 'tanh': + return [nn.Tanh()] + elif acti == 'none': + return [] + else: + assert 0, "Unsupported activation: {}".format(acti) + + +def get_norm_layer(norm='none', norm_dim=None): + + if norm == 'bn': + return [nn.BatchNorm1d(norm_dim)] + elif norm == 'in': + # return [nn.InstanceNorm1d(norm_dim, affine=False)] # for rt42! + return [nn.InstanceNorm1d(norm_dim, affine=True)] + elif norm == 'adain': + return [AdaptiveInstanceNorm1d(norm_dim)] + elif norm == 'none': + return [] + else: + assert 0, "Unsupported normalization: {}".format(norm) + + +def get_dropout_layer(dropout=None): + if dropout is not None: + return [nn.Dropout(p=dropout)] + else: + return [] + + +def ConvLayers(kernel_size, + in_channels, + out_channels, + stride=1, + pad_type='reflect', + use_bias=True): + """ + returns a list of [pad, conv] => should be += to some list, then apply sequential + """ + + if pad_type == 'reflect': + pad = nn.ReflectionPad1d + elif pad_type == 'replicate': + pad = nn.ReplicationPad1d + elif pad_type == 'zero': + pad = ZeroPad1d + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + pad_l = (kernel_size - 1) // 2 + pad_r = kernel_size - 1 - pad_l + return [ + pad((pad_l, pad_r)), + nn.Conv1d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + bias=use_bias) + ] + + +def ConvBlock(kernel_size, + in_channels, + out_channels, + stride=1, + pad_type='reflect', + dropout=None, + norm='none', + acti='lrelu', + acti_first=False, + use_bias=True, + inplace=True): + """ + returns a list of [pad, conv, norm, acti] or [acti, pad, conv, norm] + """ + + layers = ConvLayers(kernel_size, + in_channels, + out_channels, + stride=stride, + pad_type=pad_type, + use_bias=use_bias) + layers += get_dropout_layer(dropout) + layers += get_norm_layer(norm, norm_dim=out_channels) + acti_layers = get_acti_layer(acti, inplace=inplace) + + if acti_first: + return acti_layers + layers + else: + return layers + acti_layers + + +def LinearBlock(in_dim, out_dim, dropout=None, norm='none', acti='relu'): + + use_bias = True + layers = [] + layers.append(nn.Linear(in_dim, out_dim, bias=use_bias)) + layers += get_dropout_layer(dropout) + layers += get_norm_layer(norm, norm_dim=out_dim) + layers += get_acti_layer(acti) + + return layers diff --git a/Evaluator_272/mld/models/operator/conv2d_gradfix.py b/Evaluator_272/mld/models/operator/conv2d_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..64229c5a7fd04292140abac1f490619963009328 --- /dev/null +++ b/Evaluator_272/mld/models/operator/conv2d_gradfix.py @@ -0,0 +1,227 @@ +import contextlib +import warnings + +import torch +from torch import autograd +from torch.nn import functional as F + +enabled = True +weight_gradients_disabled = False + + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if could_use_op(input): + return conv2d_gradfix( + transpose=False, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=0, + dilation=dilation, + groups=groups, + ).apply(input, weight, bias) + + return F.conv2d( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + +def conv_transpose2d( + input, + weight, + bias=None, + stride=1, + padding=0, + output_padding=0, + groups=1, + dilation=1, +): + if could_use_op(input): + return conv2d_gradfix( + transpose=True, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ).apply(input, weight, bias) + + return F.conv_transpose2d( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + ) + + +def could_use_op(input): + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + + if input.device.type != "cuda": + return False + + if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): + return True + + warnings.warn( + f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." + ) + + return False + + +def ensure_tuple(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + + return xs + + +conv2d_gradfix_cache = dict() + + +def conv2d_gradfix( + transpose, weight_shape, stride, padding, output_padding, dilation, groups +): + ndim = 2 + weight_shape = tuple(weight_shape) + stride = ensure_tuple(stride, ndim) + padding = ensure_tuple(padding, ndim) + output_padding = ensure_tuple(output_padding, ndim) + dilation = ensure_tuple(dilation, ndim) + + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in conv2d_gradfix_cache: + return conv2d_gradfix_cache[key] + + common_kwargs = dict( + stride=stride, padding=padding, dilation=dilation, groups=groups + ) + + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + class Conv2d(autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + if not transpose: + out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + + else: + out = F.conv_transpose2d( + input=input, + weight=weight, + bias=bias, + output_padding=output_padding, + **common_kwargs, + ) + + ctx.save_for_backward(input, weight) + + return out + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input, grad_weight, grad_bias = None, None, None + + if ctx.needs_input_grad[0]: + p = calc_output_padding( + input_shape=input.shape, output_shape=grad_output.shape + ) + grad_input = conv2d_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs, + ).apply(grad_output, weight, None) + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum((0, 2, 3)) + + return grad_input, grad_weight, grad_bias + + class Conv2dGradWeight(autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + op = torch._C._jit_get_operation( + "aten::cudnn_convolution_backward_weight" + if not transpose + else "aten::cudnn_convolution_transpose_backward_weight" + ) + flags = [ + torch.backends.cudnn.benchmark, + torch.backends.cudnn.deterministic, + torch.backends.cudnn.allow_tf32, + ] + grad_weight = op( + weight_shape, + grad_output, + input, + padding, + stride, + dilation, + groups, + *flags, + ) + ctx.save_for_backward(grad_output, input) + + return grad_weight + + @staticmethod + def backward(ctx, grad_grad_weight): + grad_output, input = ctx.saved_tensors + grad_grad_output, grad_grad_input = None, None + + if ctx.needs_input_grad[0]: + grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) + + if ctx.needs_input_grad[1]: + p = calc_output_padding( + input_shape=input.shape, output_shape=grad_output.shape + ) + grad_grad_input = conv2d_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs, + ).apply(grad_output, grad_grad_weight, None) + + return grad_grad_output, grad_grad_input + + conv2d_gradfix_cache[key] = Conv2d + + return Conv2d diff --git a/Evaluator_272/mld/models/operator/cross_attention.py b/Evaluator_272/mld/models/operator/cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..deb1f053e575bd0940d12a9cc526a44f689f24c0 --- /dev/null +++ b/Evaluator_272/mld/models/operator/cross_attention.py @@ -0,0 +1,412 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import List, Optional +from numpy import block + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class SkipTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.d_model = encoder_layer.d_model + + self.num_layers = num_layers + self.norm = norm + + assert num_layers % 2 == 1 + + num_block = (num_layers-1)//2 + self.input_blocks = _get_clones(encoder_layer, num_block) + self.middle_block = _get_clone(encoder_layer) + self.output_blocks = _get_clones(encoder_layer, num_block) + self.linear_blocks = _get_clones(nn.Linear(2*self.d_model, self.d_model), num_block) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + x = src + + xs = [] + for module in self.input_blocks: + x = module(x, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + xs.append(x) + + x = self.middle_block(x, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + for (module, linear) in zip(self.output_blocks, self.linear_blocks): + x = torch.cat([x, xs.pop()], dim=-1) + x = linear(x) + x = module(x, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + x = self.norm(x) + return x + +class SkipTransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None): + super().__init__() + self.d_model = decoder_layer.d_model + + self.num_layers = num_layers + self.norm = norm + + assert num_layers % 2 == 1 + + num_block = (num_layers-1)//2 + self.input_blocks = _get_clones(decoder_layer, num_block) + self.middle_block = _get_clone(decoder_layer) + self.output_blocks = _get_clones(decoder_layer, num_block) + self.linear_blocks = _get_clones(nn.Linear(2*self.d_model, self.d_model), num_block) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + x = tgt + + xs = [] + for module in self.input_blocks: + x = module(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + xs.append(x) + + x = self.middle_block(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + + for (module, linear) in zip(self.output_blocks, self.linear_blocks): + x = torch.cat([x, xs.pop()], dim=-1) + x = linear(x) + x = module(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + + if self.norm is not None: + x = self.norm(x) + + return x + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + mask = mask.flatten(1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) + return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.d_model = d_model + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.d_model = d_model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clone(module): + return copy.deepcopy(module) + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") \ No newline at end of file diff --git a/Evaluator_272/mld/models/operator/position_encoding.py b/Evaluator_272/mld/models/operator/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..d0a2bf7030ed445f12761b581741145a3ad98072 --- /dev/null +++ b/Evaluator_272/mld/models/operator/position_encoding.py @@ -0,0 +1,185 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +from typing import List, Optional + +import numpy as np +import torch +from torch import Tensor, nn + +# from util.misc import NestedTensor + + +class NestedTensor(object): + + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, + dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], + dim=-1).permute(2, 0, 1).unsqueeze(0).repeat( + x.shape[0], 1, 1, 1) + return pos + + +class PositionEmbeddingSine1D(nn.Module): + + def __init__(self, d_model, max_len=500, batch_first=False): + super().__init__() + self.batch_first = batch_first + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange( + 0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + # not used in the final model + if self.batch_first: + pos = self.pe.permute(1, 0, 2)[:, :x.shape[1], :] + else: + pos = self.pe[:x.shape[0], :] + return pos + + +class PositionEmbeddingLearned1D(nn.Module): + + def __init__(self, d_model, max_len=500, batch_first=False): + super().__init__() + self.batch_first = batch_first + # self.dropout = nn.Dropout(p=dropout) + + self.pe = nn.Parameter(torch.zeros(max_len, 1, d_model)) + # self.pe = pe.unsqueeze(0).transpose(0, 1) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.pe) + + def forward(self, x): + # not used in the final model + if self.batch_first: + pos = self.pe.permute(1, 0, 2)[:, :x.shape[1], :] + else: + x = x + self.pe[:x.shape[0], :] + return x + # return self.dropout(x) + + +def build_position_encoding(N_steps, + position_embedding="sine", + embedding_dim="1D"): + # N_steps = hidden_dim // 2 + if embedding_dim == "1D": + if position_embedding in ('v2', 'sine'): + position_embedding = PositionEmbeddingSine1D(N_steps) + elif position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned1D(N_steps) + else: + raise ValueError(f"not supported {position_embedding}") + elif embedding_dim == "2D": + if position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {position_embedding}") + else: + raise ValueError(f"not supported {embedding_dim}") + + return position_embedding diff --git a/Evaluator_272/mld/models/operator/position_encoding_layer.py b/Evaluator_272/mld/models/operator/position_encoding_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..699c860bf5d28c384390196b086d93552b2cff64 --- /dev/null +++ b/Evaluator_272/mld/models/operator/position_encoding_layer.py @@ -0,0 +1,30 @@ +import numpy as np +import torch +from torch import nn + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False): + super().__init__() + self.batch_first = batch_first + + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange( + 0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer("pe", pe) + + def forward(self, x): + # not used in the final model + if self.batch_first: + x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] + else: + x = x + self.pe[: x.shape[0], :] + return self.dropout(x) diff --git a/Evaluator_272/mld/models/operator/self_attention.py b/Evaluator_272/mld/models/operator/self_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/models/tools/__init__.py b/Evaluator_272/mld/models/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/models/tools/hessian_penalty.py b/Evaluator_272/mld/models/tools/hessian_penalty.py new file mode 100644 index 0000000000000000000000000000000000000000..d5081cd7ba942a7c47ebce2dbcb50affa1007767 --- /dev/null +++ b/Evaluator_272/mld/models/tools/hessian_penalty.py @@ -0,0 +1,138 @@ +""" +## Adapted to work with our "batches" +Official PyTorch implementation of the Hessian Penalty regularization term from https://arxiv.org/pdf/2008.10599.pdf +Author: Bill Peebles +TensorFlow Implementation (GPU + Multi-Layer): hessian_penalty_tf.py +Simple Pure NumPy Implementation: hessian_penalty_np.py + +Simple use case where you want to apply the Hessian Penalty to the output of net w.r.t. net_input: +>>> from hessian_penalty_pytorch import hessian_penalty +>>> net = MyNeuralNet() +>>> net_input = sample_input() +>>> loss = hessian_penalty(net, z=net_input) # Compute hessian penalty of net's output w.r.t. net_input +>>> loss.backward() # Compute gradients w.r.t. net's parameters + +If your network takes multiple inputs, simply supply them to hessian_penalty as you do in the net's forward pass. In the +following example, we assume BigGAN.forward takes a second input argument "y". Note that we always take the Hessian +Penalty w.r.t. the z argument supplied to hessian_penalty: +>>> from hessian_penalty_pytorch import hessian_penalty +>>> net = BigGAN() +>>> z_input = sample_z_vector() +>>> class_label = sample_class_label() +>>> loss = hessian_penalty(net, z=net_input, y=class_label) +>>> loss.backward() +""" + +import torch + + +def hessian_penalty(G, batch, k=2, epsilon=0.1, reduction=torch.max, return_separately=False, G_z=None, **G_kwargs): + """ + Official PyTorch Hessian Penalty implementation. + + Note: If you want to regularize multiple network activations simultaneously, you need to + make sure the function G you pass to hessian_penalty returns a list of those activations when it's called with + G(z, **G_kwargs). Otherwise, if G returns a tensor the Hessian Penalty will only be computed for the final + output of G. + + :param G: Function that maps input z to either a tensor or a list of tensors (activations) + :param z: Input to G that the Hessian Penalty will be computed with respect to + :param k: Number of Hessian directions to sample (must be >= 2) + :param epsilon: Amount to blur G before estimating Hessian (must be > 0) + :param reduction: Many-to-one function to reduce each pixel/neuron's individual hessian penalty into a final loss + :param return_separately: If False, hessian penalties for each activation output by G are automatically summed into + a final loss. If True, the hessian penalties for each layer will be returned in a list + instead. If G outputs a single tensor, setting this to True will produce a length-1 + list. + :param G_z: [Optional small speed-up] If you have already computed G(z, **G_kwargs) for the current training + iteration, then you can provide it here to reduce the number of forward passes of this method by 1 + :param G_kwargs: Additional inputs to G besides the z vector. For example, in BigGAN you + would pass the class label into this function via y= + + :return: A differentiable scalar (the hessian penalty), or a list of hessian penalties if return_separately is True + """ + if G_z is None: + G_z = G(batch, **G_kwargs) + z = batch["x"] + rademacher_size = torch.Size((k, *z.size())) # (k, N, z.size()) + dzs = epsilon * rademacher(rademacher_size, device=z.device) + second_orders = [] + for dz in dzs: # Iterate over each (N, z.size()) tensor in xs + central_second_order = multi_layer_second_directional_derivative(G, batch, dz, G_z, epsilon, **G_kwargs) + second_orders.append(central_second_order) # Appends a tensor with shape equal to G(z).size() + loss = multi_stack_var_and_reduce(second_orders, reduction, return_separately) # (k, G(z).size()) --> scalar + return loss + + +def rademacher(shape, device='cpu'): + """Creates a random tensor of size [shape] under the Rademacher distribution (P(x=1) == P(x=-1) == 0.5)""" + x = torch.empty(shape, device=device) + x.random_(0, 2) # Creates random tensor of 0s and 1s + x[x == 0] = -1 # Turn the 0s into -1s + return x + + +def multi_layer_second_directional_derivative(G, batch, dz, G_z, epsilon, **G_kwargs): + """Estimates the second directional derivative of G w.r.t. its input at z in the direction x""" + batch_plus = {**batch, "x": batch["x"] + dz} + batch_moins = {**batch, "x": batch["x"] - dz} + G_to_x = G(batch_plus, **G_kwargs) + G_from_x = G(batch_moins, **G_kwargs) + + G_to_x = listify(G_to_x) + G_from_x = listify(G_from_x) + G_z = listify(G_z) + + eps_sqr = epsilon ** 2 + sdd = [(G2x - 2 * G_z_base + Gfx) / eps_sqr for G2x, G_z_base, Gfx in zip(G_to_x, G_z, G_from_x)] + return sdd + + +def stack_var_and_reduce(list_of_activations, reduction=torch.max): + """Equation (5) from the paper.""" + second_orders = torch.stack(list_of_activations) # (k, N, C, H, W) + var_tensor = torch.var(second_orders, dim=0, unbiased=True) # (N, C, H, W) + penalty = reduction(var_tensor) # (1,) (scalar) + return penalty + + +def multi_stack_var_and_reduce(sdds, reduction=torch.max, return_separately=False): + """Iterate over all activations to be regularized, then apply Equation (5) to each.""" + sum_of_penalties = 0 if not return_separately else [] + for activ_n in zip(*sdds): + penalty = stack_var_and_reduce(activ_n, reduction) + sum_of_penalties += penalty if not return_separately else [penalty] + return sum_of_penalties + + +def listify(x): + """If x is already a list, do nothing. Otherwise, wrap x in a list.""" + if isinstance(x, list): + return x + else: + return [x] + + +def _test_hessian_penalty(): + """ + A simple multi-layer test to verify the implementation. + Function: G(z) = [z_0 * z_1, z_0**2 * z_1] + Ground Truth Hessian Penalty: [4, 16 * z_0**2] + """ + batch_size = 10 + nz = 2 + z = torch.randn(batch_size, nz) + def reduction(x): return x.abs().mean() + def G(z): return [z[:, 0] * z[:, 1], (z[:, 0] ** 2) * z[:, 1]] + ground_truth = [4, reduction(16 * z[:, 0] ** 2).item()] + # In this simple example, we use k=100 to reduce variance, but when applied to neural networks + # you will probably want to use a small k (e.g., k=2) due to memory considerations. + predicted = hessian_penalty(G, z, G_z=None, k=100, reduction=reduction, return_separately=True) + predicted = [p.item() for p in predicted] + print('Ground Truth: %s' % ground_truth) + print('Approximation: %s' % predicted) # This should be close to ground_truth, but not exactly correct + print('Difference: %s' % [str(100 * abs(p - gt) / gt) + '%' for p, gt in zip(predicted, ground_truth)]) + + +if __name__ == '__main__': + _test_hessian_penalty() diff --git a/Evaluator_272/mld/models/tools/tools.py b/Evaluator_272/mld/models/tools/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..89ecab5616c1f0d46ed5bc9b348c5e6ad3ee603d --- /dev/null +++ b/Evaluator_272/mld/models/tools/tools.py @@ -0,0 +1,37 @@ +import torch.nn as nn + +def remove_padding(tensors, lengths): + return [tensor[:tensor_length] for tensor, tensor_length in zip(tensors, lengths)] + +class AutoParams(nn.Module): + def __init__(self, **kargs): + try: + for param in self.needed_params: + if param in kargs: + setattr(self, param, kargs[param]) + else: + raise ValueError(f"{param} is needed.") + except : + pass + + try: + for param, default in self.optional_params.items(): + if param in kargs and kargs[param] is not None: + setattr(self, param, kargs[param]) + else: + setattr(self, param, default) + except : + pass + super().__init__() + + +# taken from joeynmt repo +def freeze_params(module: nn.Module) -> None: + """ + Freeze the parameters of this module, + i.e. do not update them during training + + :param module: freeze parameters of this module + """ + for _, p in module.named_parameters(): + p.requires_grad = False diff --git a/Evaluator_272/mld/tools/__init__.py b/Evaluator_272/mld/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/tools/geometry.py b/Evaluator_272/mld/tools/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..e6eafa2e1f2459a0f6f5ad1280c71e6a9625549e --- /dev/null +++ b/Evaluator_272/mld/tools/geometry.py @@ -0,0 +1,566 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Check PYTORCH3D_LICENCE before use + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +# Added +def matrix_of_angles(cos, sin, inv=False, dim=2): + assert dim in [2, 3] + sin = -sin if inv else sin + if dim == 2: + row1 = torch.stack((cos, -sin), axis=-1) + row2 = torch.stack((sin, cos), axis=-1) + return torch.stack((row1, row2), axis=-2) + elif dim == 3: + row1 = torch.stack((cos, -sin, 0*cos), axis=-1) + row2 = torch.stack((sin, cos, 0*cos), axis=-1) + row3 = torch.stack((0*sin, 0*cos, 1+0*cos), axis=-1) + return torch.stack((row1, row2, row3),axis=-2) + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/Evaluator_272/mld/tools/logging.py b/Evaluator_272/mld/tools/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..216e521a1d203b8dc6f436dc91a6a6631419bb65 --- /dev/null +++ b/Evaluator_272/mld/tools/logging.py @@ -0,0 +1,40 @@ +import logging +import tqdm + + +class LevelsFilter(logging.Filter): + def __init__(self, levels): + self.levels = [getattr(logging, level) for level in levels] + + def filter(self, record): + return record.levelno in self.levels + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, level): + self.logger = logger + self.level = level + self.linebuf = '' + + def write(self, buf): + for line in buf.rstrip().splitlines(): + self.logger.log(self.level, line.rstrip()) + + def flush(self): + pass + + +class TqdmLoggingHandler(logging.Handler): + def __init__(self, level=logging.NOTSET): + super().__init__(level) + + def emit(self, record): + try: + msg = self.format(record) + tqdm.tqdm.write(msg) + self.flush() + except Exception: + self.handleError(record) diff --git a/Evaluator_272/mld/tools/runid.py b/Evaluator_272/mld/tools/runid.py new file mode 100644 index 0000000000000000000000000000000000000000..619e7696481eb3f91a0133fda4bce947cc853580 --- /dev/null +++ b/Evaluator_272/mld/tools/runid.py @@ -0,0 +1,13 @@ +# +""" +runid util. +Taken from wandb.sdk.lib.runid +""" + +import shortuuid # type: ignore + + +def generate_id() -> str: + # ~3t run ids (36**8) + run_gen = shortuuid.ShortUUID(alphabet=list("0123456789abcdefghijklmnopqrstuvwxyz")) + return run_gen.random(8) \ No newline at end of file diff --git a/Evaluator_272/mld/transforms/__init__.py b/Evaluator_272/mld/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c0ab9179fd0ed93c98ce0e9c90d75484594faa --- /dev/null +++ b/Evaluator_272/mld/transforms/__init__.py @@ -0,0 +1,3 @@ +from .base import Transform +from .smpl import SMPLTransform +# from .xyz import XYZTransform diff --git a/Evaluator_272/mld/transforms/base.py b/Evaluator_272/mld/transforms/base.py new file mode 100644 index 0000000000000000000000000000000000000000..2685e2b3fba90e1400c87903e244eae617d99e8f --- /dev/null +++ b/Evaluator_272/mld/transforms/base.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass, fields + + +class Transform: + + def collate(self, lst_datastruct): + from mld.datasets.utils import collate_tensor_with_padding + example = lst_datastruct[0] + + def collate_or_none(key): + if example[key] is None: + return None + key_lst = [x[key] for x in lst_datastruct] + return collate_tensor_with_padding(key_lst) + + kwargs = {key: collate_or_none(key) for key in example.datakeys} + + return self.Datastruct(**kwargs) + + +# Inspired from SMPLX library +# need to define "datakeys" and transforms +@dataclass +class Datastruct: + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + self.__dict__[key] = value + + def get(self, key, default=None): + return getattr(self, key, default) + + def __iter__(self): + return self.keys() + + def keys(self): + keys = [t.name for t in fields(self)] + return iter(keys) + + def values(self): + values = [getattr(self, t.name) for t in fields(self)] + return iter(values) + + def items(self): + data = [(t.name, getattr(self, t.name)) for t in fields(self)] + return iter(data) + + def to(self, *args, **kwargs): + for key in self.datakeys: + if self[key] is not None: + self[key] = self[key].to(*args, **kwargs) + return self + + @property + def device(self): + return self[self.datakeys[0]].device + + def detach(self): + + def detach_or_none(tensor): + if tensor is not None: + return tensor.detach() + return None + + kwargs = {key: detach_or_none(self[key]) for key in self.datakeys} + return self.transforms.Datastruct(**kwargs) diff --git a/Evaluator_272/mld/transforms/feats2smpl.py b/Evaluator_272/mld/transforms/feats2smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c8a5d9bfb844359e5910f3b9611c26190f70dc --- /dev/null +++ b/Evaluator_272/mld/transforms/feats2smpl.py @@ -0,0 +1,35 @@ +from os.path import join as pjoin + +import numpy as np +import torch + +import mld.data.humanml.utils.paramUtil as paramUtil +from mld.data.humanml.data.dataset import Text2MotionDatasetV2 +from mld.data.humanml.scripts.motion_process import recover_from_ric +from mld.data.humanml.utils.plot_script import plot_3d_motion + +skeleton = paramUtil.t2m_kinematic_chain + + + + +def main(): + data_root = '../datasets/humanml3d' + feastures_path = 'in.npy' + animation_save_path = 'in.mp4' + + fps = 20 + mean = np.load(pjoin(data_root, 'Mean.npy')) + std = np.load(pjoin(data_root, 'Std.npy')) + + motion = np.load(feastures_path) + motion = motion * std + mean + motion_rec = recover_from_ric(torch.tensor(motion), 22).cpu().numpy() + # with open('in_22.npy', 'wb') as f: + # np.save(f,motion_rec) + motion_rec = motion_rec * 1.3 + plot_3d_motion(animation_save_path, motion_rec, title='input', fps=fps) + + +if __name__ == '__main__': + main() diff --git a/Evaluator_272/mld/transforms/identity.py b/Evaluator_272/mld/transforms/identity.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e5540c8da75f6f839cd0e12672906273276dcc --- /dev/null +++ b/Evaluator_272/mld/transforms/identity.py @@ -0,0 +1,28 @@ +from typing import Optional +from torch import Tensor + +from .base import Datastruct, dataclass, Transform + + +class IdentityTransform(Transform): + def __init__(self, **kwargs): + return + + def Datastruct(self, **kwargs): + return IdentityDatastruct(**kwargs) + + def __repr__(self): + return "IdentityTransform()" + + +@dataclass +class IdentityDatastruct(Datastruct): + transforms: IdentityTransform + + features: Optional[Tensor] = None + + def __post_init__(self): + self.datakeys = ["features"] + + def __len__(self): + return len(self.rfeats) diff --git a/Evaluator_272/mld/transforms/joints2jfeats/__init__.py b/Evaluator_272/mld/transforms/joints2jfeats/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a924e845912842ec042b5b3195b8da7aee3f252 --- /dev/null +++ b/Evaluator_272/mld/transforms/joints2jfeats/__init__.py @@ -0,0 +1,2 @@ +from .base import Joints2Jfeats +from .rifke import Rifke diff --git a/Evaluator_272/mld/transforms/joints2jfeats/base.py b/Evaluator_272/mld/transforms/joints2jfeats/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a61a3b42848f5c5d7f803cc700d1b08bb4ecdbc5 --- /dev/null +++ b/Evaluator_272/mld/transforms/joints2jfeats/base.py @@ -0,0 +1,34 @@ +from typing import Optional + +import torch +from torch import Tensor, nn +from pathlib import Path + + +class Joints2Jfeats(nn.Module): + def __init__(self, path: Optional[str] = None, + normalization: bool = False, + eps: float = 1e-12, + **kwargs) -> None: + if normalization and path is None: + raise TypeError("You should provide a path if normalization is on.") + + super().__init__() + self.normalization = normalization + self.eps = eps + + if normalization: + mean_path = Path(path) / "jfeats_mean.pt" + std_path = Path(path) / "jfeats_std.pt" + self.register_buffer('mean', torch.load(mean_path)) + self.register_buffer('std', torch.load(std_path)) + + def normalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = (features - self.mean)/(self.std + self.eps) + return features + + def unnormalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = features * self.std + self.mean + return features diff --git a/Evaluator_272/mld/transforms/joints2jfeats/rifke.py b/Evaluator_272/mld/transforms/joints2jfeats/rifke.py new file mode 100644 index 0000000000000000000000000000000000000000..db97bd8338abe9c527c1c23a4ca5c5ea738867b1 --- /dev/null +++ b/Evaluator_272/mld/transforms/joints2jfeats/rifke.py @@ -0,0 +1,142 @@ +from typing import Optional + +import torch +from einops import rearrange +from torch import Tensor +from mld.utils.geometry import matrix_of_angles +from .base import Joints2Jfeats +from .tools import get_forward_direction, get_floor, gaussian_filter1d, T # noqa + + +class Rifke(Joints2Jfeats): + + def __init__(self, + jointstype: str = "mmm", + path: Optional[str] = None, + normalization: bool = False, + forward_filter: bool = False, + **kwargs) -> None: + if jointstype not in ["mmm", "mmmns", 'humanml3d', "motionx", "motionx_v26"]: + print("This function assume that the root is the first index") + raise NotImplementedError("This jointstype is not implemented.") + + super().__init__(path=path, normalization=normalization) + self.jointstype = jointstype + self.forward_filter = forward_filter + + def forward(self, joints: Tensor) -> Tensor: + # Joints to rotation invariant poses (Holden et. al.) + # Similar function than fke2rifke in Language2Pose repository + # Adapted to pytorch + # Put the origin center of the root joint instead of the ground projection + + poses = joints.clone() + poses[..., 1] -= get_floor(poses, jointstype=self.jointstype) + + translation = poses[..., 0, :].clone() + # Let the root have the Y translation + root_y = translation[..., 1] + + # Trajectory => Translation without gravity axis (Y) + trajectory = translation[..., [0, 2]] + + # Delete the root joints of the poses + poses = poses[..., 1:, :] + + # Remove the trajectory of the poses + poses[..., [0, 2]] -= trajectory[..., None, :] + + # Compute the trajectory + vel_trajectory = torch.diff(trajectory, dim=-2) + # 0 for the first one => keep the dimentionality + vel_trajectory = torch.cat( + (0 * vel_trajectory[..., [0], :], vel_trajectory), dim=-2) + + # Compute the forward direction + forward = get_forward_direction(poses, jointstype=self.jointstype) + if self.forward_filter: + # Smoothing to remove high frequencies + forward = gaussian_filter1d(forward, 2) + # normalize again to get real directions + forward = torch.nn.functional.normalize(forward, dim=-1) + + angles = T(torch.atan2(*T(forward))) + vel_angles = torch.diff(angles, dim=-1) + # 0 for the first one => keep the dimentionality + vel_angles = torch.cat((0 * vel_angles[..., [0]], vel_angles), dim=-1) + + # Construct the inverse rotation matrix + sin, cos = forward[..., 0], forward[..., 1] + rotations_inv = matrix_of_angles(cos, sin, inv=True) + + # Rotate the poses + poses_local = torch.einsum("...lj,...jk->...lk", poses[..., [0, 2]], + rotations_inv) + poses_local = torch.stack( + (poses_local[..., 0], poses[..., 1], poses_local[..., 1]), axis=-1) + + # stack the xyz joints into feature vectors + poses_features = rearrange(poses_local, + "... joints xyz -> ... (joints xyz)") + + # Rotate the vel_trajectory + vel_trajectory_local = torch.einsum("...j,...jk->...k", vel_trajectory, + rotations_inv) + # Stack things together + features = torch.cat((root_y[..., None], poses_features, + vel_angles[..., None], vel_trajectory_local), -1) + + # Normalize if needed + features = self.normalize(features) + return features + + def inverse(self, features: Tensor) -> Tensor: + features = self.unnormalize(features) + root_y, poses_features, vel_angles, vel_trajectory_local = self.extract( + features) + + # already have the good dimensionality + angles = torch.cumsum(vel_angles, dim=-1) + # First frame should be 0, but if infered it is better to ensure it + angles = angles - angles[..., [0]] + + cos, sin = torch.cos(angles), torch.sin(angles) + rotations = matrix_of_angles(cos, sin, inv=False) + + # Get back the poses + poses_local = rearrange(poses_features, + "... (joints xyz) -> ... joints xyz", + xyz=3) + + # Rotate the poses + poses = torch.einsum("...lj,...jk->...lk", poses_local[..., [0, 2]], + rotations) + poses = torch.stack( + (poses[..., 0], poses_local[..., 1], poses[..., 1]), axis=-1) + + # Rotate the vel_trajectory + vel_trajectory = torch.einsum("...j,...jk->...k", vel_trajectory_local, + rotations) + # Integrate the trajectory + # Already have the good dimensionality + trajectory = torch.cumsum(vel_trajectory, dim=-2) + # First frame should be 0, but if infered it is better to ensure it + trajectory = trajectory - trajectory[..., [0], :] + + # Add the root joints (which is still zero) + poses = torch.cat((0 * poses[..., [0], :], poses), -2) + + # put back the root joint y + poses[..., 0, 1] = root_y + + # Add the trajectory globally + poses[..., [0, 2]] += trajectory[..., None, :] + return poses + + def extract(self, features: Tensor) -> tuple: + root_y = features[..., 0] + poses_features = features[..., 1:-3] + vel_angles = features[..., -3] + vel_trajectory_local = features[..., -2:] + + return root_y, poses_features, vel_angles, vel_trajectory_local diff --git a/Evaluator_272/mld/transforms/joints2jfeats/tools.py b/Evaluator_272/mld/transforms/joints2jfeats/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ad8eceb937ef841d38f1183aabb856793a377c --- /dev/null +++ b/Evaluator_272/mld/transforms/joints2jfeats/tools.py @@ -0,0 +1,92 @@ +import torch +import torch.nn.functional as F + +from mld.utils.joints import mmm_joints, humanml3d_joints, motionx_joints + +# Get the indexes of particular body part + + +# .T is deprecated now for reversing a tensor +def T(x): + return x.permute(*torch.arange(x.ndim - 1, -1, -1)) + + +def get_forward_direction(poses, jointstype="mmm"): + if jointstype == "mmm" or jointstype == "mmmns": + joints = mmm_joints + elif jointstype == "humanml3d": + joints = humanml3d_joints + elif jointstype in ["motionx", "motionx_v26"]: + joints = motionx_joints + else: + raise TypeError('Only supports mmm, mmmns and humanl3d jointstype') + # Shoulders + LS, RS = joints.index("LS"), joints.index("RS") + # Hips + LH, RH = joints.index("LH"), joints.index("RH") + + across = poses[..., RH, :] - poses[..., LH, :] + poses[..., RS, :] - poses[ + ..., LS, :] + forward = torch.stack((-across[..., 2], across[..., 0]), axis=-1) + forward = torch.nn.functional.normalize(forward, dim=-1) + return forward + + +def get_floor(poses, jointstype="mmm"): + if jointstype == "mmm" or jointstype == "mmmns": + joints = mmm_joints + elif jointstype == "humanml3d": + joints = humanml3d_joints + elif jointstype in ["motionx", "motionx_v26"]: + joints = motionx_joints + else: + raise TypeError('Only supports mmm, mmmns and humanl3d jointstype') + ndim = len(poses.shape) + # Feet + LM, RM = joints.index("LMrot"), joints.index("RMrot") + LF, RF = joints.index("LF"), joints.index("RF") + # import pdb; pdb.set_trace() + foot_heights = poses[..., (LM, LF, RM, RF), 1].min(-1).values + floor_height = softmin(foot_heights, softness=0.5, dim=-1) + return T(floor_height[(ndim - 2) * [None]]) + + +def softmax(x, softness=1.0, dim=None): + maxi, mini = x.max(dim=dim).values, x.min(dim=dim).values + return maxi + torch.log(softness + torch.exp(mini - maxi)) + + +def softmin(x, softness=1.0, dim=0): + return -softmax(-x, softness=softness, dim=dim) + + +def gaussian_filter1d(_inputs, sigma, truncate=4.0): + # Code adapted/mixed from scipy library into pytorch + # https://github.com/scipy/scipy/blob/47bb6febaa10658c72962b9615d5d5aa2513fa3a/scipy/ndimage/filters.py#L211 + # and gaussian kernel + # https://github.com/scipy/scipy/blob/47bb6febaa10658c72962b9615d5d5aa2513fa3a/scipy/ndimage/filters.py#L179 + # Correspond to mode="nearest" and order = 0 + # But works batched + if len(_inputs.shape) == 2: + inputs = _inputs[None] + else: + inputs = _inputs + + sd = float(sigma) + radius = int(truncate * sd + 0.5) + sigma2 = sigma * sigma + x = torch.arange(-radius, + radius + 1, + device=inputs.device, + dtype=inputs.dtype) + phi_x = torch.exp(-0.5 / sigma2 * x**2) + phi_x = phi_x / phi_x.sum() + + # Conv1d weights + groups = inputs.shape[-1] + weights = torch.tile(phi_x, (groups, 1, 1)) + inputs = inputs.transpose(-1, -2) + outputs = F.conv1d(inputs, weights, padding="same", + groups=groups).transpose(-1, -2) + + return outputs.reshape(_inputs.shape) diff --git a/Evaluator_272/mld/transforms/joints2rots/config.py b/Evaluator_272/mld/transforms/joints2rots/config.py new file mode 100644 index 0000000000000000000000000000000000000000..91e3a646f456e0a78ee9ff177ff87c941f9c01ba --- /dev/null +++ b/Evaluator_272/mld/transforms/joints2rots/config.py @@ -0,0 +1,119 @@ +import numpy as np +from mld.utils.joints import mmm_joints, smplh2mmm_indexes + +# Map joints Name to SMPL joints idx +JOINT_MAP = { + 'MidHip': 0, + 'LHip': 1, + 'LKnee': 4, + 'LAnkle': 7, + 'LFoot': 10, + 'RHip': 2, + 'RKnee': 5, + 'RAnkle': 8, + 'RFoot': 11, + 'LShoulder': 16, + 'LElbow': 18, + 'LWrist': 20, + 'LHand': 22, + 'RShoulder': 17, + 'RElbow': 19, + 'RWrist': 21, + 'RHand': 23, + 'spine1': 3, + 'spine2': 6, + 'spine3': 9, + 'Neck': 12, + 'Head': 15, + 'LCollar': 13, + 'Rcollar': 14, + 'Nose': 24, + 'REye': 26, + 'LEye': 26, + 'REar': 27, + 'LEar': 28, + 'LHeel': 31, + 'RHeel': 34, + 'OP RShoulder': 17, + 'OP LShoulder': 16, + 'OP RHip': 2, + 'OP LHip': 1, + 'OP Neck': 12, +} + +mmm2smpl_correspondence = { + "root": "MidHip", + "BP": "spine1", + "BT": "spine3", + "BLN": "Neck", + "BUN": "Head", + "LS": "LShoulder", + "LE": "LElbow", + "LW": "LWrist", + "RS": "RShoulder", + "RE": "RElbow", + "RW": "RWrist", + "LH": "LHip", + "LK": "LKnee", + "LA": "LAnkle", + "LMrot": "LHeel", + "LF": "LFoot", + "RH": "RHip", + "RK": "RKnee", + "RA": "RAnkle", + "RMrot": "RHeel", + "RF": "RFoot" +} + +full_smpl_idx = range(24) +key_smpl_idx = [0, 1, 4, 7, 2, 5, 8, 17, 19, 21, 16, 18, 20] + +AMASS_JOINT_MAP = { + 'MidHip': 0, + 'LHip': 1, + 'LKnee': 4, + 'LAnkle': 7, + 'LFoot': 10, + 'RHip': 2, + 'RKnee': 5, + 'RAnkle': 8, + 'RFoot': 11, + 'LShoulder': 16, + 'LElbow': 18, + 'LWrist': 20, + 'RShoulder': 17, + 'RElbow': 19, + 'RWrist': 21, + 'spine1': 3, + 'spine2': 6, + 'spine3': 9, + 'Neck': 12, + 'Head': 15, + 'LCollar': 13, + 'Rcollar': 14, +} +amass_idx = range(22) +amass_smpl_idx = range(22) + +# cal mmm in smpl index +smpl2mmm_correspondence = { + val: key + for key, val in mmm2smpl_correspondence.items() +} +smpl2mmm_indexes = [JOINT_MAP[mmm2smpl_correspondence[x]] for x in mmm_joints] + +# cal mmm joints map +MMM_JOINT_MAP = { + val: JOINT_MAP[val] + for key, val in mmm2smpl_correspondence.items() +} + +# mmm_idx = range(21) +# mmm_smpl_dix = smpl2mmm_indexes +# mmm_smpl_dix = smplh2mmm_indexes +# todo - configable +SMPL_MODEL_DIR = "/apdcephfs/share_1227775/shingxchen/AIMotion/TMOSTData/deps/smpl_models/" +GMM_MODEL_DIR = "/apdcephfs/share_1227775/shingxchen/AIMotion/TMOSTData/deps/smpl_models/" +SMPL_MEAN_FILE = "/apdcephfs/share_1227775/shingxchen/AIMotion/TMOSTData/deps/smpl_models/neutral_smpl_mean_params.h5" +# for collsion +Part_Seg_DIR = "/apdcephfs/share_1227775/shingxchen/AIMotion/TMOSTData/deps/smpl_models/smplx_parts_segm.pkl" diff --git a/Evaluator_272/mld/transforms/joints2rots/customloss.py b/Evaluator_272/mld/transforms/joints2rots/customloss.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3c3a530876113596f223324dc9dd0c002fd520 --- /dev/null +++ b/Evaluator_272/mld/transforms/joints2rots/customloss.py @@ -0,0 +1,217 @@ +import torch +import torch.nn.functional as F +import config + +# Guassian +def gmof(x, sigma): + """ + Geman-McClure error function + """ + x_squared = x ** 2 + sigma_squared = sigma ** 2 + return (sigma_squared * x_squared) / (sigma_squared + x_squared) + +# angle prior +def angle_prior(pose): + """ + Angle prior that penalizes unnatural bending of the knees and elbows + """ + # We subtract 3 because pose does not include the global rotation of the model + return torch.exp( + pose[:, [55 - 3, 58 - 3, 12 - 3, 15 - 3]] * torch.tensor([1., -1., -1, -1.], device=pose.device)) ** 2 + + +def perspective_projection(points, rotation, translation, + focal_length, camera_center): + """ + This function computes the perspective projection of a set of points. + Input: + points (bs, N, 3): 3D points + rotation (bs, 3, 3): Camera rotation + translation (bs, 3): Camera translation + focal_length (bs,) or scalar: Focal length + camera_center (bs, 2): Camera center + """ + batch_size = points.shape[0] + K = torch.zeros([batch_size, 3, 3], device=points.device) + K[:, 0, 0] = focal_length + K[:, 1, 1] = focal_length + K[:, 2, 2] = 1. + K[:, :-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:, :, -1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + return projected_points[:, :, :-1] + + +def body_fitting_loss(body_pose, betas, model_joints, camera_t, camera_center, + joints_2d, joints_conf, pose_prior, + focal_length=5000, sigma=100, pose_prior_weight=4.78, + shape_prior_weight=5, angle_prior_weight=15.2, + output='sum'): + """ + Loss function for body fitting + """ + batch_size = body_pose.shape[0] + rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand(batch_size, -1, -1) + + projected_joints = perspective_projection(model_joints, rotation, camera_t, + focal_length, camera_center) + + # Weighted robust reprojection error + reprojection_error = gmof(projected_joints - joints_2d, sigma) + reprojection_loss = (joints_conf ** 2) * reprojection_error.sum(dim=-1) + + # Pose prior loss + pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) + + # Angle prior for knees and elbows + angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) + + # Regularizer to prevent betas from taking large values + shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) + + total_loss = reprojection_loss.sum(dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss + + if output == 'sum': + return total_loss.sum() + elif output == 'reprojection': + return reprojection_loss + + +# --- get camera fitting loss ----- +def camera_fitting_loss(model_joints, camera_t, camera_t_est, camera_center, + joints_2d, joints_conf, + focal_length=5000, depth_loss_weight=100): + """ + Loss function for camera optimization. + """ + # Project model joints + batch_size = model_joints.shape[0] + rotation = torch.eye(3, device=model_joints.device).unsqueeze(0).expand(batch_size, -1, -1) + projected_joints = perspective_projection(model_joints, rotation, camera_t, + focal_length, camera_center) + + # get the indexed four + op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder'] + op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints] + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + reprojection_error_op = (joints_2d[:, op_joints_ind] - + projected_joints[:, op_joints_ind]) ** 2 + reprojection_error_gt = (joints_2d[:, gt_joints_ind] - + projected_joints[:, gt_joints_ind]) ** 2 + + # Check if for each example in the batch all 4 OpenPose detections are valid, otherwise use the GT detections + # OpenPose joints are more reliable for this task, so we prefer to use them if possible + is_valid = (joints_conf[:, op_joints_ind].min(dim=-1)[0][:, None, None] > 0).float() + reprojection_loss = (is_valid * reprojection_error_op + (1 - is_valid) * reprojection_error_gt).sum(dim=(1, 2)) + + # Loss that penalizes deviation from depth estimate + depth_loss = (depth_loss_weight ** 2) * (camera_t[:, 2] - camera_t_est[:, 2]) ** 2 + + total_loss = reprojection_loss + depth_loss + return total_loss.sum() + + + + # #####--- body fitiing loss ----- +def body_fitting_loss_3d(body_pose, preserve_pose, + betas, model_joints, camera_translation, + j3d, pose_prior, + joints3d_conf, + sigma=100, pose_prior_weight=4.78*1.5, + shape_prior_weight=5.0, angle_prior_weight=15.2, + joint_loss_weight=500.0, + pose_preserve_weight=0.0, + use_collision=False, + model_vertices=None, model_faces=None, + search_tree=None, pen_distance=None, filter_faces=None, + collision_loss_weight=1000 + ): + """ + Loss function for body fitting + """ + batch_size = body_pose.shape[0] + + #joint3d_loss = (joint_loss_weight ** 2) * gmof((model_joints + camera_translation) - j3d, sigma).sum(dim=-1) + + joint3d_error = gmof((model_joints + camera_translation) - j3d, sigma) + + joint3d_loss_part = (joints3d_conf ** 2) * joint3d_error.sum(dim=-1) + joint3d_loss = (joint_loss_weight ** 2) * joint3d_loss_part + + # Pose prior loss + pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) + # Angle prior for knees and elbows + angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) + # Regularizer to prevent betas from taking large values + shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) + + collision_loss = 0.0 + # Calculate the loss due to interpenetration + if use_collision: + triangles = torch.index_select( + model_vertices, 1, + model_faces).view(batch_size, -1, 3, 3) + + with torch.no_grad(): + collision_idxs = search_tree(triangles) + + # Remove unwanted collisions + if filter_faces is not None: + collision_idxs = filter_faces(collision_idxs) + + if collision_idxs.ge(0).sum().item() > 0: + collision_loss = torch.sum(collision_loss_weight * pen_distance(triangles, collision_idxs)) + + pose_preserve_loss = (pose_preserve_weight ** 2) * ((body_pose - preserve_pose) ** 2).sum(dim=-1) + + total_loss = joint3d_loss + pose_prior_loss + angle_prior_loss + shape_prior_loss + collision_loss + pose_preserve_loss + + return total_loss.sum() + + +# #####--- get camera fitting loss ----- +def camera_fitting_loss_3d(model_joints, camera_t, camera_t_est, + j3d, joints_category="orig", depth_loss_weight=100.0): + """ + Loss function for camera optimization. + """ + model_joints = model_joints + camera_t + # # get the indexed four + # op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder'] + # op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints] + # + # j3d_error_loss = (j3d[:, op_joints_ind] - + # model_joints[:, op_joints_ind]) ** 2 + + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category=="orig": + select_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="AMASS": + select_joints_ind = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="MMM": + select_joints_ind = [config.MMM_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + j3d_error_loss = (j3d[:, select_joints_ind] - + model_joints[:, gt_joints_ind]) ** 2 + + # Loss that penalizes deviation from depth estimate + depth_loss = (depth_loss_weight**2) * (camera_t - camera_t_est)**2 + + total_loss = j3d_error_loss + depth_loss + return total_loss.sum() \ No newline at end of file diff --git a/Evaluator_272/mld/transforms/joints2rots/prior.py b/Evaluator_272/mld/transforms/joints2rots/prior.py new file mode 100644 index 0000000000000000000000000000000000000000..d85debddd185d44082f6ac14fdaa606d4deebd40 --- /dev/null +++ b/Evaluator_272/mld/transforms/joints2rots/prior.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import sys +import os + +import time +import pickle + +import numpy as np + +import torch +import torch.nn as nn + +DEFAULT_DTYPE = torch.float32 + + +def create_prior(prior_type, **kwargs): + if prior_type == 'gmm': + prior = MaxMixturePrior(**kwargs) + elif prior_type == 'l2': + return L2Prior(**kwargs) + elif prior_type == 'angle': + return SMPLifyAnglePrior(**kwargs) + elif prior_type == 'none' or prior_type is None: + # Don't use any pose prior + def no_prior(*args, **kwargs): + return 0.0 + prior = no_prior + else: + raise ValueError('Prior {}'.format(prior_type) + ' is not implemented') + return prior + + +class SMPLifyAnglePrior(nn.Module): + def __init__(self, dtype=torch.float32, **kwargs): + super(SMPLifyAnglePrior, self).__init__() + + # Indices for the roration angle of + # 55: left elbow, 90deg bend at -np.pi/2 + # 58: right elbow, 90deg bend at np.pi/2 + # 12: left knee, 90deg bend at np.pi/2 + # 15: right knee, 90deg bend at np.pi/2 + angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64) + angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long) + self.register_buffer('angle_prior_idxs', angle_prior_idxs) + + angle_prior_signs = np.array([1, -1, -1, -1], + dtype=np.float6432 if dtype == torch.float32 + else np.float6464) + angle_prior_signs = torch.tensor(angle_prior_signs, + dtype=dtype) + self.register_buffer('angle_prior_signs', angle_prior_signs) + + def forward(self, pose, with_global_pose=False): + ''' Returns the angle prior loss for the given pose + Args: + pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle + representation of the rotations of the joints of the SMPL model. + Kwargs: + with_global_pose: Whether the pose vector also contains the global + orientation of the SMPL model. If not then the indices must be + corrected. + Returns: + A sze (B) tensor containing the angle prior loss for each element + in the batch. + ''' + angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3 + return torch.exp(pose[:, angle_prior_idxs] * + self.angle_prior_signs).pow(2) + + +class L2Prior(nn.Module): + def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs): + super(L2Prior, self).__init__() + + def forward(self, module_input, *args): + return torch.sum(module_input.pow(2)) + + +class MaxMixturePrior(nn.Module): + + def __init__(self, prior_folder='prior', + num_gaussians=6, dtype=DEFAULT_DTYPE, epsilon=1e-16, + use_merged=True, + **kwargs): + super(MaxMixturePrior, self).__init__() + + if dtype == DEFAULT_DTYPE: + np_dtype = np.float6432 + elif dtype == torch.float64: + np_dtype = np.float6464 + else: + print('Unknown float type {}, exiting!'.format(dtype)) + sys.exit(-1) + + self.num_gaussians = num_gaussians + self.epsilon = epsilon + self.use_merged = use_merged + gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) + + full_gmm_fn = os.path.join(prior_folder, gmm_fn) + if not os.path.exists(full_gmm_fn): + print('The path to the mixture prior "{}"'.format(full_gmm_fn) + + ' does not exist, exiting!') + sys.exit(-1) + + with open(full_gmm_fn, 'rb') as f: + gmm = pickle.load(f, encoding='latin1') + + if type(gmm) == dict: + means = gmm['means'].astype(np_dtype) + covs = gmm['covars'].astype(np_dtype) + weights = gmm['weights'].astype(np_dtype) + elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): + means = gmm.means_.astype(np_dtype) + covs = gmm.covars_.astype(np_dtype) + weights = gmm.weights_.astype(np_dtype) + else: + print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) + sys.exit(-1) + + self.register_buffer('means', torch.tensor(means, dtype=dtype)) + + self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) + + precisions = [np.linalg.inv(cov) for cov in covs] + precisions = np.stack(precisions).astype(np_dtype) + + self.register_buffer('precisions', + torch.tensor(precisions, dtype=dtype)) + + # The constant term: + sqrdets = np.array([(np.sqrt(np.linalg.det(c))) + for c in gmm['covars']]) + const = (2 * np.pi)**(69 / 2.) + + nll_weights = np.asarray(gmm['weights'] / (const * + (sqrdets / sqrdets.min()))) + nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) + self.register_buffer('nll_weights', nll_weights) + + weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) + self.register_buffer('weights', weights) + + self.register_buffer('pi_term', + torch.log(torch.tensor(2 * np.pi, dtype=dtype))) + + cov_dets = [np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) + for cov in covs] + self.register_buffer('cov_dets', + torch.tensor(cov_dets, dtype=dtype)) + + # The dimensionality of the random variable + self.random_var_dim = self.means.shape[1] + + def get_mean(self): + ''' Returns the mean of the mixture ''' + mean_pose = torch.matmul(self.weights, self.means) + return mean_pose + + def merged_log_likelihood(self, pose, betas): + diff_from_mean = pose.unsqueeze(dim=1) - self.means + + prec_diff_prod = torch.einsum('mij,bmj->bmi', + [self.precisions, diff_from_mean]) + diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) + + curr_loglikelihood = 0.5 * diff_prec_quadratic - \ + torch.log(self.nll_weights) + # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) + + # self.random_var_dim * self.pi_term + + # diff_prec_quadratic + # ) - torch.log(self.weights) + + min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) + return min_likelihood + + def log_likelihood(self, pose, betas, *args, **kwargs): + ''' Create graph operation for negative log-likelihood calculation + ''' + likelihoods = [] + + for idx in range(self.num_gaussians): + mean = self.means[idx] + prec = self.precisions[idx] + cov = self.covs[idx] + diff_from_mean = pose - mean + + curr_loglikelihood = torch.einsum('bj,ji->bi', + [diff_from_mean, prec]) + curr_loglikelihood = torch.einsum('bi,bi->b', + [curr_loglikelihood, + diff_from_mean]) + cov_term = torch.log(torch.det(cov) + self.epsilon) + curr_loglikelihood += 0.5 * (cov_term + + self.random_var_dim * + self.pi_term) + likelihoods.append(curr_loglikelihood) + + log_likelihoods = torch.stack(likelihoods, dim=1) + min_idx = torch.argmin(log_likelihoods, dim=1) + weight_component = self.nll_weights[:, min_idx] + weight_component = -torch.log(weight_component) + + return weight_component + log_likelihoods[:, min_idx] + + def forward(self, pose, betas): + if self.use_merged: + return self.merged_log_likelihood(pose, betas) + else: + return self.log_likelihood(pose, betas) diff --git a/Evaluator_272/mld/transforms/joints2rots/smplify.py b/Evaluator_272/mld/transforms/joints2rots/smplify.py new file mode 100644 index 0000000000000000000000000000000000000000..7df51503a4a46a479a508c9fdf362cb063b93742 --- /dev/null +++ b/Evaluator_272/mld/transforms/joints2rots/smplify.py @@ -0,0 +1,284 @@ +import torch +import os, sys +import pickle +import smplx +import numpy as np +from tqdm import tqdm + +sys.path.append(os.path.dirname(__file__)) +from customloss import (camera_fitting_loss, + body_fitting_loss, + camera_fitting_loss_3d, + body_fitting_loss_3d, + ) +from prior import MaxMixturePrior +import config + + + +@torch.no_grad() +def guess_init_3d(model_joints, + j3d, + joints_category="orig"): + """Initialize the camera translation via triangle similarity, by using the torso joints . + :param model_joints: SMPL model with pre joints + :param j3d: 25x3 array of Kinect Joints + :returns: 3D vector corresponding to the estimated camera translation + """ + # get the indexed four + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category=="orig": + joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="AMASS": + joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="MMM": + joints_ind_category = [config.MMM_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum(dim=1) + init_t = sum_init_t / 4.0 + return init_t + + +# SMPLIfy 3D +class SMPLify3D(): + """Implementation of SMPLify, use 3D joints.""" + + def __init__(self, + smplxmodel, + step_size=1e-2, + batch_size=1, + num_iters=100, + use_collision=False, + use_lbfgs=True, + joints_category="orig", + device=torch.device('cuda:0'), + ): + + # Store options + self.batch_size = batch_size + self.device = device + self.step_size = step_size + + self.num_iters = num_iters + # --- choose optimizer + self.use_lbfgs = use_lbfgs + # GMM pose prior + self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR, + num_gaussians=8, + dtype=torch.float32).to(device) + # collision part + self.use_collision = use_collision + if self.use_collision: + self.part_segm_fn = config.Part_Seg_DIR + + # reLoad SMPL-X model + self.smpl = smplxmodel + + self.model_faces = smplxmodel.faces_tensor.view(-1) + + # select joint joint_category + self.joints_category = joints_category + + if joints_category=="orig": + self.smpl_index = config.full_smpl_idx + self.corr_index = config.full_smpl_idx + elif joints_category=="AMASS": + self.smpl_index = config.amass_smpl_idx + self.corr_index = config.amass_idx + # elif joints_category=="MMM": + # self.smpl_index = config.mmm_smpl_dix + # self.corr_index = config.mmm_idx + else: + self.smpl_index = None + self.corr_index = None + print("NO SUCH JOINTS CATEGORY!") + + # ---- get the man function here ------ + def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, seq_ind=0): + """Perform body fitting. + Input: + init_pose: SMPL pose estimate + init_betas: SMPL betas estimate + init_cam_t: Camera translation estimate + j3d: joints 3d aka keypoints + conf_3d: confidence for 3d joints + seq_ind: index of the sequence + Returns: + vertices: Vertices of optimized shape + joints: 3D joints of optimized shape + pose: SMPL pose parameters of optimized shape + betas: SMPL beta parameters of optimized shape + camera_translation: Camera translation + """ + + # # # add the mesh inter-section to avoid + search_tree = None + pen_distance = None + filter_faces = None + + if self.use_collision: + from mesh_intersection.bvh_search_tree import BVH + import mesh_intersection.loss as collisions_loss + from mesh_intersection.filter_faces import FilterFaces + + search_tree = BVH(max_collisions=8) + + pen_distance = collisions_loss.DistanceFieldPenetrationLoss( + sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True) + + if self.part_segm_fn: + # Read the part segmentation + part_segm_fn = os.path.expandvars(self.part_segm_fn) + with open(part_segm_fn, 'rb') as faces_parents_file: + face_segm_data = pickle.load(faces_parents_file, encoding='latin1') + faces_segm = face_segm_data['segm'] + faces_parents = face_segm_data['parents'] + # Create the module used to filter invalid collision pairs + filter_faces = FilterFaces( + faces_segm=faces_segm, faces_parents=faces_parents, + ign_part_pairs=None).to(device=self.device) + + + # Split SMPL pose to body pose and global orientation + body_pose = init_pose[:, 3:].detach().clone() + global_orient = init_pose[:, :3].detach().clone() + betas = init_betas.detach().clone() + + # use guess 3d to get the initial + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).detach() + camera_translation = init_cam_t.clone() + + preserve_pose = init_pose[:, 3:].detach().clone() + # -------------Step 1: Optimize camera translation and body orientation-------- + # Optimize only camera translation and body orientation + body_pose.requires_grad = False + betas.requires_grad = False + global_orient.requires_grad = True + camera_translation.requires_grad = True + + camera_opt_params = [global_orient, camera_translation] + + if self.use_lbfgs: + camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + for i in range(10): + def closure(): + camera_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + loss = camera_fitting_loss_3d(model_joints, camera_translation, + init_cam_t, j3d, self.joints_category) + loss.backward() + return loss + + camera_optimizer.step(closure) + else: + camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(20): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + loss = camera_fitting_loss_3d(model_joints[:, self.smpl_index], camera_translation, + init_cam_t, j3d[:, self.corr_index], self.joints_category) + camera_optimizer.zero_grad() + loss.backward() + camera_optimizer.step() + + # Fix camera translation after optimizing camera + # --------Step 2: Optimize body joints -------------------------- + # Optimize only the body pose and global orientation of the body + body_pose.requires_grad = True + global_orient.requires_grad = True + camera_translation.requires_grad = True + + # --- if we use the sequence, fix the shape + if seq_ind == 0: + betas.requires_grad = True + body_opt_params = [body_pose, betas, global_orient, camera_translation] + else: + betas.requires_grad = False + body_opt_params = [body_pose, global_orient, camera_translation] + + if self.use_lbfgs: + body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + + for i in tqdm(range(self.num_iters), desc=f"LBFGS iter: "): + # for i in range(self.num_iters): + def closure(): + body_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + pose_preserve_weight=5.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + loss.backward() + return loss + + body_optimizer.step(closure) + else: + body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(self.num_iters): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + body_optimizer.zero_grad() + loss.backward() + body_optimizer.step() + + # Get final loss value + with torch.no_grad(): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas, return_full_pose=True) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + final_loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + + vertices = smpl_output.vertices.detach() + joints = smpl_output.joints.detach() + pose = torch.cat([global_orient, body_pose], dim=-1).detach() + betas = betas.detach() + + return vertices, joints, pose, betas, camera_translation, final_loss \ No newline at end of file diff --git a/Evaluator_272/mld/transforms/rotation2xyz.py b/Evaluator_272/mld/transforms/rotation2xyz.py new file mode 100644 index 0000000000000000000000000000000000000000..8a62fbe7eff18cae14c4768084a41cc375914198 --- /dev/null +++ b/Evaluator_272/mld/transforms/rotation2xyz.py @@ -0,0 +1,114 @@ +# This code is based on https://github.com/Mathux/ACTOR.git +import torch +import mld.utils.rotation_conversions as geometry + +from .smpl import SMPL, JOINTSTYPE_ROOT +# from .get_model import JOINTSTYPES +JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] + + +class Rotation2xyz(torch.nn.Module): + + def __init__(self, smpl_path): + super().__init__() + self.smpl_model = SMPL(smpl_path).eval() + + def __call__(self, + x, + mask, + pose_rep, + translation, + glob, + jointstype, + vertstrans, + betas=None, + beta=0, + glob_rot=None, + get_rotations_back=False, + **kwargs): + if pose_rep == "xyz": + return x + + if mask is None: + mask = torch.ones((x.shape[0], x.shape[-1]), + dtype=bool, + device=x.device) + + if not glob and glob_rot is None: + raise TypeError( + "You must specify global rotation if glob is False") + + if jointstype not in JOINTSTYPES: + raise NotImplementedError("This jointstype is not implemented.") + + if translation: + x_translations = x[:, -1, :3] + x_rotations = x[:, :-1] + else: + x_rotations = x + + x_rotations = x_rotations.permute(0, 3, 1, 2) + nsamples, time, njoints, feats = x_rotations.shape + + # Compute rotations (convert only masked sequences output) + if pose_rep == "rotvec": + rotations = geometry.axis_angle_to_matrix(x_rotations[mask]) + elif pose_rep == "rotmat": + rotations = x_rotations[mask].view(-1, njoints, 3, 3) + elif pose_rep == "rotquat": + rotations = geometry.quaternion_to_matrix(x_rotations[mask]) + elif pose_rep == "rot6d": + rotations = geometry.rotation_6d_to_matrix(x_rotations[mask]) + else: + raise NotImplementedError("No geometry for this one.") + + if not glob: + global_orient = torch.tensor(glob_rot, device=x.device) + global_orient = geometry.axis_angle_to_matrix(global_orient).view( + 1, 1, 3, 3) + global_orient = global_orient.repeat(len(rotations), 1, 1, 1) + else: + global_orient = rotations[:, 0] + rotations = rotations[:, 1:] + + if betas is None: + betas = torch.zeros( + [rotations.shape[0], self.smpl_model.num_betas], + dtype=rotations.dtype, + device=rotations.device) + betas[:, 1] = beta + + out = self.smpl_model(body_pose=rotations, + global_orient=global_orient, + betas=betas) + + # get the desirable joints + joints = out[jointstype] + + x_xyz = torch.empty(nsamples, + time, + joints.shape[1], + 3, + device=x.device, + dtype=x.dtype) + x_xyz[~mask] = 0 + x_xyz[mask] = joints + + x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous() + + # the first translation root at the origin on the prediction + if jointstype != "vertices": + rootindex = JOINTSTYPE_ROOT[jointstype] + x_xyz = x_xyz - x_xyz[:, [rootindex], :, :] + + if translation and vertstrans: + # the first translation root at the origin + x_translations = x_translations - x_translations[:, :, [0]] + + # add the translation to all the joints + x_xyz = x_xyz + x_translations[:, None, :, :] + + if get_rotations_back: + return x_xyz, rotations, global_orient + else: + return x_xyz diff --git a/Evaluator_272/mld/transforms/rots2joints/__init__.py b/Evaluator_272/mld/transforms/rots2joints/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2755abfcdfeccf7866ce8dac0d165c5d13c94d4d --- /dev/null +++ b/Evaluator_272/mld/transforms/rots2joints/__init__.py @@ -0,0 +1,2 @@ +from .base import Rots2Joints +from .smplh import SMPLH diff --git a/Evaluator_272/mld/transforms/rots2joints/base.py b/Evaluator_272/mld/transforms/rots2joints/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d683dd604be755e73ee5906d2bc9bc429216d95a --- /dev/null +++ b/Evaluator_272/mld/transforms/rots2joints/base.py @@ -0,0 +1,34 @@ +from typing import Optional + +import torch +from torch import Tensor, nn +from pathlib import Path + + +class Rots2Joints(nn.Module): + def __init__(self, path: Optional[str] = None, + normalization: bool = False, + eps: float = 1e-12, + **kwargs) -> None: + if normalization and path is None: + raise TypeError("You should provide a path if normalization is on.") + + super().__init__() + self.normalization = normalization + self.eps = eps + + if normalization: + mean_path = Path(path) / "mean.pt" + std_path = Path(path) / "std.pt" + self.register_buffer('mean', torch.load(mean_path)) + self.register_buffer('std', torch.load(std_path)) + + def normalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = (features - self.mean)/(self.std + self.eps) + return features + + def unnormalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = features * self.std + self.mean + return features diff --git a/Evaluator_272/mld/transforms/rots2joints/smplh.py b/Evaluator_272/mld/transforms/rots2joints/smplh.py new file mode 100644 index 0000000000000000000000000000000000000000..bab75a20d9eecac1375fbd435788e5c9deaa0b6f --- /dev/null +++ b/Evaluator_272/mld/transforms/rots2joints/smplh.py @@ -0,0 +1,175 @@ +import contextlib +from typing import Optional + +import torch +from einops import rearrange +from torch import Tensor +from .base import Rots2Joints + + +def slice_or_none(data, cslice): + if data is None: + return data + else: + return data[cslice] + + +class SMPLH(Rots2Joints): + + def __init__(self, + path: str, + jointstype: str = "mmm", + input_pose_rep: str = "matrix", + batch_size: int = 512, + gender="neutral", + **kwargs) -> None: + super().__init__(path=None, normalization=False) + self.batch_size = batch_size + self.input_pose_rep = input_pose_rep + self.jointstype = jointstype + self.training = False + + from smplx.body_models import SMPLHLayer + + # Remove annoying print + with contextlib.redirect_stdout(None): + self.smplh = SMPLHLayer(path, ext="npz", gender=gender).eval() + + self.faces = self.smplh.faces + for p in self.parameters(): + p.requires_grad = False + + def train(self, *args, **kwargs): + return self + + def forward(self, + smpl_data: dict, + jointstype: Optional[str] = None, + input_pose_rep: Optional[str] = None, + batch_size: Optional[int] = None) -> Tensor: + + # Take values from init if not specified there + jointstype = self.jointstype if jointstype is None else jointstype + batch_size = self.batch_size if batch_size is None else batch_size + input_pose_rep = self.input_pose_rep if input_pose_rep is None else input_pose_rep + + if input_pose_rep == "xyz": + raise NotImplementedError( + "You should use identity pose2joints instead") + + poses = smpl_data.rots + trans = smpl_data.trans + + from functools import reduce + import operator + save_shape_bs_len = poses.shape[:-3] + nposes = reduce(operator.mul, save_shape_bs_len, 1) + + if poses.shape[-3] == 52: + nohands = False + elif poses.shape[-3] == 22: + nohands = True + else: + raise NotImplementedError("Could not parse the poses.") + + # Convert any rotations to matrix + # from mld.tools.easyconvert import to_matrix + # matrix_poses = to_matrix(input_pose_rep, poses) + matrix_poses = poses + + # Reshaping + matrix_poses = matrix_poses.reshape((nposes, *matrix_poses.shape[-3:])) + global_orient = matrix_poses[:, 0] + + if trans is None: + trans = torch.zeros((*save_shape_bs_len, 3), + dtype=poses.dtype, + device=poses.device) + + trans_all = trans.reshape((nposes, *trans.shape[-1:])) + + body_pose = matrix_poses[:, 1:22] + if nohands: + from mld.tools.easyconvert import to_matrix + # still axis angle + left_hand_pose = self.smplh.left_hand_mean.reshape(15, 3) + left_hand_pose = to_matrix("axisangle", left_hand_pose) + left_hand_pose = left_hand_pose[None].repeat((nposes, 1, 1, 1)) + + right_hand_pose = self.smplh.right_hand_mean.reshape(15, 3) + right_hand_pose = to_matrix("axisangle", right_hand_pose) + right_hand_pose = right_hand_pose[None].repeat((nposes, 1, 1, 1)) + else: + hand_pose = matrix_poses[:, 22:] + left_hand_pose = hand_pose[:, :15] + right_hand_pose = hand_pose[:, 15:] + + n = len(body_pose) + outputs = [] + for chunk in range(int((n - 1) / batch_size) + 1): + chunk_slice = slice(chunk * batch_size, (chunk + 1) * batch_size) + smpl_output = self.smplh( + global_orient=slice_or_none(global_orient, chunk_slice), + body_pose=slice_or_none(body_pose, chunk_slice), + left_hand_pose=slice_or_none(left_hand_pose, chunk_slice), + right_hand_pose=slice_or_none(right_hand_pose, chunk_slice), + transl=slice_or_none(trans_all, chunk_slice)) + + if jointstype == "vertices": + output_chunk = smpl_output.vertices + else: + joints = smpl_output.joints + output_chunk = joints + outputs.append(output_chunk) + + outputs = torch.cat(outputs) + outputs = outputs.reshape((*save_shape_bs_len, *outputs.shape[1:])) + + # Change topology if needed + outputs = smplh_to(jointstype, outputs, trans) + return outputs + + def inverse(self, joints: Tensor) -> Tensor: + raise NotImplementedError("Cannot inverse SMPLH layer.") + + +def smplh_to(jointstype, data, trans): + from mld.utils.joints import get_root_idx + + if "mmm" in jointstype: + from mld.utils.joints import smplh2mmm_indexes + indexes = smplh2mmm_indexes + data = data[..., indexes, :] + + # make it compatible with mmm + if jointstype == "mmm": + from mld.utils.joints import smplh_to_mmm_scaling_factor + data *= smplh_to_mmm_scaling_factor + + if jointstype == "smplmmm": + pass + elif jointstype in ["mmm", "mmmns"]: + # swap axis + data = data[..., [1, 2, 0]] + # revert left and right + data[..., 2] = -data[..., 2] + + elif jointstype == "smplnh": + from mld.utils.joints import smplh2smplnh_indexes + indexes = smplh2smplnh_indexes + data = data[..., indexes, :] + elif jointstype == "smplh": + pass + elif jointstype == "vertices": + pass + else: + raise NotImplementedError(f"SMPLH to {jointstype} is not implemented.") + + if jointstype != "vertices": + # shift the output in each batch + # such that it is centered on the pelvis/root on the first frame + root_joint_idx = get_root_idx(jointstype) + shift = trans[..., 0, :] - data[..., 0, root_joint_idx, :] + data += shift[..., None, None, :] + + return data diff --git a/Evaluator_272/mld/transforms/rots2rfeats/__init__.py b/Evaluator_272/mld/transforms/rots2rfeats/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff007f2f014b8264971f532ebab44f45f2c7b90 --- /dev/null +++ b/Evaluator_272/mld/transforms/rots2rfeats/__init__.py @@ -0,0 +1,2 @@ +from .base import Rots2Rfeats +from .smplvelp import SMPLVelP diff --git a/Evaluator_272/mld/transforms/rots2rfeats/base.py b/Evaluator_272/mld/transforms/rots2rfeats/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a98c598676a2cb731ea843e202b11c4f928ef5 --- /dev/null +++ b/Evaluator_272/mld/transforms/rots2rfeats/base.py @@ -0,0 +1,34 @@ +from typing import Optional + +import torch +from torch import Tensor, nn +from pathlib import Path + + +class Rots2Rfeats(nn.Module): + def __init__(self, path: Optional[str] = None, + normalization: bool = False, + eps: float = 1e-12, + **kwargs) -> None: + if normalization and path is None: + raise TypeError("You should provide a path if normalization is on.") + + super().__init__() + self.normalization = normalization + self.eps = eps + + if normalization: + mean_path = Path(path) / "rfeats_mean.pt" + std_path = Path(path) / "rfeats_std.pt" + self.register_buffer('mean', torch.load(mean_path)) + self.register_buffer('std', torch.load(std_path)) + + def normalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = (features - self.mean)/(self.std + self.eps) + return features + + def unnormalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = features * self.std + self.mean + return features diff --git a/Evaluator_272/mld/transforms/rots2rfeats/smplvelp.py b/Evaluator_272/mld/transforms/rots2rfeats/smplvelp.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4355ad868bd37c187dcd57b4ac25cdd6c5f72b --- /dev/null +++ b/Evaluator_272/mld/transforms/rots2rfeats/smplvelp.py @@ -0,0 +1,101 @@ +from typing import Optional + +import torch +from torch import Tensor +from einops import rearrange + +from mld.utils.temos_utils import matrix_to, nfeats_of, to_matrix +import mld.utils.geometry as geometry + +from .base import Rots2Rfeats + + +class SMPLVelP(Rots2Rfeats): + + def __init__(self, + path: Optional[str] = None, + normalization: bool = False, + pose_rep: str = "rot6d", + canonicalize: bool = False, + offset: bool = True, + **kwargs) -> None: + super().__init__(path=path, normalization=normalization) + self.canonicalize = canonicalize + self.pose_rep = pose_rep + self.nfeats = nfeats_of(pose_rep) + self.offset = offset + + def forward(self, data) -> Tensor: + matrix_poses, trans = data.rots, data.trans + # matrix_poses: [nframes, 22, 3, 3] + + # extract the root gravity axis + # for smpl it is the last coordinate + root_y = trans[..., 2] + trajectory = trans[..., [0, 1]] + + # Comoute the difference of trajectory (for X and Y axis) + vel_trajectory = torch.diff(trajectory, dim=-2) + # 0 for the first one => keep the dimentionality + vel_trajectory = torch.cat( + (0 * vel_trajectory[..., [0], :], vel_trajectory), dim=-2) + + # first normalize the data + if self.canonicalize: + global_orient = matrix_poses[..., 0, :, :] + # remove the rotation + rot2d = geometry.matrix_to_axis_angle(global_orient[..., 0, :, :]) + # Remove the fist rotation along the vertical axis + # construct this by extract only the vertical component of the rotation + rot2d[..., :2] = 0 + + if self.offset: + # add a bit more rotation + rot2d[..., 2] += torch.pi / 2 + + rot2d = geometry.axis_angle_to_matrix(rot2d) + + # turn with the same amount all the rotations + global_orient = torch.einsum("...kj,...kl->...jl", rot2d, + global_orient) + + matrix_poses = torch.cat( + (global_orient[..., None, :, :], matrix_poses[..., 1:, :, :]), + dim=-3) + + # Turn the trajectory as well + vel_trajectory = torch.einsum("...kj,...lk->...lj", + rot2d[..., :2, :2], vel_trajectory) + + poses = matrix_to(self.pose_rep, matrix_poses) + features = torch.cat( + (root_y[..., None], vel_trajectory, + rearrange(poses, "... joints rot -> ... (joints rot)")), + dim=-1) + features = self.normalize(features) + return features + + def extract(self, features): + root_y = features[..., 0] + vel_trajectory = features[..., 1:3] + poses_features = features[..., 3:] + poses = rearrange(poses_features, + "... (joints rot) -> ... joints rot", + rot=self.nfeats) + return root_y, vel_trajectory, poses + + def inverse(self, features): + features = self.unnormalize(features) + root_y, vel_trajectory, poses = self.extract(features) + + # integrate the trajectory + trajectory = torch.cumsum(vel_trajectory, dim=-2) + # First frame should be 0, but if infered it is better to ensure it + trajectory = trajectory - trajectory[..., [0], :] + + # Get back the translation + trans = torch.cat([trajectory, root_y[..., None]], dim=-1) + matrix_poses = to_matrix(self.pose_rep, poses) + + from temos.transforms.smpl import RotTransDatastruct + return RotTransDatastruct(rots=matrix_poses, trans=trans) diff --git a/Evaluator_272/mld/transforms/smpl.py b/Evaluator_272/mld/transforms/smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..f83c5ff18af981164950728f2bbf7e57214652c9 --- /dev/null +++ b/Evaluator_272/mld/transforms/smpl.py @@ -0,0 +1,253 @@ +from typing import Optional +from torch import Tensor +import numpy as np +import torch +import contextlib +from .base import Datastruct, dataclass, Transform +import os +from .rots2rfeats import Rots2Rfeats +from .rots2joints import Rots2Joints +from .joints2jfeats import Joints2Jfeats + + +class SMPLTransform(Transform): + + def __init__(self, rots2rfeats: Rots2Rfeats, rots2joints: Rots2Joints, + joints2jfeats: Joints2Jfeats, **kwargs): + self.rots2rfeats = rots2rfeats + self.rots2joints = rots2joints + self.joints2jfeats = joints2jfeats + + def Datastruct(self, **kwargs): + return SMPLDatastruct(_rots2rfeats=self.rots2rfeats, + _rots2joints=self.rots2joints, + _joints2jfeats=self.joints2jfeats, + transforms=self, + **kwargs) + + def __repr__(self): + return "SMPLTransform()" + + +class RotIdentityTransform(Transform): + + def __init__(self, **kwargs): + return + + def Datastruct(self, **kwargs): + return RotTransDatastruct(**kwargs) + + def __repr__(self): + return "RotIdentityTransform()" + + +@dataclass +class RotTransDatastruct(Datastruct): + rots: Tensor + trans: Tensor + + transforms: RotIdentityTransform = RotIdentityTransform() + + def __post_init__(self): + self.datakeys = ["rots", "trans"] + + def __len__(self): + return len(self.rots) + + +@dataclass +class SMPLDatastruct(Datastruct): + transforms: SMPLTransform + _rots2rfeats: Rots2Rfeats + _rots2joints: Rots2Joints + _joints2jfeats: Joints2Jfeats + + features: Optional[Tensor] = None + rots_: Optional[RotTransDatastruct] = None + rfeats_: Optional[Tensor] = None + joints_: Optional[Tensor] = None + jfeats_: Optional[Tensor] = None + + def __post_init__(self): + self.datakeys = ["features", "rots_", "rfeats_", "joints_", "jfeats_"] + # starting point + if self.features is not None and self.rfeats_ is None: + self.rfeats_ = self.features + + @property + def rots(self): + # Cached value + if self.rots_ is not None: + return self.rots_ + + # self.rfeats_ should be defined + assert self.rfeats_ is not None + + self._rots2rfeats.to(self.rfeats.device) + self.rots_ = self._rots2rfeats.inverse(self.rfeats) + return self.rots_ + + @property + def rfeats(self): + # Cached value + if self.rfeats_ is not None: + return self.rfeats_ + + # self.rots_ should be defined + assert self.rots_ is not None + + self._rots2rfeats.to(self.rots.device) + self.rfeats_ = self._rots2rfeats(self.rots) + return self.rfeats_ + + @property + def joints(self): + # Cached value + if self.joints_ is not None: + return self.joints_ + + self._rots2joints.to(self.rots.device) + self.joints_ = self._rots2joints(self.rots) + return self.joints_ + + @property + def jfeats(self): + # Cached value + if self.jfeats_ is not None: + return self.jfeats_ + + self._joints2jfeats.to(self.joints.device) + self.jfeats_ = self._joints2jfeats(self.joints) + return self.jfeats_ + + def __len__(self): + return len(self.rfeats) + + +# This code is based on https://github.com/Mathux/ACTOR.git +from smplx import SMPLLayer as _SMPLLayer +from smplx.lbs import vertices2joints + +# action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38] +# change 0 and 8 +action2motion_joints = [ + 8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38 +] + +SMPL_DATA_PATH = 'deps/smpl' + +JOINTSTYPE_ROOT = { + "a2m": 0, # action2motion + "smpl": 0, + "a2mpl": 0, # set(smpl, a2m) + "vibe": 8 +} # 0 is the 8 position: OP MidHip below + +JOINT_MAP = { + 'OP Nose': 24, + 'OP Neck': 12, + 'OP RShoulder': 17, + 'OP RElbow': 19, + 'OP RWrist': 21, + 'OP LShoulder': 16, + 'OP LElbow': 18, + 'OP LWrist': 20, + 'OP MidHip': 0, + 'OP RHip': 2, + 'OP RKnee': 5, + 'OP RAnkle': 8, + 'OP LHip': 1, + 'OP LKnee': 4, + 'OP LAnkle': 7, + 'OP REye': 25, + 'OP LEye': 26, + 'OP REar': 27, + 'OP LEar': 28, + 'OP LBigToe': 29, + 'OP LSmallToe': 30, + 'OP LHeel': 31, + 'OP RBigToe': 32, + 'OP RSmallToe': 33, + 'OP RHeel': 34, + 'Right Ankle': 8, + 'Right Knee': 5, + 'Right Hip': 45, + 'Left Hip': 46, + 'Left Knee': 4, + 'Left Ankle': 7, + 'Right Wrist': 21, + 'Right Elbow': 19, + 'Right Shoulder': 17, + 'Left Shoulder': 16, + 'Left Elbow': 18, + 'Left Wrist': 20, + 'Neck (LSP)': 47, + 'Top of Head (LSP)': 48, + 'Pelvis (MPII)': 49, + 'Thorax (MPII)': 50, + 'Spine (H36M)': 51, + 'Jaw (H36M)': 52, + 'Head (H36M)': 53, + 'Nose': 24, + 'Left Eye': 26, + 'Right Eye': 25, + 'Left Ear': 28, + 'Right Ear': 27 +} + +JOINT_NAMES = [ + 'OP Nose', 'OP Neck', 'OP RShoulder', 'OP RElbow', 'OP RWrist', + 'OP LShoulder', 'OP LElbow', 'OP LWrist', 'OP MidHip', 'OP RHip', + 'OP RKnee', 'OP RAnkle', 'OP LHip', 'OP LKnee', 'OP LAnkle', 'OP REye', + 'OP LEye', 'OP REar', 'OP LEar', 'OP LBigToe', 'OP LSmallToe', 'OP LHeel', + 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', 'Right Ankle', 'Right Knee', + 'Right Hip', 'Left Hip', 'Left Knee', 'Left Ankle', 'Right Wrist', + 'Right Elbow', 'Right Shoulder', 'Left Shoulder', 'Left Elbow', + 'Left Wrist', 'Neck (LSP)', 'Top of Head (LSP)', 'Pelvis (MPII)', + 'Thorax (MPII)', 'Spine (H36M)', 'Jaw (H36M)', 'Head (H36M)', 'Nose', + 'Left Eye', 'Right Eye', 'Left Ear', 'Right Ear' +] + + +# adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints +class SMPL(_SMPLLayer): + """ Extension of the official SMPL implementation to support more joints """ + + def __init__(self, smpl_path=SMPL_DATA_PATH, **kwargs): + model_path = os.path.join(smpl_path, "SMPL_NEUTRAL.pkl") + J_path = os.path.join(smpl_path, 'J_regressor_extra.npy') + kwargs["model_path"] = model_path + + # remove the verbosity for the 10-shapes beta parameters + with contextlib.redirect_stdout(None): + super(SMPL, self).__init__(**kwargs) + + J_regressor_extra = np.load(J_path) + self.register_buffer( + 'J_regressor_extra', + torch.tensor(J_regressor_extra, dtype=torch.float32)) + vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES]) + a2m_indexes = vibe_indexes[action2motion_joints] + smpl_indexes = np.arange(24) + a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes]) + + self.maps = { + "vibe": vibe_indexes, + "a2m": a2m_indexes, + "smpl": smpl_indexes, + "a2mpl": a2mpl_indexes + } + + def forward(self, *args, **kwargs): + smpl_output = super(SMPL, self).forward(*args, **kwargs) + + extra_joints = vertices2joints(self.J_regressor_extra, + smpl_output.vertices) + all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1) + + output = {"vertices": smpl_output.vertices} + + for joinstype, indexes in self.maps.items(): + output[joinstype] = all_joints[:, indexes] + + return output diff --git a/Evaluator_272/mld/transforms/xyz.py b/Evaluator_272/mld/transforms/xyz.py new file mode 100644 index 0000000000000000000000000000000000000000..f8590ea8f54fbb907cda85a5daa41bd9299ea1db --- /dev/null +++ b/Evaluator_272/mld/transforms/xyz.py @@ -0,0 +1,66 @@ +from typing import Optional +from torch import Tensor + +from .base import Datastruct, dataclass, Transform +from mld.datasets.utils import collate_tensor_with_padding + +from .joints2jfeats import Joints2Jfeats + + +class XYZTransform(Transform): + + def __init__(self, joints2jfeats: Joints2Jfeats, **kwargs): + self.joints2jfeats = joints2jfeats + + def Datastruct(self, **kwargs): + return XYZDatastruct(_joints2jfeats=self.joints2jfeats, + transforms=self, + **kwargs) + + def __repr__(self): + return "XYZTransform()" + + +@dataclass +class XYZDatastruct(Datastruct): + transforms: XYZTransform + _joints2jfeats: Joints2Jfeats + + features: Optional[Tensor] = None + joints_: Optional[Tensor] = None + jfeats_: Optional[Tensor] = None + + def __post_init__(self): + self.datakeys = ["features", "joints_", "jfeats_"] + # starting point + if self.features is not None and self.jfeats_ is None: + self.jfeats_ = self.features + + @property + def joints(self): + # Cached value + if self.joints_ is not None: + return self.joints_ + + # self.jfeats_ should be defined + assert self.jfeats_ is not None + + self._joints2jfeats.to(self.jfeats.device) + self.joints_ = self._joints2jfeats.inverse(self.jfeats) + return self.joints_ + + @property + def jfeats(self): + # Cached value + if self.jfeats_ is not None: + return self.jfeats_ + + # self.joints_ should be defined + assert self.joints_ is not None + + self._joints2jfeats.to(self.joints.device) + self.jfeats_ = self._joints2jfeats(self.joints) + return self.jfeats_ + + def __len__(self): + return len(self.jfeats) diff --git a/Evaluator_272/mld/utils/__init__.py b/Evaluator_272/mld/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Evaluator_272/mld/utils/demo_utils.py b/Evaluator_272/mld/utils/demo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6299332523f4ed07c72740b4eb433a6550022c --- /dev/null +++ b/Evaluator_272/mld/utils/demo_utils.py @@ -0,0 +1,79 @@ +import os +from pathlib import Path + + +# load example data +def load_example_input(txt_path): + file = open(txt_path, "r") + Lines = file.readlines() + count = 0 + texts, lens = [], [] + # Strips the newline character + for line in Lines: + count += 1 + s = line.strip() + s_l = s.split(" ")[0] + s_t = s[(len(s_l) + 1):] + lens.append(int(s_l)) + texts.append(s_t) + print("Length-{}: {}".format(s_l, s_t)) + return texts, lens + + +# render batch +def render_batch(npy_dir, execute_python="./scripts/visualize_motion.sh", mode="sequence"): + os.system(f"{execute_python} {npy_dir} {mode}") + + +# render +def render(execute_python, npy_path, jointtype, cfg_path): + # execute_python = "/apdcephfs/share_1227775/shingxchen/libs/blender_bpy/blender-2.93.2-linux-x64/blender" + # execute_python = "/apdcephfs/share_1227775/mingzhenzhu/jiangbiao/libs/blender-2.93.2-linux-x64/blender" + export_scripts = "render.py" + + os.system( + f"{execute_python} --background --python {export_scripts} -- --cfg={cfg_path} --npy={npy_path} --joint_type={jointtype}" + ) + + fig_path = Path(str(npy_path).replace(".npy", ".png")) + return fig_path + + +# origin render +# def render(npy_path, jointtype): +# execute_python = '/apdcephfs/share_1227775/shingxchen/libs/blender_bpy/blender-2.93.2-linux-x64/blender' +# export_scripts = 'render.py' + +# os.system(f"{execute_python} --background --python {export_scripts} -- npy={npy_path} jointstype={jointtype}") + +# fig_path = Path(str(npy_path).replace(".npy",".png")) +# return fig_path + +# export fbx with hand params from pkl files +# refer to /apdcephfs/share_1227775/shingxchen/AIMotion/TMOST/scripts/fbx_output_smplx.py +def export_fbx_hand(pkl_path): + input = pkl_path + output = pkl_path.replace(".pkl", ".fbx") + + execute_python = "/apdcephfs/share_1227775/shingxchen/libs/blender_bpy/blender-2.93.2-linux-x64/blender" + export_scripts = "./scripts/fbx_output_smplx.py" + os.system( + f"{execute_python} -noaudio --background --python {export_scripts}\ + --input {input} \ + --output {output}" + ) + + +# export fbx without hand params from pkl files +# refer to /apdcephfs/share_1227775/shingxchen/AIMotion/TMOST/scripts/fbx_output.py +def export_fbx(pkl_path): + input = pkl_path + output = pkl_path.replace(".pkl", ".fbx") + + execute_python = "/apdcephfs/share_1227775/shingxchen/libs/blender_bpy/blender-2.93.2-linux-x64/blender" + export_scripts = "./scripts/fbx_output.py" + os.system( + f"{execute_python} -noaudio --background --python {export_scripts}\ + --input {input} \ + --output {output}" + ) diff --git a/Evaluator_272/mld/utils/easyconvert.py b/Evaluator_272/mld/utils/easyconvert.py new file mode 100644 index 0000000000000000000000000000000000000000..ba4061c4904d6d5ee807c85adff6fb721f8ed548 --- /dev/null +++ b/Evaluator_272/mld/utils/easyconvert.py @@ -0,0 +1,73 @@ +import mld.utils.geometry as geometry + + +def nfeats_of(rottype): + if rottype in ["rotvec", "axisangle"]: + return 3 + elif rottype in ["rotquat", "quaternion"]: + return 4 + elif rottype in ["rot6d", "6drot", "rotation6d"]: + return 6 + elif rottype in ["rotmat"]: + return 9 + else: + return TypeError("This rotation type doesn't have features.") + + +def axis_angle_to(newtype, rotations): + if newtype in ["matrix"]: + rotations = geometry.axis_angle_to_matrix(rotations) + return rotations + elif newtype in ["rotmat"]: + rotations = geometry.axis_angle_to_matrix(rotations) + rotations = matrix_to("rotmat", rotations) + return rotations + elif newtype in ["rot6d", "6drot", "rotation6d"]: + rotations = geometry.axis_angle_to_matrix(rotations) + rotations = matrix_to("rot6d", rotations) + return rotations + elif newtype in ["rotquat", "quaternion"]: + rotations = geometry.axis_angle_to_quaternion(rotations) + return rotations + elif newtype in ["rotvec", "axisangle"]: + return rotations + else: + raise NotImplementedError + + +def matrix_to(newtype, rotations): + if newtype in ["matrix"]: + return rotations + if newtype in ["rotmat"]: + rotations = rotations.reshape((*rotations.shape[:-2], 9)) + return rotations + elif newtype in ["rot6d", "6drot", "rotation6d"]: + rotations = geometry.matrix_to_rotation_6d(rotations) + return rotations + elif newtype in ["rotquat", "quaternion"]: + rotations = geometry.matrix_to_quaternion(rotations) + return rotations + elif newtype in ["rotvec", "axisangle"]: + rotations = geometry.matrix_to_axis_angle(rotations) + return rotations + else: + raise NotImplementedError + + +def to_matrix(oldtype, rotations): + if oldtype in ["matrix"]: + return rotations + if oldtype in ["rotmat"]: + rotations = rotations.reshape((*rotations.shape[:-2], 3, 3)) + return rotations + elif oldtype in ["rot6d", "6drot", "rotation6d"]: + rotations = geometry.rotation_6d_to_matrix(rotations) + return rotations + elif oldtype in ["rotquat", "quaternion"]: + rotations = geometry.quaternion_to_matrix(rotations) + return rotations + elif oldtype in ["rotvec", "axisangle"]: + rotations = geometry.axis_angle_to_matrix(rotations) + return rotations + else: + raise NotImplementedError diff --git a/Evaluator_272/mld/utils/fixseed.py b/Evaluator_272/mld/utils/fixseed.py new file mode 100644 index 0000000000000000000000000000000000000000..a43a273b138c45dccafef4da3628dd4c2a3f84a4 --- /dev/null +++ b/Evaluator_272/mld/utils/fixseed.py @@ -0,0 +1,18 @@ +import numpy as np +import torch +import random + + +def fixseed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +SEED = 10 +EVALSEED = 0 +# Provoc warning: not fully functionnal yet +# torch.set_deterministic(True) +torch.backends.cudnn.benchmark = False + +fixseed(SEED) diff --git a/Evaluator_272/mld/utils/geometry.py b/Evaluator_272/mld/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..86bf6ae2bcee2580d44281fcae9125f70470e952 --- /dev/null +++ b/Evaluator_272/mld/utils/geometry.py @@ -0,0 +1,473 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import torch +import numpy as np +from torch.nn import functional as F + + +def matrix_of_angles(cos, sin, inv=False, dim=2): + assert dim in [2, 3] + sin = -sin if inv else sin + if dim == 2: + row1 = torch.stack((cos, -sin), axis=-1) + row2 = torch.stack((sin, cos), axis=-1) + return torch.stack((row1, row2), axis=-2) + elif dim == 3: + row1 = torch.stack((cos, -sin, 0 * cos), axis=-1) + row2 = torch.stack((sin, cos, 0 * cos), axis=-1) + row3 = torch.stack((0 * sin, 0 * cos, 1 + 0 * cos), axis=-1) + return torch.stack((row1, row2, row3), axis=-2) + + +def matrot2axisangle(matrots): + # This function is borrowed from https://github.com/davrempe/humor/utils/transforms.py + # axisang N x 3 + ''' + :param matrots: N*num_joints*9 + :return: N*num_joints*3 + ''' + import cv2 + batch_size = matrots.shape[0] + matrots = matrots.reshape([batch_size, -1, 9]) + out_axisangle = [] + for mIdx in range(matrots.shape[0]): + cur_axisangle = [] + for jIdx in range(matrots.shape[1]): + a = cv2.Rodrigues(matrots[mIdx, + jIdx:jIdx + 1, :].reshape(3, + 3))[0].reshape( + (1, 3)) + cur_axisangle.append(a) + + out_axisangle.append(np.array(cur_axisangle).reshape([1, -1, 3])) + return np.vstack(out_axisangle) + + +def axisangle2matrots(axisangle): + # This function is borrowed from https://github.com/davrempe/humor/utils/transforms.py + # axisang N x 3 + ''' + :param axisangle: N*num_joints*3 + :return: N*num_joints*9 + ''' + import cv2 + batch_size = axisangle.shape[0] + axisangle = axisangle.reshape([batch_size, -1, 3]) + out_matrot = [] + for mIdx in range(axisangle.shape[0]): + cur_axisangle = [] + for jIdx in range(axisangle.shape[1]): + a = cv2.Rodrigues(axisangle[mIdx, jIdx:jIdx + 1, :].reshape(1, + 3))[0] + cur_axisangle.append(a) + + out_matrot.append(np.array(cur_axisangle).reshape([1, -1, 9])) + return np.vstack(out_matrot) + + +def batch_rodrigues(axisang): + # This function is borrowed from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py#L37 + # axisang N x 3 + axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(axisang_norm, -1) + axisang_normalized = torch.div(axisang, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + + quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1) + rot_mat = quat2mat(quat) + rot_mat = rot_mat.view(rot_mat.shape[0], 9) + return rot_mat + + +def quat2mat(quat): + """ + This function is borrowed from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py#L50 + + Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [batch_size, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, + 2], norm_quat[:, + 3] + + batch_size = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, + w2 - x2 - y2 + z2 + ], + dim=1).view(batch_size, 3, 3) + return rotMat + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to Rodrigues vector + + Args: + rotation_matrix (Tensor): rotation matrix. + + Returns: + Tensor: Rodrigues vector transformation. + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 3)` + + Example: + >>> input = torch.rand(2, 3, 4) # Nx4x4 + >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 + """ + if rotation_matrix.shape[1:] == (3, 3): + rot_mat = rotation_matrix.reshape(-1, 3, 3) + hom = torch.tensor([0, 0, 1], + dtype=torch.float32, + device=rotation_matrix.device).reshape( + 1, 3, 1).expand(rot_mat.shape[0], -1, -1) + rotation_matrix = torch.cat([rot_mat, hom], dim=-1) + + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + aa = quaternion_to_angle_axis(quaternion) + aa[torch.isnan(aa)] = 0.0 + return aa + + +def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert quaternion vector to angle axis of rotation. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (torch.Tensor): tensor with quaternions. + + Return: + torch.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError( + "Input must be a tensor of shape Nx4 or 4. Got {}".format( + quaternion.shape)) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) + cos_theta: torch.Tensor = quaternion[..., 0] + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta)) + + k_pos: torch.Tensor = two_theta / sin_theta + k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) + k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to 4d quaternion vector + + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + Args: + rotation_matrix (Tensor): the rotation matrix to convert. + + Return: + Tensor: the rotation in quaternion + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 4)` + + Example: + >>> input = torch.rand(4, 3, 4) # Nx3x4 + >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 + """ + if not torch.is_tensor(rotation_matrix): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(rotation_matrix))) + + if len(rotation_matrix.shape) > 3: + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format( + rotation_matrix.shape)) + if not rotation_matrix.shape[-2:] == (3, 4): + raise ValueError( + "Input size must be a N x 3 x 4 tensor. Got {}".format( + rotation_matrix.shape)) + + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack([ + rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2] + ], -1) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack([ + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1] + ], -1) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack([ + rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2 + ], -1) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack([ + t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0] + ], -1) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * ~mask_d0_d1 + mask_c2 = ~mask_d2 * mask_d0_nd1 + mask_c3 = ~mask_d2 * ~mask_d0_nd1 + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa + q *= 0.5 + return q + + +def estimate_translation_np(S, + joints_2d, + joints_conf, + focal_length=5000., + img_size=224.): + """ + This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py + + Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (25, 3) 3D joint locations + joints: (25, 3) 2D joint locations and confidence + Returns: + (3,) camera translation vector + """ + + num_joints = S.shape[0] + # focal length + f = np.array([focal_length, focal_length]) + # optical center + center = np.array([img_size / 2., img_size / 2.]) + + # transformations + Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1) + XY = np.reshape(S[:, 0:2], -1) + O = np.tile(center, num_joints) + F = np.tile(f, num_joints) + weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1) + + # least squares + Q = np.array([ + F * np.tile(np.array([1, 0]), num_joints), + F * np.tile(np.array([0, 1]), num_joints), + O - np.reshape(joints_2d, -1) + ]).T + c = (np.reshape(joints_2d, -1) - O) * Z - F * XY + + # weighted least squares + W = np.diagflat(weight2) + Q = np.dot(W, Q) + c = np.dot(W, c) + + # square matrix + A = np.dot(Q.T, Q) + b = np.dot(Q.T, c) + + # solution + trans = np.linalg.solve(A, b) + + return trans + + +def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.): + """ + This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py + + Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (B, 49, 3) 3D joint locations + joints: (B, 49, 3) 2D joint locations and confidence + Returns: + (B, 3) camera translation vectors + """ + + device = S.device + # Use only joints 25:49 (GT joints) + S = S[:, 25:, :].cpu().numpy() + joints_2d = joints_2d[:, 25:, :].cpu().numpy() + joints_conf = joints_2d[:, :, -1] + joints_2d = joints_2d[:, :, :-1] + trans = np.zeros((S.shape[0], 3), dtype=np.float6432) + # Find the translation for each example in the batch + for i in range(S.shape[0]): + S_i = S[i] + joints_i = joints_2d[i] + conf_i = joints_conf[i] + trans[i] = estimate_translation_np(S_i, + joints_i, + conf_i, + focal_length=focal_length, + img_size=img_size) + return torch.from_numpy(trans).to(device) + + +def rot6d_to_rotmat_spin(x): + """Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1, 3, 2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + + # inp = a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1 + # denom = inp.pow(2).sum(dim=1).sqrt().unsqueeze(-1) + 1e-8 + # b2 = inp / denom + + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + + +def rot6d_to_rotmat(x): + x = x.view(-1, 3, 2) + + # Normalize the first vector + b1 = F.normalize(x[:, :, 0], dim=1, eps=1e-6) + + dot_prod = torch.sum(b1 * x[:, :, 1], dim=1, keepdim=True) + # Compute the second vector by finding the orthogonal complement to it + b2 = F.normalize(x[:, :, 1] - dot_prod * b1, dim=-1, eps=1e-6) + + # Finish building the basis by taking the cross product + b3 = torch.cross(b1, b2, dim=1) + rot_mats = torch.stack([b1, b2, b3], dim=-1) + + return rot_mats + + +import mld.utils.rotation_conversions as rotation_conversions + + +def rot6d(x_rotations, pose_rep): + time, njoints, feats = x_rotations.shape + + # Compute rotations (convert only masked sequences output) + if pose_rep == "rotvec": + rotations = rotation_conversions.axis_angle_to_matrix(x_rotations) + elif pose_rep == "rotmat": + rotations = x_rotations.view(njoints, 3, 3) + elif pose_rep == "rotquat": + rotations = rotation_conversions.quaternion_to_matrix(x_rotations) + elif pose_rep == "rot6d": + rotations = rotation_conversions.rotation_6d_to_matrix(x_rotations) + else: + raise NotImplementedError("No geometry for this one.") + + rotations_6d = rotation_conversions.matrix_to_rotation_6d(rotations) + return rotations_6d + + +def rot6d_batch(x_rotations, pose_rep): + nsamples, time, njoints, feats = x_rotations.shape + + # Compute rotations (convert only masked sequences output) + if pose_rep == "rotvec": + rotations = rotation_conversions.axis_angle_to_matrix(x_rotations) + elif pose_rep == "rotmat": + rotations = x_rotations.view(-1, njoints, 3, 3) + elif pose_rep == "rotquat": + rotations = rotation_conversions.quaternion_to_matrix(x_rotations) + elif pose_rep == "rot6d": + rotations = rotation_conversions.rotation_6d_to_matrix(x_rotations) + else: + raise NotImplementedError("No geometry for this one.") + + rotations_6d = rotation_conversions.matrix_to_rotation_6d(rotations) + return rotations_6d + + +def rot6d_to_rotvec_batch(pose): + # nsamples, time, njoints, feats = rot6d.shape + bs, nfeats = pose.shape + rot6d = pose.reshape(bs, 24, 6) + rotations = rotation_conversions.rotation_6d_to_matrix(rot6d) + rotvec = rotation_conversions.matrix_to_axis_angle(rotations) + return rotvec.reshape(bs, 24 * 3) diff --git a/Evaluator_272/mld/utils/joints.py b/Evaluator_272/mld/utils/joints.py new file mode 100644 index 0000000000000000000000000000000000000000..ffbddbb79289fa1c71d007a4c37e2bba34e847e0 --- /dev/null +++ b/Evaluator_272/mld/utils/joints.py @@ -0,0 +1,291 @@ +mmm_joints = [ + "root", + "BP", + "BT", + "BLN", + "BUN", + "LS", + "LE", + "LW", + "RS", + "RE", + "RW", + "LH", + "LK", + "LA", + "LMrot", + "LF", + "RH", + "RK", + "RA", + "RMrot", + "RF", +] + +humanml3d_joints = [ + "root", + "RH", + "LH", + "BP", + "RK", + "LK", + "BT", + "RMrot", + "LMrot", + "BLN", + "RF", + "LF", + "BMN", + "RSI", + "LSI", + "BUN", + "RS", + "LS", + "RE", + "LE", + "RW", + "LW", +] + + +motionx_joints = [ + "root", + "RH", + "LH", + "BP", + "RK", + "LK", + "BT", + "RMrot", + "LMrot", + "BLN", + "RF", + "LF", + "BMN", + "RSI", + "LSI", + "BUN", + "RS", + "LS", + "RE", + "LE", + "RW", + "LW", +] + +smplh_joints = [ + "pelvis", + "left_hip", + "right_hip", + "spine1", + "left_knee", + "right_knee", + "spine2", + "left_ankle", + "right_ankle", + "spine3", + "left_foot", + "right_foot", + "neck", + "left_collar", + "right_collar", + "head", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_index1", + "left_index2", + "left_index3", + "left_middle1", + "left_middle2", + "left_middle3", + "left_pinky1", + "left_pinky2", + "left_pinky3", + "left_ring1", + "left_ring2", + "left_ring3", + "left_thumb1", + "left_thumb2", + "left_thumb3", + "right_index1", + "right_index2", + "right_index3", + "right_middle1", + "right_middle2", + "right_middle3", + "right_pinky1", + "right_pinky2", + "right_pinky3", + "right_ring1", + "right_ring2", + "right_ring3", + "right_thumb1", + "right_thumb2", + "right_thumb3", + "nose", + "right_eye", + "left_eye", + "right_ear", + "left_ear", + "left_big_toe", + "left_small_toe", + "left_heel", + "right_big_toe", + "right_small_toe", + "right_heel", + "left_thumb", + "left_index", + "left_middle", + "left_ring", + "left_pinky", + "right_thumb", + "right_index", + "right_middle", + "right_ring", + "right_pinky", +] + +smplnh_joints = [ + "pelvis", + "left_hip", + "right_hip", + "spine1", + "left_knee", + "right_knee", + "spine2", + "left_ankle", + "right_ankle", + "spine3", + "left_foot", + "right_foot", + "neck", + "left_collar", + "right_collar", + "head", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", +] + + +mmm2smplh_correspondence = { + "root": "pelvis", + "BP": "spine1", + "BT": "spine3", + "BLN": "neck", + "BUN": "head", + "LS": "left_shoulder", + "LE": "left_elbow", + "LW": "left_wrist", + "RS": "right_shoulder", + "RE": "right_elbow", + "RW": "right_wrist", + "LH": "left_hip", + "LK": "left_knee", + "LA": "left_ankle", + "LMrot": "left_heel", + "LF": "left_foot", + "RH": "right_hip", + "RK": "right_knee", + "RA": "right_ankle", + "RMrot": "right_heel", + "RF": "right_foot", +} + +smplh2mmm_correspondence = {val: key for key, val in mmm2smplh_correspondence.items()} +smplh2mmm_indexes = [ + smplh_joints.index(mmm2smplh_correspondence[x]) for x in mmm_joints +] + +smplnh2smplh_correspondence = {key: key for key in smplnh_joints} +smplh2smplnh_correspondence = { + val: key for key, val in smplnh2smplh_correspondence.items() +} + +smplh2smplnh_indexes = [ + smplh_joints.index(smplnh2smplh_correspondence[x]) for x in smplnh_joints +] + + +mmm_kinematic_tree = [ + [0, 1, 2, 3, 4], # body + [3, 5, 6, 7], # right arm + [3, 8, 9, 10], # left arm + [0, 11, 12, 13, 14, 15], # right leg + [0, 16, 17, 18, 19, 20], +] # left leg + +humanml3d_kinematic_tree = [ + [0, 3, 6, 9, 12, 15], # body + [9, 14, 17, 19, 21], # right arm + [9, 13, 16, 18, 20], # left arm + [0, 2, 5, 8, 11], # right leg + [0, 1, 4, 7, 10], +] # left leg + +smplh_to_mmm_scaling_factor = 480 / 0.75 +mmm_to_smplh_scaling_factor = 0.75 / 480 + +mmm_joints_info = { + "root": mmm_joints.index("root"), + "feet": [ + mmm_joints.index("LMrot"), + mmm_joints.index("RMrot"), + mmm_joints.index("LF"), + mmm_joints.index("RF"), + ], + "shoulders": [mmm_joints.index("LS"), mmm_joints.index("RS")], + "hips": [mmm_joints.index("LH"), mmm_joints.index("RH")], +} + +smplnh_joints_info = { + "root": smplnh_joints.index("pelvis"), + "feet": [ + smplnh_joints.index("left_ankle"), + smplnh_joints.index("right_ankle"), + smplnh_joints.index("left_foot"), + smplnh_joints.index("right_foot"), + ], + "shoulders": [ + smplnh_joints.index("left_shoulder"), + smplnh_joints.index("right_shoulder"), + ], + "hips": [smplnh_joints.index("left_hip"), smplnh_joints.index("right_hip")], +} + + +infos = {"mmm": mmm_joints_info, "smplnh": smplnh_joints_info} + +smplh_indexes = {"mmm": smplh2mmm_indexes, "smplnh": smplh2smplnh_indexes} + + +root_joints = { + "mmm": mmm_joints_info["root"], + "mmmns": mmm_joints_info["root"], + "smplmmm": mmm_joints_info["root"], + "smplnh": smplnh_joints_info["root"], + "smplh": smplh_joints.index("pelvis"), +} + + +def get_root_idx(joinstype): + return root_joints[joinstype] + + +# def mmm2smpl(joints_mmm): +# mmm2smplnh_indexes = [] +# for x in smplnh_joints: +# if x in smplh2mmm_correspondence: +# mmm2smplnh_indexes.append(mmm_joints.index(smplh2mmm_correspondence[x])) + +# spine2 = 0.5*(joints[mmm_joints.index("spine1")] + joints[mmm_joints.index("spine3")]) + +# joints = joints_mmm[indexes] +# return joints diff --git a/Evaluator_272/mld/utils/logger.py b/Evaluator_272/mld/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..a9eacb06597aafa9973ff63fb75309327622709a --- /dev/null +++ b/Evaluator_272/mld/utils/logger.py @@ -0,0 +1,71 @@ +from pathlib import Path +import os +import time +import logging +from omegaconf import OmegaConf +from pytorch_lightning.utilities.rank_zero import rank_zero_only + + +def create_logger(cfg, phase='train'): + # root dir set by cfg + root_output_dir = Path(cfg.FOLDER) + # set up logger + if not root_output_dir.exists(): + print('=> creating {}'.format(root_output_dir)) + root_output_dir.mkdir() + + cfg_name = cfg.NAME + model = cfg.model.model_type + cfg_name = os.path.basename(cfg_name).split('.')[0] + + final_output_dir = root_output_dir / model / cfg_name + cfg.FOLDER_EXP = str(final_output_dir) + + time_str = time.strftime('%Y-%m-%d-%H-%M-%S') + + new_dir(cfg, phase, time_str, final_output_dir) + + head = '%(asctime)-15s %(message)s' + logger = config_logger(final_output_dir, time_str, phase, head) + if logger is None: + logger = logging.getLogger() + logger.setLevel(logging.CRITICAL) + logging.basicConfig(format=head) + return logger + + +@rank_zero_only +def config_logger(final_output_dir, time_str, phase, head): + log_file = '{}_{}_{}.log'.format('log', time_str, phase) + final_log_file = final_output_dir / log_file + logging.basicConfig(filename=str(final_log_file)) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + console = logging.StreamHandler() + formatter = logging.Formatter(head) + console.setFormatter(formatter) + logging.getLogger('').addHandler(console) + file_handler = logging.FileHandler(final_log_file, 'w') + file_handler.setFormatter(logging.Formatter(head)) + file_handler.setLevel(logging.INFO) + logging.getLogger('').addHandler(file_handler) + return logger + + +@rank_zero_only +def new_dir(cfg, phase, time_str, final_output_dir): + # new experiment folder + cfg.TIME = str(time_str) + if os.path.exists( + final_output_dir) and cfg.TRAIN.RESUME is None and not cfg.DEBUG: + file_list = sorted(os.listdir(final_output_dir), reverse=True) + for item in file_list: + if item.endswith('.log'): + os.rename(str(final_output_dir), + str(final_output_dir) + '_' + cfg.TIME) + break + final_output_dir.mkdir(parents=True, exist_ok=True) + # write config yaml + config_file = '{}_{}_{}.yaml'.format('config', time_str, phase) + final_config_file = final_output_dir / config_file + OmegaConf.save(config=cfg, f=final_config_file) diff --git a/Evaluator_272/mld/utils/misc.py b/Evaluator_272/mld/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2a68d68019098e66905e0e21cb96678031bab0 --- /dev/null +++ b/Evaluator_272/mld/utils/misc.py @@ -0,0 +1,29 @@ +import torch + + +def to_numpy(tensor): + if torch.is_tensor(tensor): + return tensor.cpu().numpy() + elif type(tensor).__module__ != 'numpy': + raise ValueError("Cannot convert {} to numpy array".format( + type(tensor))) + return tensor + + +def to_torch(ndarray): + if type(ndarray).__module__ == 'numpy': + return torch.from_numpy(ndarray) + elif not torch.is_tensor(ndarray): + raise ValueError("Cannot convert {} to torch tensor".format( + type(ndarray))) + return ndarray + + +def cleanexit(): + import sys + import os + try: + sys.exit(0) + except SystemExit: + os._exit(0) + diff --git a/Evaluator_272/mld/utils/rotation_conversions.py b/Evaluator_272/mld/utils/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..770c3bf36f05fcaf89cbb03e17035357f3c0a4df --- /dev/null +++ b/Evaluator_272/mld/utils/rotation_conversions.py @@ -0,0 +1,551 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Check PYTORCH3D_LICENCE before use + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/Evaluator_272/mld/utils/sample_utils.py b/Evaluator_272/mld/utils/sample_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..724b109af5ecd08e20cf0af681df317b390d4d44 --- /dev/null +++ b/Evaluator_272/mld/utils/sample_utils.py @@ -0,0 +1,18 @@ +import logging +from pathlib import Path +logger = logging.getLogger(__name__) + +def cfg_mean_nsamples_resolution(cfg): + if cfg.mean and cfg.number_of_samples > 1: + logger.error("All the samples will be the mean.. cfg.number_of_samples=1 will be forced.") + cfg.number_of_samples = 1 + + return cfg.number_of_samples == 1 + + +def get_path(sample_path: Path, is_amass: bool, gender: str, split: str, onesample: bool, mean: bool, fact: float): + extra_str = ("_mean" if mean else "") if onesample else "_multi" + fact_str = "" if fact == 1 else f"{fact}_" + gender_str = gender + "_" if is_amass else "" + path = sample_path / f"{fact_str}{gender_str}{split}{extra_str}" + return path diff --git a/Evaluator_272/mld/utils/temos_utils.py b/Evaluator_272/mld/utils/temos_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2fd47eb1437a39c79c645d0d1f6e558a4fe5109f --- /dev/null +++ b/Evaluator_272/mld/utils/temos_utils.py @@ -0,0 +1,133 @@ +from typing import Dict, List + +import numpy as np +import torch +from torch import Tensor + +import mld.utils.geometry as geometry + + +def lengths_to_mask(lengths: List[int], + device: torch.device, + max_len: int = None) -> Tensor: + lengths = torch.tensor(lengths, device=device) + max_len = max_len if max_len else max(lengths) + mask = torch.arange(max_len, device=device).expand( + len(lengths), max_len) < lengths.unsqueeze(1) + return mask + + +def detach_to_numpy(tensor): + return tensor.detach().cpu().numpy() + + +def remove_padding(tensors, lengths): + return [ + tensor[:tensor_length] + for tensor, tensor_length in zip(tensors, lengths) + ] + + +def nfeats_of(rottype): + if rottype in ["rotvec", "axisangle"]: + return 3 + elif rottype in ["rotquat", "quaternion"]: + return 4 + elif rottype in ["rot6d", "6drot", "rotation6d"]: + return 6 + elif rottype in ["rotmat"]: + return 9 + else: + return TypeError("This rotation type doesn't have features.") + + +def axis_angle_to(newtype, rotations): + if newtype in ["matrix"]: + rotations = geometry.axis_angle_to_matrix(rotations) + return rotations + elif newtype in ["rotmat"]: + rotations = geometry.axis_angle_to_matrix(rotations) + rotations = matrix_to("rotmat", rotations) + return rotations + elif newtype in ["rot6d", "6drot", "rotation6d"]: + rotations = geometry.axis_angle_to_matrix(rotations) + rotations = matrix_to("rot6d", rotations) + return rotations + elif newtype in ["rotquat", "quaternion"]: + rotations = geometry.axis_angle_to_quaternion(rotations) + return rotations + elif newtype in ["rotvec", "axisangle"]: + return rotations + else: + raise NotImplementedError + + +def matrix_to(newtype, rotations): + if newtype in ["matrix"]: + return rotations + if newtype in ["rotmat"]: + rotations = rotations.reshape((*rotations.shape[:-2], 9)) + return rotations + elif newtype in ["rot6d", "6drot", "rotation6d"]: + rotations = geometry.matrix_to_rotation_6d(rotations) + return rotations + elif newtype in ["rotquat", "quaternion"]: + rotations = geometry.matrix_to_quaternion(rotations) + return rotations + elif newtype in ["rotvec", "axisangle"]: + rotations = geometry.matrix_to_axis_angle(rotations) + return rotations + else: + raise NotImplementedError + + +def to_matrix(oldtype, rotations): + if oldtype in ["matrix"]: + return rotations + if oldtype in ["rotmat"]: + rotations = rotations.reshape((*rotations.shape[:-2], 3, 3)) + return rotations + elif oldtype in ["rot6d", "6drot", "rotation6d"]: + rotations = geometry.rotation_6d_to_matrix(rotations) + return rotations + elif oldtype in ["rotquat", "quaternion"]: + rotations = geometry.quaternion_to_matrix(rotations) + return rotations + elif oldtype in ["rotvec", "axisangle"]: + rotations = geometry.axis_angle_to_matrix(rotations) + return rotations + else: + raise NotImplementedError + + +# TODO: use a real subsampler.. +def subsample(num_frames, last_framerate, new_framerate): + step = int(last_framerate / new_framerate) + assert step >= 1 + frames = np.arange(0, num_frames, step) + return frames + + +# TODO: use a real upsampler.. +def upsample(motion, last_framerate, new_framerate): + step = int(new_framerate / last_framerate) + assert step >= 1 + + # Alpha blending => interpolation + alpha = np.linspace(0, 1, step + 1) + last = np.einsum("l,...->l...", 1 - alpha, motion[:-1]) + new = np.einsum("l,...->l...", alpha, motion[1:]) + + chuncks = (last + new)[:-1] + output = np.concatenate(chuncks.swapaxes(1, 0)) + # Don't forget the last one + output = np.concatenate((output, motion[[-1]])) + return output + + +if __name__ == "__main__": + motion = np.arange(105) + submotion = motion[subsample(len(motion), 100.0, 12.5)] + newmotion = upsample(submotion, 12.5, 100) + + print(newmotion) diff --git a/Evaluator_272/mld/utils/tensors.py b/Evaluator_272/mld/utils/tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..166143893e5ad1494e3bdf8a9a12261f61e77335 --- /dev/null +++ b/Evaluator_272/mld/utils/tensors.py @@ -0,0 +1,74 @@ +import torch + + +def lengths_to_mask(lengths): + max_len = max(lengths) + mask = torch.arange(max_len, device=lengths.device).expand( + len(lengths), max_len) < lengths.unsqueeze(1) + return mask + + +def collate_tensors(batch): + dims = batch[0].dim() + max_size = [max([b.size(i) for b in batch]) for i in range(dims)] + size = (len(batch),) + tuple(max_size) + canvas = batch[0].new_zeros(size=size) + for i, b in enumerate(batch): + sub_tensor = canvas[i] + for d in range(dims): + sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) + sub_tensor.add_(b) + return canvas + + +def collate(batch): + databatch = [b[0] for b in batch] + labelbatch = [b[1] for b in batch] + lenbatch = [len(b[0][0][0]) for b in batch] + + databatchTensor = collate_tensors(databatch) + labelbatchTensor = torch.as_tensor(labelbatch) + lenbatchTensor = torch.as_tensor(lenbatch) + + maskbatchTensor = lengths_to_mask(lenbatchTensor) + # x - [bs, njoints, nfeats, lengths] + # - nfeats, the representation of a joint + # y - [bs] + # mask - [bs, lengths] + # lengths - [bs] + batch = {"x": databatchTensor, "y": labelbatchTensor, + "mask": maskbatchTensor, 'lengths': lenbatchTensor} + return batch + + +# slow version with padding +def collate_data3d_slow(batch): + batchTensor = {} + for key in batch[0].keys(): + databatch = [b[key] for b in batch] + batchTensor[key] = collate_tensors(databatch) + batch = batchTensor + # theta - [bs, lengths, 85], theta shape (85,) + # - (np.array([1., 0., 0.]), pose(72), shape(10)), axis=0) + # kp_2d - [bs, lengths, njoints, nfeats], nfeats (x,y,weight) + # kp_3d - [bs, lengths, njoints, nfeats], nfeats (x,y,z) + # w_smpl - [bs, lengths] zeros + # w_3d - [bs, lengths] zeros + return batch + +def collate_data3d(batch): + batchTensor = {} + for key in batch[0].keys(): + databatch = [b[key] for b in batch] + if key == "paths": + batchTensor[key] = databatch + else: + batchTensor[key] = torch.stack(databatch,axis=0) + batch = batchTensor + # theta - [bs, lengths, 85], theta shape (85,) + # - (np.array([1., 0., 0.]), pose(72), shape(10)), axis=0) + # kp_2d - [bs, lengths, njoints, nfeats], nfeats (x,y,weight) + # kp_3d - [bs, lengths, njoints, nfeats], nfeats (x,y,z) + # w_smpl - [bs, lengths] zeros + # w_3d - [bs, lengths] zeros + return batch diff --git a/Evaluator_272/train.py b/Evaluator_272/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2c1936d781699860c5bd921909cceea787b93019 --- /dev/null +++ b/Evaluator_272/train.py @@ -0,0 +1,255 @@ +import os +from pprint import pformat + +import pytorch_lightning as pl +import torch +from omegaconf import OmegaConf +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks import ModelCheckpoint +# from pytorch_lightning.strategies.ddp import DDPStrategy + +from mld.callback import ProgressLogger +from mld.config import parse_args +from mld.data.get_data import get_datasets +from mld.models.get_model import get_model +from mld.utils.logger import create_logger + + +def main(): + cfg = parse_args() # parse config file + + # create logger + logger = create_logger(cfg, phase="train") + + # resume + if cfg.TRAIN.RESUME: + resume = cfg.TRAIN.RESUME + backcfg = cfg.TRAIN.copy() + if os.path.exists(resume): + file_list = sorted(os.listdir(resume), reverse=True) + for item in file_list: + if item.endswith(".yaml"): + cfg = OmegaConf.load(os.path.join(resume, item)) + cfg.TRAIN = backcfg + break + checkpoints = sorted(os.listdir(os.path.join( + resume, "checkpoints")), + key=lambda x: int(x[6:-5]), + reverse=True) + for checkpoint in checkpoints: + if "epoch=" in checkpoint: + cfg.TRAIN.PRETRAINED = os.path.join( + resume, "checkpoints", checkpoint) + break + if os.path.exists(os.path.join(resume, "wandb")): + wandb_list = sorted(os.listdir(os.path.join(resume, "wandb")), + reverse=True) + for item in wandb_list: + if "run-" in item: + cfg.LOGGER.WANDB.RESUME_ID = item.split("-")[-1] + + else: + raise ValueError("Resume path is not right.") + # set seed + pl.seed_everything(cfg.SEED_VALUE) + + # gpu setting + if cfg.ACCELERATOR == "gpu": + # os.environ["PYTHONWARNINGS"] = "ignore" + os.environ["TOKENIZERS_PARALLELISM"] = "false" + # os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(str(x) for x in cfg.DEVICE) + + # tensorboard logger and wandb logger + loggers = [] + if cfg.LOGGER.WANDB.PROJECT: + wandb_logger = pl_loggers.WandbLogger( + project=cfg.LOGGER.WANDB.PROJECT, + offline=cfg.LOGGER.WANDB.OFFLINE, + id=cfg.LOGGER.WANDB.RESUME_ID, + save_dir=cfg.FOLDER_EXP, + version="", + name=cfg.NAME, + anonymous=False, + log_model=False, + ) + loggers.append(wandb_logger) + if cfg.LOGGER.TENSORBOARD: + tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.FOLDER_EXP, + sub_dir="tensorboard", + version="", + name="") + loggers.append(tb_logger) + logger.info(OmegaConf.to_yaml(cfg)) + + # create dataset + datasets = get_datasets(cfg, logger=logger) + logger.info("datasets module {} initialized".format("".join( + cfg.TRAIN.DATASETS))) + + # create model + model = get_model(cfg, datasets[0]) + logger.info("model {} loaded".format(cfg.model.model_type)) + + + + if cfg.TRAIN.STAGE in ['gpt']: + logger.info("Loading pretrain vae from {}".format( + cfg.TRAIN.PRETRAINED_VAE)) + state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE, + map_location="cpu")["state_dict"] + # extract encoder/decoder + from collections import OrderedDict + vae_dict = OrderedDict() + for k, v in state_dict.items(): + if k.split(".")[0] == "vae": + name = k.replace("vae.vqvae", "vqvae") + vae_dict[name] = v + model.vae.load_state_dict(vae_dict, strict=True) + else: + if cfg.TRAIN.PRETRAINED_VAE: + logger.info("Loading pretrain vae from {}".format( + cfg.TRAIN.PRETRAINED_VAE)) + state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE, + map_location="cpu")["state_dict"] + # extract encoder/decoder + from collections import OrderedDict + vae_dict = OrderedDict() + for k, v in state_dict.items(): + if k.split(".")[0] == "vae": + name = k.replace("vae.", "") + vae_dict[name] = v + model.vae.load_state_dict(vae_dict, strict=True) + + + # optimizer + metric_monitor = { + "Train_jf": "recons/text2jfeats/train", + "Val_jf": "recons/text2jfeats/val", + "Train_rf": "recons/text2rfeats/train", + "Val_rf": "recons/text2rfeats/val", + "APE root": "Metrics/APE_root", + "APE mean pose": "Metrics/APE_mean_pose", + "AVE root": "Metrics/AVE_root", + "AVE mean pose": "Metrics/AVE_mean_pose", + "R_TOP_1": "Metrics/R_precision_top_1", + "R_TOP_2": "Metrics/R_precision_top_2", + "R_TOP_3": "Metrics/R_precision_top_3", + "gt_R_TOP_1": "Metrics/gt_R_precision_top_1", + "gt_R_TOP_2": "Metrics/gt_R_precision_top_2", + "gt_R_TOP_3": "Metrics/gt_R_precision_top_3", + "FID": "Metrics/FID", + "gt_FID": "Metrics/gt_FID", + "Diversity": "Metrics/Diversity", + "gt_Diversity": "Metrics/gt_Diversity", + "MM dist": "Metrics/Matching_score", + "Accuracy": "Metrics/accuracy", + "gt_Accuracy": "Metrics/gt_accuracy", + } + + # callbacks + callbacks = [ + pl.callbacks.RichProgressBar(), + ProgressLogger(metric_monitor=metric_monitor), + # ModelCheckpoint(dirpath=os.path.join(cfg.FOLDER_EXP,'checkpoints'),filename='latest-{epoch}',every_n_epochs=1,save_top_k=1,save_last=True,save_on_train_epoch_end=True), + ModelCheckpoint( + dirpath=os.path.join(cfg.FOLDER_EXP, "checkpoints"), + filename="{epoch}", + monitor="step", + mode="max", + every_n_epochs=cfg.LOGGER.SAVE_CHECKPOINT_EPOCH, + save_top_k=-1, + save_last=False, + save_on_train_epoch_end=True, + ), + ] + logger.info("Callbacks initialized") + + if len(cfg.DEVICE) > 1: + # ddp_strategy = DDPStrategy(find_unused_parameters=False) + ddp_strategy = "ddp" + else: + ddp_strategy = None + + # trainer + trainer = pl.Trainer( + benchmark=False, + max_epochs=cfg.TRAIN.END_EPOCH, + accelerator=cfg.ACCELERATOR, + devices=cfg.DEVICE, + # gpus=2, + strategy=ddp_strategy, + # move_metrics_to_cpu=True, + default_root_dir=cfg.FOLDER_EXP, + log_every_n_steps=cfg.LOGGER.VAL_EVERY_STEPS, + deterministic=False, + detect_anomaly=False, + enable_progress_bar=True, + logger=loggers, + callbacks=callbacks, + check_val_every_n_epoch=cfg.LOGGER.VAL_EVERY_STEPS, + ) + logger.info("Trainer initialized") + + if cfg.TRAIN.STAGE == 'temos': + vae_type = 'temos' + else: + vae_type = cfg.model.motion_vae.target.split(".")[-1].lower().replace( + "vae", "") + + + if cfg.TRAIN.PRETRAINED_MLD: + logger.info("Loading pretrain mld from {}".format( + cfg.TRAIN.PRETRAINED_MLD)) + + state_dict = torch.load(cfg.TRAIN.PRETRAINED_MLD, + map_location="cpu")["state_dict"] + + + from collections import OrderedDict + vae_dict = OrderedDict() + for k, v in state_dict.items(): + if k.split(".")[0] == "denoiser": + name = k.replace("denoiser.", "") + vae_dict[name] = v + model.denoiser.load_state_dict(vae_dict, strict=True) + + + + if cfg.TRAIN.PRETRAINED: + + logger.info("Loading pretrain mode from {}".format( + cfg.TRAIN.PRETRAINED)) + logger.info("Attention! VAE will be recovered") + state_dict = torch.load(cfg.TRAIN.PRETRAINED, + map_location="cpu")["state_dict"] + # remove mismatched and unused params + from collections import OrderedDict + + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k not in ["denoiser.sequence_pos_encoding.pe"]: + new_state_dict[k] = v + model.load_state_dict(new_state_dict, strict=False) + + # fitting + + if cfg.TRAIN.RESUME: + trainer.validate(model, datamodule=datasets[0], ckpt_path=cfg.TRAIN.PRETRAINED) + trainer.fit(model, + datamodule=datasets[0], + ckpt_path=cfg.TRAIN.PRETRAINED) + else: + trainer.fit(model, datamodule=datasets[0]) + + # checkpoint + checkpoint_folder = trainer.checkpoint_callback.dirpath + logger.info(f"The checkpoints are stored in {checkpoint_folder}") + logger.info( + f"The outputs of this experiment are stored in {cfg.FOLDER_EXP}") + + # end + logger.info("Training ends!") + + +if __name__ == "__main__": + main() diff --git a/Experiments/motionstreamer_model/.ipynb_checkpoints/run-checkpoint.log b/Experiments/motionstreamer_model/.ipynb_checkpoints/run-checkpoint.log new file mode 100644 index 0000000000000000000000000000000000000000..d818583599df9479e6ee08b99620dab3bc4fae3f --- /dev/null +++ b/Experiments/motionstreamer_model/.ipynb_checkpoints/run-checkpoint.log @@ -0,0 +1,362 @@ +2025-10-12 22:06:52,573 INFO { + "batch_size": 256, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:08:56,596 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:09:45,974 INFO Train. Iter 100 : Loss. 0.98718 +2025-10-12 22:10:12,703 INFO Train. Iter 200 : Loss. 0.89334 +2025-10-12 22:10:39,800 INFO Train. Iter 300 : Loss. 0.58750 +2025-10-12 22:11:06,778 INFO Train. Iter 400 : Loss. 0.16893 +2025-10-12 22:11:34,048 INFO Train. Iter 500 : Loss. 0.09410 +2025-10-12 22:12:01,408 INFO Train. Iter 600 : Loss. 0.08094 +2025-10-12 22:12:28,205 INFO Train. Iter 700 : Loss. 0.07417 +2025-10-12 22:12:55,164 INFO Train. Iter 800 : Loss. 0.06324 +2025-10-12 22:13:22,827 INFO Train. Iter 900 : Loss. 0.05340 +2025-10-12 22:13:49,569 INFO Train. Iter 1000 : Loss. 0.04717 +2025-10-12 22:14:16,954 INFO Train. Iter 1100 : Loss. 0.04301 +2025-10-12 22:14:44,300 INFO Train. Iter 1200 : Loss. 0.04186 +2025-10-12 22:15:11,118 INFO Train. Iter 1300 : Loss. 0.03830 +2025-10-12 22:17:43,928 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:24:52,238 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:29:30,975 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:33:25,314 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:37:35,801 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:38:03,752 INFO Train. Iter 10 : Loss. 1.00243 +2025-10-12 22:38:06,503 INFO Train. Iter 20 : Loss. 0.99949 +2025-10-12 22:38:09,317 INFO Train. Iter 30 : Loss. 0.99516 +2025-10-12 22:38:12,126 INFO Train. Iter 40 : Loss. 0.99321 +2025-10-12 22:38:14,819 INFO Train. Iter 50 : Loss. 0.99326 +2025-10-12 22:38:17,512 INFO Train. Iter 60 : Loss. 0.99138 +2025-10-12 22:38:20,225 INFO Train. Iter 70 : Loss. 0.98142 +2025-10-12 22:38:22,968 INFO Train. Iter 80 : Loss. 0.97872 +2025-10-12 22:38:25,701 INFO Train. Iter 90 : Loss. 0.97327 +2025-10-12 22:38:28,447 INFO Train. Iter 100 : Loss. 0.96485 +2025-10-12 22:38:31,147 INFO Train. Iter 110 : Loss. 0.95915 +2025-10-12 22:38:33,870 INFO Train. Iter 120 : Loss. 0.94824 +2025-10-12 22:38:36,578 INFO Train. Iter 130 : Loss. 0.94099 +2025-10-12 22:38:39,344 INFO Train. Iter 140 : Loss. 0.92432 +2025-10-12 22:38:42,005 INFO Train. Iter 150 : Loss. 0.91036 +2025-10-12 22:38:44,765 INFO Train. Iter 160 : Loss. 0.89669 +2025-10-12 22:38:47,478 INFO Train. Iter 170 : Loss. 0.87722 +2025-10-12 22:38:50,169 INFO Train. Iter 180 : Loss. 0.86069 +2025-10-12 22:38:52,908 INFO Train. Iter 190 : Loss. 0.83729 +2025-10-12 22:38:55,613 INFO Train. Iter 200 : Loss. 0.81117 +2025-10-12 22:38:58,383 INFO Train. Iter 210 : Loss. 0.78479 +2025-10-12 22:39:01,148 INFO Train. Iter 220 : Loss. 0.75250 +2025-10-12 22:39:03,820 INFO Train. Iter 230 : Loss. 0.72087 +2025-10-12 22:39:06,594 INFO Train. Iter 240 : Loss. 0.68616 +2025-10-12 22:39:09,353 INFO Train. Iter 250 : Loss. 0.64506 +2025-10-12 22:39:12,141 INFO Train. Iter 260 : Loss. 0.60450 +2025-10-12 22:39:14,858 INFO Train. Iter 270 : Loss. 0.55745 +2025-10-12 22:39:17,682 INFO Train. Iter 280 : Loss. 0.50806 +2025-10-12 22:39:20,413 INFO Train. Iter 290 : Loss. 0.45532 +2025-10-12 22:39:23,134 INFO Train. Iter 300 : Loss. 0.39957 +2025-10-12 22:39:25,919 INFO Train. Iter 310 : Loss. 0.34181 +2025-10-12 22:39:28,826 INFO Train. Iter 320 : Loss. 0.29091 +2025-10-12 22:39:31,567 INFO Train. Iter 330 : Loss. 0.24968 +2025-10-12 22:39:34,344 INFO Train. Iter 340 : Loss. 0.20833 +2025-10-12 22:39:37,030 INFO Train. Iter 350 : Loss. 0.17297 +2025-10-12 22:39:39,835 INFO Train. Iter 360 : Loss. 0.14701 +2025-10-12 22:39:42,520 INFO Train. Iter 370 : Loss. 0.13512 +2025-10-12 22:39:45,304 INFO Train. Iter 380 : Loss. 0.12805 +2025-10-12 22:39:48,034 INFO Train. Iter 390 : Loss. 0.11589 +2025-10-12 22:39:50,771 INFO Train. Iter 400 : Loss. 0.11061 +2025-10-12 22:39:53,562 INFO Train. Iter 410 : Loss. 0.10903 +2025-10-12 22:39:56,303 INFO Train. Iter 420 : Loss. 0.10401 +2025-10-12 22:39:58,996 INFO Train. Iter 430 : Loss. 0.10301 +2025-10-12 22:40:02,118 INFO Train. Iter 440 : Loss. 0.09834 +2025-10-12 22:40:04,872 INFO Train. Iter 450 : Loss. 0.10153 +2025-10-12 22:40:07,568 INFO Train. Iter 460 : Loss. 0.09574 +2025-10-12 22:40:10,349 INFO Train. Iter 470 : Loss. 0.09539 +2025-10-12 22:40:13,119 INFO Train. Iter 480 : Loss. 0.09524 +2025-10-12 22:40:15,881 INFO Train. Iter 490 : Loss. 0.09011 +2025-10-12 22:40:18,610 INFO Train. Iter 500 : Loss. 0.08879 +2025-10-12 22:40:21,548 INFO Train. Iter 510 : Loss. 0.08935 +2025-10-12 22:40:24,269 INFO Train. Iter 520 : Loss. 0.08596 +2025-10-12 22:40:27,034 INFO Train. Iter 530 : Loss. 0.08763 +2025-10-12 22:40:29,761 INFO Train. Iter 540 : Loss. 0.08654 +2025-10-12 22:40:32,805 INFO Train. Iter 550 : Loss. 0.08733 +2025-10-12 22:40:35,532 INFO Train. Iter 560 : Loss. 0.08198 +2025-10-12 22:40:38,339 INFO Train. Iter 570 : Loss. 0.08665 +2025-10-12 22:40:41,118 INFO Train. Iter 580 : Loss. 0.08642 +2025-10-12 22:40:43,867 INFO Train. Iter 590 : Loss. 0.08730 +2025-10-12 22:40:46,668 INFO Train. Iter 600 : Loss. 0.07833 +2025-10-12 22:40:49,456 INFO Train. Iter 610 : Loss. 0.08199 +2025-10-12 22:40:52,142 INFO Train. Iter 620 : Loss. 0.08393 +2025-10-12 22:40:54,825 INFO Train. Iter 630 : Loss. 0.07858 +2025-10-12 22:40:57,447 INFO Train. Iter 640 : Loss. 0.08227 +2025-10-12 22:41:00,132 INFO Train. Iter 650 : Loss. 0.07588 +2025-10-12 22:41:02,921 INFO Train. Iter 660 : Loss. 0.08195 +2025-10-12 22:41:05,672 INFO Train. Iter 670 : Loss. 0.08222 +2025-10-12 22:41:08,472 INFO Train. Iter 680 : Loss. 0.07408 +2025-10-12 22:41:11,288 INFO Train. Iter 690 : Loss. 0.07727 +2025-10-12 22:41:14,040 INFO Train. Iter 700 : Loss. 0.07344 +2025-10-12 22:41:16,922 INFO Train. Iter 710 : Loss. 0.07518 +2025-10-12 22:41:19,616 INFO Train. Iter 720 : Loss. 0.07710 +2025-10-12 22:41:22,356 INFO Train. Iter 730 : Loss. 0.07323 +2025-10-12 22:41:25,088 INFO Train. Iter 740 : Loss. 0.07110 +2025-10-12 22:41:27,728 INFO Train. Iter 750 : Loss. 0.06784 +2025-10-12 22:41:30,499 INFO Train. Iter 760 : Loss. 0.06789 +2025-10-12 22:41:33,384 INFO Train. Iter 770 : Loss. 0.06421 +2025-10-12 22:41:36,151 INFO Train. Iter 780 : Loss. 0.06638 +2025-10-12 22:41:38,870 INFO Train. Iter 790 : Loss. 0.05871 +2025-10-12 22:41:41,628 INFO Train. Iter 800 : Loss. 0.06345 +2025-10-12 22:41:44,297 INFO Train. Iter 810 : Loss. 0.06169 +2025-10-12 22:41:47,133 INFO Train. Iter 820 : Loss. 0.06290 +2025-10-12 22:41:49,901 INFO Train. Iter 830 : Loss. 0.05927 +2025-10-12 22:41:52,724 INFO Train. Iter 840 : Loss. 0.05768 +2025-10-12 22:41:55,454 INFO Train. Iter 850 : Loss. 0.06435 +2025-10-12 22:41:58,225 INFO Train. Iter 860 : Loss. 0.05915 +2025-10-12 22:42:01,090 INFO Train. Iter 870 : Loss. 0.05240 +2025-10-12 22:42:04,404 INFO Train. Iter 880 : Loss. 0.05182 +2025-10-12 22:42:07,083 INFO Train. Iter 890 : Loss. 0.05887 +2025-10-12 22:42:09,896 INFO Train. Iter 900 : Loss. 0.04940 +2025-10-12 22:42:12,547 INFO Train. Iter 910 : Loss. 0.05817 +2025-10-12 22:42:15,403 INFO Train. Iter 920 : Loss. 0.05284 +2025-10-12 22:42:18,069 INFO Train. Iter 930 : Loss. 0.04915 +2025-10-12 22:42:20,674 INFO Train. Iter 940 : Loss. 0.05672 +2025-10-12 22:42:23,404 INFO Train. Iter 950 : Loss. 0.04910 +2025-10-12 22:42:26,122 INFO Train. Iter 960 : Loss. 0.05498 +2025-10-12 22:42:28,819 INFO Train. Iter 970 : Loss. 0.05035 +2025-10-12 22:42:31,603 INFO Train. Iter 980 : Loss. 0.04843 +2025-10-12 22:42:34,300 INFO Train. Iter 990 : Loss. 0.05129 +2025-10-12 22:42:37,022 INFO Train. Iter 1000 : Loss. 0.04424 +2025-10-12 22:42:39,818 INFO Train. Iter 1010 : Loss. 0.04696 +2025-10-12 22:42:42,518 INFO Train. Iter 1020 : Loss. 0.04665 +2025-10-12 22:42:45,262 INFO Train. Iter 1030 : Loss. 0.04747 +2025-10-12 22:42:48,043 INFO Train. Iter 1040 : Loss. 0.04422 +2025-10-12 22:42:50,820 INFO Train. Iter 1050 : Loss. 0.04651 +2025-10-12 22:42:53,519 INFO Train. Iter 1060 : Loss. 0.04664 +2025-10-12 22:42:56,318 INFO Train. Iter 1070 : Loss. 0.04701 +2025-10-12 22:42:59,036 INFO Train. Iter 1080 : Loss. 0.04410 +2025-10-12 22:43:01,879 INFO Train. Iter 1090 : Loss. 0.04406 +2025-10-12 22:43:04,681 INFO Train. Iter 1100 : Loss. 0.04581 +2025-10-12 22:43:07,386 INFO Train. Iter 1110 : Loss. 0.04522 +2025-10-12 22:43:10,289 INFO Train. Iter 1120 : Loss. 0.04364 +2025-10-12 22:43:13,092 INFO Train. Iter 1130 : Loss. 0.04148 +2025-10-12 22:47:51,657 INFO { + "batch_size": 40, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:48:49,611 INFO Train. Iter 100 : Loss. 0.98768 +2025-10-12 22:51:50,886 INFO { + "batch_size": 30, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:52:41,451 INFO Train. Iter 100 : Loss. 0.98724 diff --git a/Experiments/motionstreamer_model/100k.pth b/Experiments/motionstreamer_model/100k.pth new file mode 100644 index 0000000000000000000000000000000000000000..5274e91b99933d2ccc660e238339023aa0add175 --- /dev/null +++ b/Experiments/motionstreamer_model/100k.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7692bc2d024c37ab6421354dc6237a4d19221e7d09cf761ccdc96a993a8b578 +size 969286216 diff --git a/Experiments/motionstreamer_model/events.out.tfevents.1760344059.86fdc3d9f180.32065.0 b/Experiments/motionstreamer_model/events.out.tfevents.1760344059.86fdc3d9f180.32065.0 new file mode 100644 index 0000000000000000000000000000000000000000..608d27a15076f21b69a54e07c56a5f1967b27f82 --- /dev/null +++ b/Experiments/motionstreamer_model/events.out.tfevents.1760344059.86fdc3d9f180.32065.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7319831fb43dc83c613200e054011ed47a803141fcc81ac2e7f56df5797b903 +size 1108 diff --git a/Experiments/motionstreamer_model/events.out.tfevents.1760344453.86fdc3d9f180.1417.0 b/Experiments/motionstreamer_model/events.out.tfevents.1760344453.86fdc3d9f180.1417.0 new file mode 100644 index 0000000000000000000000000000000000000000..686605a5803d4624e2c3160df3bd412a15798790 --- /dev/null +++ b/Experiments/motionstreamer_model/events.out.tfevents.1760344453.86fdc3d9f180.1417.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9817406a18dd624ebfb0e0434c3482571814a72a48a1005ed7b3842f317cb1de +size 1618 diff --git a/Experiments/motionstreamer_model/events.out.tfevents.1760345184.86fdc3d9f180.4590.0 b/Experiments/motionstreamer_model/events.out.tfevents.1760345184.86fdc3d9f180.4590.0 new file mode 100644 index 0000000000000000000000000000000000000000..d40f9deb9f83178d5ef8751ddbef8dbc856593ab --- /dev/null +++ b/Experiments/motionstreamer_model/events.out.tfevents.1760345184.86fdc3d9f180.4590.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:777b1762235b1953b2cd0c3537c7a23fd02dea8f14aa83f03be80761243bf1af +size 386 diff --git a/Experiments/motionstreamer_model/events.out.tfevents.1760345891.86fdc3d9f180.6638.0 b/Experiments/motionstreamer_model/events.out.tfevents.1760345891.86fdc3d9f180.6638.0 new file mode 100644 index 0000000000000000000000000000000000000000..f9a6de3258f3efa69487fc1becffcd0ec4729a43 --- /dev/null +++ b/Experiments/motionstreamer_model/events.out.tfevents.1760345891.86fdc3d9f180.6638.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d294c830b4dcc2994d40df665c3c8a927ba3c052fd1f637867cabae182be1413 +size 101760 diff --git a/Experiments/motionstreamer_model/latest.pth b/Experiments/motionstreamer_model/latest.pth new file mode 100644 index 0000000000000000000000000000000000000000..e563ec64a9a60aa8de4a09cca6393730fd1472a8 --- /dev/null +++ b/Experiments/motionstreamer_model/latest.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c235ac7d1c1566a5d16325045cf42db6f8867d1000d538c5a08bc0c9eb8c7971 +size 2907872033 diff --git a/Experiments/motionstreamer_model/run.log b/Experiments/motionstreamer_model/run.log new file mode 100644 index 0000000000000000000000000000000000000000..c6a371507045e7b1478f845257c33550f135c4ce --- /dev/null +++ b/Experiments/motionstreamer_model/run.log @@ -0,0 +1,2529 @@ +2025-10-12 22:06:52,573 INFO { + "batch_size": 256, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:08:56,596 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:09:45,974 INFO Train. Iter 100 : Loss. 0.98718 +2025-10-12 22:10:12,703 INFO Train. Iter 200 : Loss. 0.89334 +2025-10-12 22:10:39,800 INFO Train. Iter 300 : Loss. 0.58750 +2025-10-12 22:11:06,778 INFO Train. Iter 400 : Loss. 0.16893 +2025-10-12 22:11:34,048 INFO Train. Iter 500 : Loss. 0.09410 +2025-10-12 22:12:01,408 INFO Train. Iter 600 : Loss. 0.08094 +2025-10-12 22:12:28,205 INFO Train. Iter 700 : Loss. 0.07417 +2025-10-12 22:12:55,164 INFO Train. Iter 800 : Loss. 0.06324 +2025-10-12 22:13:22,827 INFO Train. Iter 900 : Loss. 0.05340 +2025-10-12 22:13:49,569 INFO Train. Iter 1000 : Loss. 0.04717 +2025-10-12 22:14:16,954 INFO Train. Iter 1100 : Loss. 0.04301 +2025-10-12 22:14:44,300 INFO Train. Iter 1200 : Loss. 0.04186 +2025-10-12 22:15:11,118 INFO Train. Iter 1300 : Loss. 0.03830 +2025-10-12 22:17:43,928 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:24:52,238 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:29:30,975 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:33:25,314 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:37:35,801 INFO { + "batch_size": 32, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:38:03,752 INFO Train. Iter 10 : Loss. 1.00243 +2025-10-12 22:38:06,503 INFO Train. Iter 20 : Loss. 0.99949 +2025-10-12 22:38:09,317 INFO Train. Iter 30 : Loss. 0.99516 +2025-10-12 22:38:12,126 INFO Train. Iter 40 : Loss. 0.99321 +2025-10-12 22:38:14,819 INFO Train. Iter 50 : Loss. 0.99326 +2025-10-12 22:38:17,512 INFO Train. Iter 60 : Loss. 0.99138 +2025-10-12 22:38:20,225 INFO Train. Iter 70 : Loss. 0.98142 +2025-10-12 22:38:22,968 INFO Train. Iter 80 : Loss. 0.97872 +2025-10-12 22:38:25,701 INFO Train. Iter 90 : Loss. 0.97327 +2025-10-12 22:38:28,447 INFO Train. Iter 100 : Loss. 0.96485 +2025-10-12 22:38:31,147 INFO Train. Iter 110 : Loss. 0.95915 +2025-10-12 22:38:33,870 INFO Train. Iter 120 : Loss. 0.94824 +2025-10-12 22:38:36,578 INFO Train. Iter 130 : Loss. 0.94099 +2025-10-12 22:38:39,344 INFO Train. Iter 140 : Loss. 0.92432 +2025-10-12 22:38:42,005 INFO Train. Iter 150 : Loss. 0.91036 +2025-10-12 22:38:44,765 INFO Train. Iter 160 : Loss. 0.89669 +2025-10-12 22:38:47,478 INFO Train. Iter 170 : Loss. 0.87722 +2025-10-12 22:38:50,169 INFO Train. Iter 180 : Loss. 0.86069 +2025-10-12 22:38:52,908 INFO Train. Iter 190 : Loss. 0.83729 +2025-10-12 22:38:55,613 INFO Train. Iter 200 : Loss. 0.81117 +2025-10-12 22:38:58,383 INFO Train. Iter 210 : Loss. 0.78479 +2025-10-12 22:39:01,148 INFO Train. Iter 220 : Loss. 0.75250 +2025-10-12 22:39:03,820 INFO Train. Iter 230 : Loss. 0.72087 +2025-10-12 22:39:06,594 INFO Train. Iter 240 : Loss. 0.68616 +2025-10-12 22:39:09,353 INFO Train. Iter 250 : Loss. 0.64506 +2025-10-12 22:39:12,141 INFO Train. Iter 260 : Loss. 0.60450 +2025-10-12 22:39:14,858 INFO Train. Iter 270 : Loss. 0.55745 +2025-10-12 22:39:17,682 INFO Train. Iter 280 : Loss. 0.50806 +2025-10-12 22:39:20,413 INFO Train. Iter 290 : Loss. 0.45532 +2025-10-12 22:39:23,134 INFO Train. Iter 300 : Loss. 0.39957 +2025-10-12 22:39:25,919 INFO Train. Iter 310 : Loss. 0.34181 +2025-10-12 22:39:28,826 INFO Train. Iter 320 : Loss. 0.29091 +2025-10-12 22:39:31,567 INFO Train. Iter 330 : Loss. 0.24968 +2025-10-12 22:39:34,344 INFO Train. Iter 340 : Loss. 0.20833 +2025-10-12 22:39:37,030 INFO Train. Iter 350 : Loss. 0.17297 +2025-10-12 22:39:39,835 INFO Train. Iter 360 : Loss. 0.14701 +2025-10-12 22:39:42,520 INFO Train. Iter 370 : Loss. 0.13512 +2025-10-12 22:39:45,304 INFO Train. Iter 380 : Loss. 0.12805 +2025-10-12 22:39:48,034 INFO Train. Iter 390 : Loss. 0.11589 +2025-10-12 22:39:50,771 INFO Train. Iter 400 : Loss. 0.11061 +2025-10-12 22:39:53,562 INFO Train. Iter 410 : Loss. 0.10903 +2025-10-12 22:39:56,303 INFO Train. Iter 420 : Loss. 0.10401 +2025-10-12 22:39:58,996 INFO Train. Iter 430 : Loss. 0.10301 +2025-10-12 22:40:02,118 INFO Train. Iter 440 : Loss. 0.09834 +2025-10-12 22:40:04,872 INFO Train. Iter 450 : Loss. 0.10153 +2025-10-12 22:40:07,568 INFO Train. Iter 460 : Loss. 0.09574 +2025-10-12 22:40:10,349 INFO Train. Iter 470 : Loss. 0.09539 +2025-10-12 22:40:13,119 INFO Train. Iter 480 : Loss. 0.09524 +2025-10-12 22:40:15,881 INFO Train. Iter 490 : Loss. 0.09011 +2025-10-12 22:40:18,610 INFO Train. Iter 500 : Loss. 0.08879 +2025-10-12 22:40:21,548 INFO Train. Iter 510 : Loss. 0.08935 +2025-10-12 22:40:24,269 INFO Train. Iter 520 : Loss. 0.08596 +2025-10-12 22:40:27,034 INFO Train. Iter 530 : Loss. 0.08763 +2025-10-12 22:40:29,761 INFO Train. Iter 540 : Loss. 0.08654 +2025-10-12 22:40:32,805 INFO Train. Iter 550 : Loss. 0.08733 +2025-10-12 22:40:35,532 INFO Train. Iter 560 : Loss. 0.08198 +2025-10-12 22:40:38,339 INFO Train. Iter 570 : Loss. 0.08665 +2025-10-12 22:40:41,118 INFO Train. Iter 580 : Loss. 0.08642 +2025-10-12 22:40:43,867 INFO Train. Iter 590 : Loss. 0.08730 +2025-10-12 22:40:46,668 INFO Train. Iter 600 : Loss. 0.07833 +2025-10-12 22:40:49,456 INFO Train. Iter 610 : Loss. 0.08199 +2025-10-12 22:40:52,142 INFO Train. Iter 620 : Loss. 0.08393 +2025-10-12 22:40:54,825 INFO Train. Iter 630 : Loss. 0.07858 +2025-10-12 22:40:57,447 INFO Train. Iter 640 : Loss. 0.08227 +2025-10-12 22:41:00,132 INFO Train. Iter 650 : Loss. 0.07588 +2025-10-12 22:41:02,921 INFO Train. Iter 660 : Loss. 0.08195 +2025-10-12 22:41:05,672 INFO Train. Iter 670 : Loss. 0.08222 +2025-10-12 22:41:08,472 INFO Train. Iter 680 : Loss. 0.07408 +2025-10-12 22:41:11,288 INFO Train. Iter 690 : Loss. 0.07727 +2025-10-12 22:41:14,040 INFO Train. Iter 700 : Loss. 0.07344 +2025-10-12 22:41:16,922 INFO Train. Iter 710 : Loss. 0.07518 +2025-10-12 22:41:19,616 INFO Train. Iter 720 : Loss. 0.07710 +2025-10-12 22:41:22,356 INFO Train. Iter 730 : Loss. 0.07323 +2025-10-12 22:41:25,088 INFO Train. Iter 740 : Loss. 0.07110 +2025-10-12 22:41:27,728 INFO Train. Iter 750 : Loss. 0.06784 +2025-10-12 22:41:30,499 INFO Train. Iter 760 : Loss. 0.06789 +2025-10-12 22:41:33,384 INFO Train. Iter 770 : Loss. 0.06421 +2025-10-12 22:41:36,151 INFO Train. Iter 780 : Loss. 0.06638 +2025-10-12 22:41:38,870 INFO Train. Iter 790 : Loss. 0.05871 +2025-10-12 22:41:41,628 INFO Train. Iter 800 : Loss. 0.06345 +2025-10-12 22:41:44,297 INFO Train. Iter 810 : Loss. 0.06169 +2025-10-12 22:41:47,133 INFO Train. Iter 820 : Loss. 0.06290 +2025-10-12 22:41:49,901 INFO Train. Iter 830 : Loss. 0.05927 +2025-10-12 22:41:52,724 INFO Train. Iter 840 : Loss. 0.05768 +2025-10-12 22:41:55,454 INFO Train. Iter 850 : Loss. 0.06435 +2025-10-12 22:41:58,225 INFO Train. Iter 860 : Loss. 0.05915 +2025-10-12 22:42:01,090 INFO Train. Iter 870 : Loss. 0.05240 +2025-10-12 22:42:04,404 INFO Train. Iter 880 : Loss. 0.05182 +2025-10-12 22:42:07,083 INFO Train. Iter 890 : Loss. 0.05887 +2025-10-12 22:42:09,896 INFO Train. Iter 900 : Loss. 0.04940 +2025-10-12 22:42:12,547 INFO Train. Iter 910 : Loss. 0.05817 +2025-10-12 22:42:15,403 INFO Train. Iter 920 : Loss. 0.05284 +2025-10-12 22:42:18,069 INFO Train. Iter 930 : Loss. 0.04915 +2025-10-12 22:42:20,674 INFO Train. Iter 940 : Loss. 0.05672 +2025-10-12 22:42:23,404 INFO Train. Iter 950 : Loss. 0.04910 +2025-10-12 22:42:26,122 INFO Train. Iter 960 : Loss. 0.05498 +2025-10-12 22:42:28,819 INFO Train. Iter 970 : Loss. 0.05035 +2025-10-12 22:42:31,603 INFO Train. Iter 980 : Loss. 0.04843 +2025-10-12 22:42:34,300 INFO Train. Iter 990 : Loss. 0.05129 +2025-10-12 22:42:37,022 INFO Train. Iter 1000 : Loss. 0.04424 +2025-10-12 22:42:39,818 INFO Train. Iter 1010 : Loss. 0.04696 +2025-10-12 22:42:42,518 INFO Train. Iter 1020 : Loss. 0.04665 +2025-10-12 22:42:45,262 INFO Train. Iter 1030 : Loss. 0.04747 +2025-10-12 22:42:48,043 INFO Train. Iter 1040 : Loss. 0.04422 +2025-10-12 22:42:50,820 INFO Train. Iter 1050 : Loss. 0.04651 +2025-10-12 22:42:53,519 INFO Train. Iter 1060 : Loss. 0.04664 +2025-10-12 22:42:56,318 INFO Train. Iter 1070 : Loss. 0.04701 +2025-10-12 22:42:59,036 INFO Train. Iter 1080 : Loss. 0.04410 +2025-10-12 22:43:01,879 INFO Train. Iter 1090 : Loss. 0.04406 +2025-10-12 22:43:04,681 INFO Train. Iter 1100 : Loss. 0.04581 +2025-10-12 22:43:07,386 INFO Train. Iter 1110 : Loss. 0.04522 +2025-10-12 22:43:10,289 INFO Train. Iter 1120 : Loss. 0.04364 +2025-10-12 22:43:13,092 INFO Train. Iter 1130 : Loss. 0.04148 +2025-10-12 22:47:51,657 INFO { + "batch_size": 40, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:48:49,611 INFO Train. Iter 100 : Loss. 0.98768 +2025-10-12 22:51:50,886 INFO { + "batch_size": 30, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:52:41,451 INFO Train. Iter 100 : Loss. 0.98724 +2025-10-12 22:53:07,122 INFO Train. Iter 200 : Loss. 0.89317 +2025-10-12 22:53:33,119 INFO Train. Iter 300 : Loss. 0.58815 +2025-10-12 22:53:59,064 INFO Train. Iter 400 : Loss. 0.17107 +2025-10-12 22:54:25,435 INFO Train. Iter 500 : Loss. 0.09429 +2025-10-12 22:54:51,725 INFO Train. Iter 600 : Loss. 0.08183 +2025-10-12 22:55:17,999 INFO Train. Iter 700 : Loss. 0.07371 +2025-10-12 22:55:43,883 INFO Train. Iter 800 : Loss. 0.06298 +2025-10-12 22:56:10,014 INFO Train. Iter 900 : Loss. 0.05263 +2025-10-12 22:56:36,366 INFO Train. Iter 1000 : Loss. 0.04678 +2025-10-12 22:58:46,980 INFO { + "batch_size": 30, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-12 22:59:37,098 INFO Train. Iter 100 : Loss. 0.98724 +2025-10-12 23:00:02,782 INFO Train. Iter 200 : Loss. 0.89317 +2025-10-12 23:00:28,767 INFO Train. Iter 300 : Loss. 0.58813 +2025-10-12 23:00:54,590 INFO Train. Iter 400 : Loss. 0.17133 +2025-10-12 23:01:20,995 INFO Train. Iter 500 : Loss. 0.09427 +2025-10-12 23:01:47,400 INFO Train. Iter 600 : Loss. 0.08201 +2025-10-12 23:02:13,759 INFO Train. Iter 700 : Loss. 0.07402 +2025-10-12 23:02:39,648 INFO Train. Iter 800 : Loss. 0.06515 +2025-10-12 23:03:05,801 INFO Train. Iter 900 : Loss. 0.05208 +2025-10-12 23:03:32,122 INFO Train. Iter 1000 : Loss. 0.04672 +2025-10-12 23:03:58,427 INFO Train. Iter 1100 : Loss. 0.04410 +2025-10-12 23:04:24,395 INFO Train. Iter 1200 : Loss. 0.04132 +2025-10-12 23:04:50,470 INFO Train. Iter 1300 : Loss. 0.03959 +2025-10-12 23:05:17,050 INFO Train. Iter 1400 : Loss. 0.03692 +2025-10-12 23:05:43,089 INFO Train. Iter 1500 : Loss. 0.03639 +2025-10-12 23:06:09,147 INFO Train. Iter 1600 : Loss. 0.03472 +2025-10-12 23:06:35,209 INFO Train. Iter 1700 : Loss. 0.03402 +2025-10-12 23:07:01,188 INFO Train. Iter 1800 : Loss. 0.03401 +2025-10-12 23:07:27,885 INFO Train. Iter 1900 : Loss. 0.03347 +2025-10-12 23:07:53,973 INFO Train. Iter 2000 : Loss. 0.03137 +2025-10-12 23:08:20,079 INFO Train. Iter 2100 : Loss. 0.03233 +2025-10-12 23:08:46,352 INFO Train. Iter 2200 : Loss. 0.03073 +2025-10-12 23:09:12,403 INFO Train. Iter 2300 : Loss. 0.03189 +2025-10-12 23:09:39,072 INFO Train. Iter 2400 : Loss. 0.03067 +2025-10-12 23:10:05,092 INFO Train. Iter 2500 : Loss. 0.03099 +2025-10-12 23:10:31,206 INFO Train. Iter 2600 : Loss. 0.02977 +2025-10-12 23:10:57,383 INFO Train. Iter 2700 : Loss. 0.02998 +2025-10-12 23:11:23,867 INFO Train. Iter 2800 : Loss. 0.02858 +2025-10-12 23:11:50,165 INFO Train. Iter 2900 : Loss. 0.02895 +2025-10-12 23:12:16,408 INFO Train. Iter 3000 : Loss. 0.02939 +2025-10-12 23:12:42,450 INFO Train. Iter 3100 : Loss. 0.02959 +2025-10-12 23:13:08,569 INFO Train. Iter 3200 : Loss. 0.02867 +2025-10-12 23:13:35,264 INFO Train. Iter 3300 : Loss. 0.02833 +2025-10-12 23:14:01,313 INFO Train. Iter 3400 : Loss. 0.02846 +2025-10-12 23:14:27,540 INFO Train. Iter 3500 : Loss. 0.02785 +2025-10-12 23:14:53,909 INFO Train. Iter 3600 : Loss. 0.02890 +2025-10-12 23:15:19,993 INFO Train. Iter 3700 : Loss. 0.02836 +2025-10-12 23:15:46,793 INFO Train. Iter 3800 : Loss. 0.02771 +2025-10-12 23:16:12,908 INFO Train. Iter 3900 : Loss. 0.02818 +2025-10-12 23:16:39,191 INFO Train. Iter 4000 : Loss. 0.02782 +2025-10-12 23:17:05,217 INFO Train. Iter 4100 : Loss. 0.02858 +2025-10-12 23:17:31,831 INFO Train. Iter 4200 : Loss. 0.02808 +2025-10-12 23:17:57,921 INFO Train. Iter 4300 : Loss. 0.02697 +2025-10-12 23:18:23,998 INFO Train. Iter 4400 : Loss. 0.02790 +2025-10-12 23:18:49,958 INFO Train. Iter 4500 : Loss. 0.02782 +2025-10-12 23:19:16,275 INFO Train. Iter 4600 : Loss. 0.02704 +2025-10-12 23:19:42,934 INFO Train. Iter 4700 : Loss. 0.02670 +2025-10-12 23:20:09,070 INFO Train. Iter 4800 : Loss. 0.02720 +2025-10-12 23:20:35,048 INFO Train. Iter 4900 : Loss. 0.02715 +2025-10-12 23:21:01,395 INFO Train. Iter 5000 : Loss. 0.02735 +2025-10-12 23:21:27,492 INFO Train. Iter 5100 : Loss. 0.02704 +2025-10-12 23:21:54,317 INFO Train. Iter 5200 : Loss. 0.02593 +2025-10-12 23:22:20,493 INFO Train. Iter 5300 : Loss. 0.02795 +2025-10-12 23:22:46,665 INFO Train. Iter 5400 : Loss. 0.02746 +2025-10-12 23:23:12,986 INFO Train. Iter 5500 : Loss. 0.02599 +2025-10-12 23:23:39,459 INFO Train. Iter 5600 : Loss. 0.02620 +2025-10-12 23:24:05,427 INFO Train. Iter 5700 : Loss. 0.02644 +2025-10-12 23:24:31,636 INFO Train. Iter 5800 : Loss. 0.02626 +2025-10-12 23:24:57,685 INFO Train. Iter 5900 : Loss. 0.02631 +2025-10-12 23:25:23,933 INFO Train. Iter 6000 : Loss. 0.02667 +2025-10-12 23:25:50,696 INFO Train. Iter 6100 : Loss. 0.02630 +2025-10-12 23:26:17,108 INFO Train. Iter 6200 : Loss. 0.02631 +2025-10-12 23:26:43,400 INFO Train. Iter 6300 : Loss. 0.02567 +2025-10-12 23:27:09,790 INFO Train. Iter 6400 : Loss. 0.02622 +2025-10-12 23:27:36,027 INFO Train. Iter 6500 : Loss. 0.02549 +2025-10-12 23:28:03,118 INFO Train. Iter 6600 : Loss. 0.02654 +2025-10-12 23:28:29,581 INFO Train. Iter 6700 : Loss. 0.02672 +2025-10-12 23:28:55,681 INFO Train. Iter 6800 : Loss. 0.02601 +2025-10-12 23:29:21,879 INFO Train. Iter 6900 : Loss. 0.02587 +2025-10-12 23:29:48,790 INFO Train. Iter 7000 : Loss. 0.02496 +2025-10-12 23:30:15,174 INFO Train. Iter 7100 : Loss. 0.02601 +2025-10-12 23:30:41,341 INFO Train. Iter 7200 : Loss. 0.02545 +2025-10-12 23:31:07,606 INFO Train. Iter 7300 : Loss. 0.02532 +2025-10-12 23:31:33,654 INFO Train. Iter 7400 : Loss. 0.02573 +2025-10-12 23:32:00,528 INFO Train. Iter 7500 : Loss. 0.02593 +2025-10-12 23:32:26,678 INFO Train. Iter 7600 : Loss. 0.02647 +2025-10-12 23:32:52,925 INFO Train. Iter 7700 : Loss. 0.02530 +2025-10-12 23:33:19,081 INFO Train. Iter 7800 : Loss. 0.02613 +2025-10-12 23:33:45,180 INFO Train. Iter 7900 : Loss. 0.02586 +2025-10-12 23:34:12,056 INFO Train. Iter 8000 : Loss. 0.02593 +2025-10-12 23:34:38,279 INFO Train. Iter 8100 : Loss. 0.02593 +2025-10-12 23:35:04,492 INFO Train. Iter 8200 : Loss. 0.02518 +2025-10-12 23:35:30,531 INFO Train. Iter 8300 : Loss. 0.02556 +2025-10-12 23:35:56,721 INFO Train. Iter 8400 : Loss. 0.02501 +2025-10-12 23:36:22,817 INFO Train. Iter 8500 : Loss. 0.02516 +2025-10-12 23:36:49,267 INFO Train. Iter 8600 : Loss. 0.02532 +2025-10-12 23:37:15,315 INFO Train. Iter 8700 : Loss. 0.02514 +2025-10-12 23:37:41,855 INFO Train. Iter 8800 : Loss. 0.02511 +2025-10-12 23:38:08,578 INFO Train. Iter 8900 : Loss. 0.02568 +2025-10-12 23:38:34,861 INFO Train. Iter 9000 : Loss. 0.02448 +2025-10-12 23:39:01,312 INFO Train. Iter 9100 : Loss. 0.02413 +2025-10-12 23:39:27,397 INFO Train. Iter 9200 : Loss. 0.02524 +2025-10-12 23:39:53,376 INFO Train. Iter 9300 : Loss. 0.02534 +2025-10-12 23:40:19,957 INFO Train. Iter 9400 : Loss. 0.02498 +2025-10-12 23:40:46,163 INFO Train. Iter 9500 : Loss. 0.02418 +2025-10-12 23:41:12,487 INFO Train. Iter 9600 : Loss. 0.02472 +2025-10-12 23:41:38,591 INFO Train. Iter 9700 : Loss. 0.02564 +2025-10-12 23:42:05,161 INFO Train. Iter 9800 : Loss. 0.02526 +2025-10-12 23:42:31,431 INFO Train. Iter 9900 : Loss. 0.02408 +2025-10-12 23:42:57,218 INFO Train. Iter 10000 : Loss. 0.02424 +2025-10-12 23:43:24,467 INFO Train. Iter 10100 : Loss. 0.02408 +2025-10-12 23:43:50,653 INFO Train. Iter 10200 : Loss. 0.02442 +2025-10-12 23:44:17,367 INFO Train. Iter 10300 : Loss. 0.02427 +2025-10-12 23:44:43,467 INFO Train. Iter 10400 : Loss. 0.02405 +2025-10-12 23:45:09,722 INFO Train. Iter 10500 : Loss. 0.02363 +2025-10-12 23:45:35,830 INFO Train. Iter 10600 : Loss. 0.02513 +2025-10-12 23:46:01,838 INFO Train. Iter 10700 : Loss. 0.02455 +2025-10-12 23:46:28,887 INFO Train. Iter 10800 : Loss. 0.02363 +2025-10-12 23:46:55,031 INFO Train. Iter 10900 : Loss. 0.02385 +2025-10-12 23:47:20,979 INFO Train. Iter 11000 : Loss. 0.02410 +2025-10-12 23:47:47,217 INFO Train. Iter 11100 : Loss. 0.02338 +2025-10-12 23:48:13,922 INFO Train. Iter 11200 : Loss. 0.02445 +2025-10-12 23:48:39,973 INFO Train. Iter 11300 : Loss. 0.02308 +2025-10-12 23:49:05,800 INFO Train. Iter 11400 : Loss. 0.02423 +2025-10-12 23:49:32,205 INFO Train. Iter 11500 : Loss. 0.02288 +2025-10-12 23:49:58,537 INFO Train. Iter 11600 : Loss. 0.02278 +2025-10-12 23:50:25,661 INFO Train. Iter 11700 : Loss. 0.02384 +2025-10-12 23:50:51,821 INFO Train. Iter 11800 : Loss. 0.02339 +2025-10-12 23:51:17,894 INFO Train. Iter 11900 : Loss. 0.02311 +2025-10-12 23:51:44,214 INFO Train. Iter 12000 : Loss. 0.02209 +2025-10-12 23:52:10,351 INFO Train. Iter 12100 : Loss. 0.02298 +2025-10-12 23:52:37,064 INFO Train. Iter 12200 : Loss. 0.02214 +2025-10-12 23:53:03,024 INFO Train. Iter 12300 : Loss. 0.02283 +2025-10-12 23:53:29,417 INFO Train. Iter 12400 : Loss. 0.02225 +2025-10-12 23:53:55,585 INFO Train. Iter 12500 : Loss. 0.02237 +2025-10-12 23:54:22,260 INFO Train. Iter 12600 : Loss. 0.02203 +2025-10-12 23:54:48,323 INFO Train. Iter 12700 : Loss. 0.02295 +2025-10-12 23:55:14,655 INFO Train. Iter 12800 : Loss. 0.02234 +2025-10-12 23:55:40,566 INFO Train. Iter 12900 : Loss. 0.02218 +2025-10-12 23:56:06,867 INFO Train. Iter 13000 : Loss. 0.02242 +2025-10-12 23:56:33,801 INFO Train. Iter 13100 : Loss. 0.02222 +2025-10-12 23:56:59,552 INFO Train. Iter 13200 : Loss. 0.02166 +2025-10-12 23:57:25,973 INFO Train. Iter 13300 : Loss. 0.02162 +2025-10-12 23:57:52,112 INFO Train. Iter 13400 : Loss. 0.02299 +2025-10-12 23:58:18,388 INFO Train. Iter 13500 : Loss. 0.02238 +2025-10-12 23:58:45,114 INFO Train. Iter 13600 : Loss. 0.02121 +2025-10-12 23:59:11,278 INFO Train. Iter 13700 : Loss. 0.02206 +2025-10-12 23:59:37,278 INFO Train. Iter 13800 : Loss. 0.02220 +2025-10-13 00:00:03,363 INFO Train. Iter 13900 : Loss. 0.02110 +2025-10-13 00:00:30,025 INFO Train. Iter 14000 : Loss. 0.02196 +2025-10-13 00:00:56,208 INFO Train. Iter 14100 : Loss. 0.02138 +2025-10-13 00:01:22,361 INFO Train. Iter 14200 : Loss. 0.02091 +2025-10-13 00:01:48,679 INFO Train. Iter 14300 : Loss. 0.02178 +2025-10-13 00:02:14,792 INFO Train. Iter 14400 : Loss. 0.02068 +2025-10-13 00:02:41,290 INFO Train. Iter 14500 : Loss. 0.02058 +2025-10-13 00:03:07,528 INFO Train. Iter 14600 : Loss. 0.02144 +2025-10-13 00:03:33,659 INFO Train. Iter 14700 : Loss. 0.02131 +2025-10-13 00:03:59,892 INFO Train. Iter 14800 : Loss. 0.02157 +2025-10-13 00:04:26,197 INFO Train. Iter 14900 : Loss. 0.02129 +2025-10-13 00:04:52,867 INFO Train. Iter 15000 : Loss. 0.02123 +2025-10-13 00:05:18,871 INFO Train. Iter 15100 : Loss. 0.02091 +2025-10-13 00:05:45,042 INFO Train. Iter 15200 : Loss. 0.02084 +2025-10-13 00:06:11,041 INFO Train. Iter 15300 : Loss. 0.02126 +2025-10-13 00:06:37,651 INFO Train. Iter 15400 : Loss. 0.02050 +2025-10-13 00:07:03,929 INFO Train. Iter 15500 : Loss. 0.02083 +2025-10-13 00:07:29,867 INFO Train. Iter 15600 : Loss. 0.02098 +2025-10-13 00:07:55,946 INFO Train. Iter 15700 : Loss. 0.02111 +2025-10-13 00:08:22,232 INFO Train. Iter 15800 : Loss. 0.02080 +2025-10-13 00:08:48,823 INFO Train. Iter 15900 : Loss. 0.02020 +2025-10-13 00:09:15,246 INFO Train. Iter 16000 : Loss. 0.02037 +2025-10-13 00:09:41,415 INFO Train. Iter 16100 : Loss. 0.02043 +2025-10-13 00:10:07,684 INFO Train. Iter 16200 : Loss. 0.02086 +2025-10-13 00:10:33,826 INFO Train. Iter 16300 : Loss. 0.02005 +2025-10-13 00:11:00,332 INFO Train. Iter 16400 : Loss. 0.02048 +2025-10-13 00:11:26,590 INFO Train. Iter 16500 : Loss. 0.02055 +2025-10-13 00:11:52,792 INFO Train. Iter 16600 : Loss. 0.02086 +2025-10-13 00:12:18,914 INFO Train. Iter 16700 : Loss. 0.01986 +2025-10-13 00:12:45,606 INFO Train. Iter 16800 : Loss. 0.02017 +2025-10-13 00:13:11,684 INFO Train. Iter 16900 : Loss. 0.02016 +2025-10-13 00:13:37,837 INFO Train. Iter 17000 : Loss. 0.02062 +2025-10-13 00:14:04,227 INFO Train. Iter 17100 : Loss. 0.02023 +2025-10-13 00:14:30,208 INFO Train. Iter 17200 : Loss. 0.02023 +2025-10-13 00:14:56,614 INFO Train. Iter 17300 : Loss. 0.01943 +2025-10-13 00:15:22,810 INFO Train. Iter 17400 : Loss. 0.02024 +2025-10-13 00:15:48,873 INFO Train. Iter 17500 : Loss. 0.01939 +2025-10-13 00:16:15,061 INFO Train. Iter 17600 : Loss. 0.01960 +2025-10-13 00:16:41,127 INFO Train. Iter 17700 : Loss. 0.02005 +2025-10-13 00:17:07,770 INFO Train. Iter 17800 : Loss. 0.01979 +2025-10-13 00:17:34,203 INFO Train. Iter 17900 : Loss. 0.01958 +2025-10-13 00:18:00,401 INFO Train. Iter 18000 : Loss. 0.01961 +2025-10-13 00:18:27,409 INFO Train. Iter 18100 : Loss. 0.01961 +2025-10-13 00:18:53,667 INFO Train. Iter 18200 : Loss. 0.01998 +2025-10-13 00:19:19,877 INFO Train. Iter 18300 : Loss. 0.01944 +2025-10-13 00:19:46,041 INFO Train. Iter 18400 : Loss. 0.01938 +2025-10-13 00:20:12,111 INFO Train. Iter 18500 : Loss. 0.01939 +2025-10-13 00:20:38,383 INFO Train. Iter 18600 : Loss. 0.01983 +2025-10-13 00:21:05,177 INFO Train. Iter 18700 : Loss. 0.01939 +2025-10-13 00:21:31,105 INFO Train. Iter 18800 : Loss. 0.01936 +2025-10-13 00:21:57,215 INFO Train. Iter 18900 : Loss. 0.01930 +2025-10-13 00:22:23,316 INFO Train. Iter 19000 : Loss. 0.01921 +2025-10-13 00:22:49,546 INFO Train. Iter 19100 : Loss. 0.01923 +2025-10-13 00:23:16,383 INFO Train. Iter 19200 : Loss. 0.01881 +2025-10-13 00:23:42,581 INFO Train. Iter 19300 : Loss. 0.01859 +2025-10-13 00:24:08,916 INFO Train. Iter 19400 : Loss. 0.01854 +2025-10-13 00:24:34,918 INFO Train. Iter 19500 : Loss. 0.01900 +2025-10-13 00:25:01,503 INFO Train. Iter 19600 : Loss. 0.01972 +2025-10-13 00:25:27,934 INFO Train. Iter 19700 : Loss. 0.01884 +2025-10-13 00:25:53,953 INFO Train. Iter 19800 : Loss. 0.01943 +2025-10-13 00:26:20,192 INFO Train. Iter 19900 : Loss. 0.01876 +2025-10-13 00:26:46,340 INFO Train. Iter 20000 : Loss. 0.01900 +2025-10-13 00:27:14,419 INFO Train. Iter 20100 : Loss. 0.01881 +2025-10-13 00:27:40,716 INFO Train. Iter 20200 : Loss. 0.01838 +2025-10-13 00:28:06,860 INFO Train. Iter 20300 : Loss. 0.01878 +2025-10-13 00:28:33,087 INFO Train. Iter 20400 : Loss. 0.01888 +2025-10-13 00:28:59,215 INFO Train. Iter 20500 : Loss. 0.01911 +2025-10-13 00:29:26,106 INFO Train. Iter 20600 : Loss. 0.01815 +2025-10-13 00:29:52,448 INFO Train. Iter 20700 : Loss. 0.01869 +2025-10-13 00:30:18,616 INFO Train. Iter 20800 : Loss. 0.01832 +2025-10-13 00:30:44,574 INFO Train. Iter 20900 : Loss. 0.01813 +2025-10-13 00:31:11,174 INFO Train. Iter 21000 : Loss. 0.01902 +2025-10-13 00:31:37,390 INFO Train. Iter 21100 : Loss. 0.01821 +2025-10-13 00:32:03,401 INFO Train. Iter 21200 : Loss. 0.01876 +2025-10-13 00:32:30,076 INFO Train. Iter 21300 : Loss. 0.01789 +2025-10-13 00:32:56,205 INFO Train. Iter 21400 : Loss. 0.01934 +2025-10-13 00:33:22,998 INFO Train. Iter 21500 : Loss. 0.01801 +2025-10-13 00:33:49,117 INFO Train. Iter 21600 : Loss. 0.01841 +2025-10-13 00:34:15,153 INFO Train. Iter 21700 : Loss. 0.01867 +2025-10-13 00:34:41,467 INFO Train. Iter 21800 : Loss. 0.01833 +2025-10-13 00:35:07,434 INFO Train. Iter 21900 : Loss. 0.01809 +2025-10-13 00:35:34,182 INFO Train. Iter 22000 : Loss. 0.01784 +2025-10-13 00:36:00,093 INFO Train. Iter 22100 : Loss. 0.01748 +2025-10-13 00:36:26,249 INFO Train. Iter 22200 : Loss. 0.01829 +2025-10-13 00:36:52,416 INFO Train. Iter 22300 : Loss. 0.01833 +2025-10-13 00:37:19,184 INFO Train. Iter 22400 : Loss. 0.01830 +2025-10-13 00:37:45,910 INFO Train. Iter 22500 : Loss. 0.01798 +2025-10-13 00:38:12,067 INFO Train. Iter 22600 : Loss. 0.01783 +2025-10-13 00:38:38,237 INFO Train. Iter 22700 : Loss. 0.01843 +2025-10-13 00:39:04,474 INFO Train. Iter 22800 : Loss. 0.01778 +2025-10-13 00:39:31,276 INFO Train. Iter 22900 : Loss. 0.01789 +2025-10-13 00:39:57,576 INFO Train. Iter 23000 : Loss. 0.01761 +2025-10-13 00:40:23,835 INFO Train. Iter 23100 : Loss. 0.01734 +2025-10-13 00:40:50,340 INFO Train. Iter 23200 : Loss. 0.01758 +2025-10-13 00:41:16,772 INFO Train. Iter 23300 : Loss. 0.01804 +2025-10-13 00:41:43,571 INFO Train. Iter 23400 : Loss. 0.01696 +2025-10-13 00:42:09,901 INFO Train. Iter 23500 : Loss. 0.01814 +2025-10-13 00:42:36,550 INFO Train. Iter 23600 : Loss. 0.01681 +2025-10-13 00:43:02,724 INFO Train. Iter 23700 : Loss. 0.01761 +2025-10-13 00:43:29,413 INFO Train. Iter 23800 : Loss. 0.01826 +2025-10-13 00:43:55,525 INFO Train. Iter 23900 : Loss. 0.01731 +2025-10-13 00:44:21,814 INFO Train. Iter 24000 : Loss. 0.01738 +2025-10-13 00:44:47,967 INFO Train. Iter 24100 : Loss. 0.01796 +2025-10-13 00:45:14,384 INFO Train. Iter 24200 : Loss. 0.01695 +2025-10-13 00:45:40,941 INFO Train. Iter 24300 : Loss. 0.01733 +2025-10-13 00:46:07,306 INFO Train. Iter 24400 : Loss. 0.01712 +2025-10-13 00:46:33,542 INFO Train. Iter 24500 : Loss. 0.01750 +2025-10-13 00:46:59,863 INFO Train. Iter 24600 : Loss. 0.01687 +2025-10-13 00:47:26,992 INFO Train. Iter 24700 : Loss. 0.01719 +2025-10-13 00:47:53,079 INFO Train. Iter 24800 : Loss. 0.01697 +2025-10-13 00:48:19,448 INFO Train. Iter 24900 : Loss. 0.01759 +2025-10-13 00:48:45,621 INFO Train. Iter 25000 : Loss. 0.01682 +2025-10-13 00:49:11,640 INFO Train. Iter 25100 : Loss. 0.01745 +2025-10-13 00:49:38,206 INFO Train. Iter 25200 : Loss. 0.01706 +2025-10-13 00:50:04,612 INFO Train. Iter 25300 : Loss. 0.01693 +2025-10-13 00:50:30,782 INFO Train. Iter 25400 : Loss. 0.01698 +2025-10-13 00:50:57,167 INFO Train. Iter 25500 : Loss. 0.01693 +2025-10-13 00:51:23,548 INFO Train. Iter 25600 : Loss. 0.01678 +2025-10-13 00:51:50,068 INFO Train. Iter 25700 : Loss. 0.01662 +2025-10-13 00:52:16,744 INFO Train. Iter 25800 : Loss. 0.01712 +2025-10-13 00:52:43,057 INFO Train. Iter 25900 : Loss. 0.01690 +2025-10-13 00:53:09,261 INFO Train. Iter 26000 : Loss. 0.01693 +2025-10-13 00:53:36,491 INFO Train. Iter 26100 : Loss. 0.01666 +2025-10-13 00:54:02,959 INFO Train. Iter 26200 : Loss. 0.01664 +2025-10-13 00:54:29,052 INFO Train. Iter 26300 : Loss. 0.01678 +2025-10-13 00:54:55,155 INFO Train. Iter 26400 : Loss. 0.01653 +2025-10-13 00:55:21,479 INFO Train. Iter 26500 : Loss. 0.01619 +2025-10-13 00:55:48,238 INFO Train. Iter 26600 : Loss. 0.01645 +2025-10-13 00:56:14,502 INFO Train. Iter 26700 : Loss. 0.01626 +2025-10-13 00:56:40,785 INFO Train. Iter 26800 : Loss. 0.01655 +2025-10-13 00:57:07,093 INFO Train. Iter 26900 : Loss. 0.01666 +2025-10-13 00:57:33,240 INFO Train. Iter 27000 : Loss. 0.01660 +2025-10-13 00:57:59,601 INFO Train. Iter 27100 : Loss. 0.01627 +2025-10-13 00:58:25,897 INFO Train. Iter 27200 : Loss. 0.01613 +2025-10-13 00:58:52,076 INFO Train. Iter 27300 : Loss. 0.01652 +2025-10-13 00:59:18,304 INFO Train. Iter 27400 : Loss. 0.01696 +2025-10-13 00:59:45,210 INFO Train. Iter 27500 : Loss. 0.01636 +2025-10-13 01:00:11,528 INFO Train. Iter 27600 : Loss. 0.01641 +2025-10-13 01:00:38,106 INFO Train. Iter 27700 : Loss. 0.01619 +2025-10-13 01:01:04,147 INFO Train. Iter 27800 : Loss. 0.01639 +2025-10-13 01:01:30,301 INFO Train. Iter 27900 : Loss. 0.01598 +2025-10-13 01:01:56,671 INFO Train. Iter 28000 : Loss. 0.01617 +2025-10-13 01:02:23,003 INFO Train. Iter 28100 : Loss. 0.01527 +2025-10-13 01:02:49,203 INFO Train. Iter 28200 : Loss. 0.01648 +2025-10-13 01:03:15,316 INFO Train. Iter 28300 : Loss. 0.01660 +2025-10-13 01:03:41,769 INFO Train. Iter 28400 : Loss. 0.01649 +2025-10-13 01:04:08,886 INFO Train. Iter 28500 : Loss. 0.01585 +2025-10-13 01:04:34,962 INFO Train. Iter 28600 : Loss. 0.01607 +2025-10-13 01:05:01,409 INFO Train. Iter 28700 : Loss. 0.01584 +2025-10-13 01:05:27,715 INFO Train. Iter 28800 : Loss. 0.01557 +2025-10-13 01:05:54,446 INFO Train. Iter 28900 : Loss. 0.01623 +2025-10-13 01:06:20,696 INFO Train. Iter 29000 : Loss. 0.01562 +2025-10-13 01:06:46,907 INFO Train. Iter 29100 : Loss. 0.01603 +2025-10-13 01:07:13,188 INFO Train. Iter 29200 : Loss. 0.01591 +2025-10-13 01:07:39,351 INFO Train. Iter 29300 : Loss. 0.01585 +2025-10-13 01:08:06,215 INFO Train. Iter 29400 : Loss. 0.01608 +2025-10-13 01:08:32,408 INFO Train. Iter 29500 : Loss. 0.01530 +2025-10-13 01:08:58,693 INFO Train. Iter 29600 : Loss. 0.01569 +2025-10-13 01:09:25,045 INFO Train. Iter 29700 : Loss. 0.01573 +2025-10-13 01:09:50,886 INFO Train. Iter 29800 : Loss. 0.01579 +2025-10-13 01:10:17,590 INFO Train. Iter 29900 : Loss. 0.01559 +2025-10-13 01:10:43,854 INFO Train. Iter 30000 : Loss. 0.01548 +2025-10-13 01:11:10,916 INFO Train. Iter 30100 : Loss. 0.01544 +2025-10-13 01:11:37,075 INFO Train. Iter 30200 : Loss. 0.01597 +2025-10-13 01:12:03,813 INFO Train. Iter 30300 : Loss. 0.01626 +2025-10-13 01:12:30,007 INFO Train. Iter 30400 : Loss. 0.01523 +2025-10-13 01:12:56,394 INFO Train. Iter 30500 : Loss. 0.01588 +2025-10-13 01:13:22,721 INFO Train. Iter 30600 : Loss. 0.01595 +2025-10-13 01:13:48,829 INFO Train. Iter 30700 : Loss. 0.01538 +2025-10-13 01:14:16,265 INFO Train. Iter 30800 : Loss. 0.01561 +2025-10-13 01:14:42,747 INFO Train. Iter 30900 : Loss. 0.01538 +2025-10-13 01:15:09,123 INFO Train. Iter 31000 : Loss. 0.01505 +2025-10-13 01:15:35,122 INFO Train. Iter 31100 : Loss. 0.01551 +2025-10-13 01:16:01,398 INFO Train. Iter 31200 : Loss. 0.01547 +2025-10-13 01:16:27,928 INFO Train. Iter 31300 : Loss. 0.01562 +2025-10-13 01:16:54,316 INFO Train. Iter 31400 : Loss. 0.01499 +2025-10-13 01:17:20,369 INFO Train. Iter 31500 : Loss. 0.01531 +2025-10-13 01:17:46,657 INFO Train. Iter 31600 : Loss. 0.01529 +2025-10-13 01:18:13,270 INFO Train. Iter 31700 : Loss. 0.01525 +2025-10-13 01:18:39,427 INFO Train. Iter 31800 : Loss. 0.01512 +2025-10-13 01:19:05,800 INFO Train. Iter 31900 : Loss. 0.01514 +2025-10-13 01:19:31,871 INFO Train. Iter 32000 : Loss. 0.01519 +2025-10-13 01:19:58,053 INFO Train. Iter 32100 : Loss. 0.01482 +2025-10-13 01:20:24,975 INFO Train. Iter 32200 : Loss. 0.01502 +2025-10-13 01:20:51,299 INFO Train. Iter 32300 : Loss. 0.01494 +2025-10-13 01:21:18,151 INFO Train. Iter 32400 : Loss. 0.01520 +2025-10-13 01:21:44,376 INFO Train. Iter 32500 : Loss. 0.01525 +2025-10-13 01:22:10,468 INFO Train. Iter 32600 : Loss. 0.01509 +2025-10-13 01:22:37,360 INFO Train. Iter 32700 : Loss. 0.01497 +2025-10-13 01:23:03,510 INFO Train. Iter 32800 : Loss. 0.01420 +2025-10-13 01:23:29,751 INFO Train. Iter 32900 : Loss. 0.01520 +2025-10-13 01:23:55,927 INFO Train. Iter 33000 : Loss. 0.01481 +2025-10-13 01:24:22,577 INFO Train. Iter 33100 : Loss. 0.01469 +2025-10-13 01:24:48,937 INFO Train. Iter 33200 : Loss. 0.01469 +2025-10-13 01:25:15,196 INFO Train. Iter 33300 : Loss. 0.01492 +2025-10-13 01:25:41,660 INFO Train. Iter 33400 : Loss. 0.01459 +2025-10-13 01:26:07,777 INFO Train. Iter 33500 : Loss. 0.01502 +2025-10-13 01:26:34,731 INFO Train. Iter 33600 : Loss. 0.01480 +2025-10-13 01:27:01,069 INFO Train. Iter 33700 : Loss. 0.01496 +2025-10-13 01:27:27,373 INFO Train. Iter 33800 : Loss. 0.01438 +2025-10-13 01:27:53,594 INFO Train. Iter 33900 : Loss. 0.01460 +2025-10-13 01:28:20,118 INFO Train. Iter 34000 : Loss. 0.01485 +2025-10-13 01:28:46,746 INFO Train. Iter 34100 : Loss. 0.01457 +2025-10-13 01:29:13,202 INFO Train. Iter 34200 : Loss. 0.01446 +2025-10-13 01:29:39,548 INFO Train. Iter 34300 : Loss. 0.01456 +2025-10-13 01:30:05,717 INFO Train. Iter 34400 : Loss. 0.01449 +2025-10-13 01:30:32,645 INFO Train. Iter 34500 : Loss. 0.01480 +2025-10-13 01:30:58,900 INFO Train. Iter 34600 : Loss. 0.01415 +2025-10-13 01:31:25,114 INFO Train. Iter 34700 : Loss. 0.01476 +2025-10-13 01:31:51,496 INFO Train. Iter 34800 : Loss. 0.01441 +2025-10-13 01:32:17,556 INFO Train. Iter 34900 : Loss. 0.01401 +2025-10-13 01:32:44,254 INFO Train. Iter 35000 : Loss. 0.01401 +2025-10-13 01:33:10,520 INFO Train. Iter 35100 : Loss. 0.01462 +2025-10-13 01:33:37,013 INFO Train. Iter 35200 : Loss. 0.01471 +2025-10-13 01:34:03,089 INFO Train. Iter 35300 : Loss. 0.01428 +2025-10-13 01:34:29,363 INFO Train. Iter 35400 : Loss. 0.01447 +2025-10-13 01:34:55,993 INFO Train. Iter 35500 : Loss. 0.01430 +2025-10-13 01:35:22,352 INFO Train. Iter 35600 : Loss. 0.01379 +2025-10-13 01:35:48,668 INFO Train. Iter 35700 : Loss. 0.01393 +2025-10-13 01:36:14,636 INFO Train. Iter 35800 : Loss. 0.01428 +2025-10-13 01:36:41,383 INFO Train. Iter 35900 : Loss. 0.01503 +2025-10-13 01:37:07,644 INFO Train. Iter 36000 : Loss. 0.01417 +2025-10-13 01:37:33,844 INFO Train. Iter 36100 : Loss. 0.01370 +2025-10-13 01:37:59,834 INFO Train. Iter 36200 : Loss. 0.01464 +2025-10-13 01:38:26,145 INFO Train. Iter 36300 : Loss. 0.01455 +2025-10-13 01:38:53,153 INFO Train. Iter 36400 : Loss. 0.01415 +2025-10-13 01:39:19,372 INFO Train. Iter 36500 : Loss. 0.01358 +2025-10-13 01:39:45,653 INFO Train. Iter 36600 : Loss. 0.01430 +2025-10-13 01:40:11,918 INFO Train. Iter 36700 : Loss. 0.01380 +2025-10-13 01:40:38,202 INFO Train. Iter 36800 : Loss. 0.01402 +2025-10-13 01:41:04,687 INFO Train. Iter 36900 : Loss. 0.01350 +2025-10-13 01:41:30,870 INFO Train. Iter 37000 : Loss. 0.01386 +2025-10-13 01:41:57,078 INFO Train. Iter 37100 : Loss. 0.01424 +2025-10-13 01:42:23,700 INFO Train. Iter 37200 : Loss. 0.01409 +2025-10-13 01:42:50,384 INFO Train. Iter 37300 : Loss. 0.01437 +2025-10-13 01:43:16,570 INFO Train. Iter 37400 : Loss. 0.01354 +2025-10-13 01:43:42,751 INFO Train. Iter 37500 : Loss. 0.01371 +2025-10-13 01:44:08,903 INFO Train. Iter 37600 : Loss. 0.01386 +2025-10-13 01:44:35,414 INFO Train. Iter 37700 : Loss. 0.01346 +2025-10-13 01:45:02,194 INFO Train. Iter 37800 : Loss. 0.01414 +2025-10-13 01:45:28,630 INFO Train. Iter 37900 : Loss. 0.01378 +2025-10-13 01:45:55,204 INFO Train. Iter 38000 : Loss. 0.01397 +2025-10-13 01:46:22,045 INFO Train. Iter 38100 : Loss. 0.01352 +2025-10-13 01:46:48,422 INFO Train. Iter 38200 : Loss. 0.01353 +2025-10-13 01:47:15,554 INFO Train. Iter 38300 : Loss. 0.01370 +2025-10-13 01:47:41,896 INFO Train. Iter 38400 : Loss. 0.01361 +2025-10-13 01:48:08,618 INFO Train. Iter 38500 : Loss. 0.01338 +2025-10-13 01:48:34,851 INFO Train. Iter 38600 : Loss. 0.01376 +2025-10-13 01:49:01,839 INFO Train. Iter 38700 : Loss. 0.01393 +2025-10-13 01:49:28,435 INFO Train. Iter 38800 : Loss. 0.01349 +2025-10-13 01:49:54,916 INFO Train. Iter 38900 : Loss. 0.01336 +2025-10-13 01:50:21,157 INFO Train. Iter 39000 : Loss. 0.01342 +2025-10-13 01:50:47,390 INFO Train. Iter 39100 : Loss. 0.01358 +2025-10-13 01:51:14,315 INFO Train. Iter 39200 : Loss. 0.01354 +2025-10-13 01:51:40,553 INFO Train. Iter 39300 : Loss. 0.01333 +2025-10-13 01:52:06,713 INFO Train. Iter 39400 : Loss. 0.01337 +2025-10-13 01:52:32,949 INFO Train. Iter 39500 : Loss. 0.01338 +2025-10-13 01:52:59,443 INFO Train. Iter 39600 : Loss. 0.01335 +2025-10-13 01:53:26,208 INFO Train. Iter 39700 : Loss. 0.01299 +2025-10-13 01:53:52,156 INFO Train. Iter 39800 : Loss. 0.01354 +2025-10-13 01:54:18,497 INFO Train. Iter 39900 : Loss. 0.01379 +2025-10-13 01:54:44,896 INFO Train. Iter 40000 : Loss. 0.01319 +2025-10-13 01:55:12,968 INFO Train. Iter 40100 : Loss. 0.01341 +2025-10-13 01:55:39,080 INFO Train. Iter 40200 : Loss. 0.01300 +2025-10-13 01:56:05,496 INFO Train. Iter 40300 : Loss. 0.01328 +2025-10-13 01:56:31,889 INFO Train. Iter 40400 : Loss. 0.01298 +2025-10-13 01:56:58,239 INFO Train. Iter 40500 : Loss. 0.01348 +2025-10-13 01:57:25,085 INFO Train. Iter 40600 : Loss. 0.01343 +2025-10-13 01:57:51,367 INFO Train. Iter 40700 : Loss. 0.01305 +2025-10-13 01:58:17,620 INFO Train. Iter 40800 : Loss. 0.01291 +2025-10-13 01:58:43,806 INFO Train. Iter 40900 : Loss. 0.01302 +2025-10-13 01:59:10,044 INFO Train. Iter 41000 : Loss. 0.01333 +2025-10-13 01:59:36,660 INFO Train. Iter 41100 : Loss. 0.01291 +2025-10-13 02:00:02,821 INFO Train. Iter 41200 : Loss. 0.01291 +2025-10-13 02:00:29,098 INFO Train. Iter 41300 : Loss. 0.01291 +2025-10-13 02:00:55,307 INFO Train. Iter 41400 : Loss. 0.01319 +2025-10-13 02:01:21,999 INFO Train. Iter 41500 : Loss. 0.01291 +2025-10-13 02:01:48,242 INFO Train. Iter 41600 : Loss. 0.01302 +2025-10-13 02:02:14,532 INFO Train. Iter 41700 : Loss. 0.01328 +2025-10-13 02:02:40,662 INFO Train. Iter 41800 : Loss. 0.01268 +2025-10-13 02:03:06,983 INFO Train. Iter 41900 : Loss. 0.01296 +2025-10-13 02:03:34,051 INFO Train. Iter 42000 : Loss. 0.01267 +2025-10-13 02:04:00,581 INFO Train. Iter 42100 : Loss. 0.01294 +2025-10-13 02:04:26,704 INFO Train. Iter 42200 : Loss. 0.01291 +2025-10-13 02:04:52,749 INFO Train. Iter 42300 : Loss. 0.01269 +2025-10-13 02:05:18,862 INFO Train. Iter 42400 : Loss. 0.01300 +2025-10-13 02:05:45,670 INFO Train. Iter 42500 : Loss. 0.01273 +2025-10-13 02:06:11,714 INFO Train. Iter 42600 : Loss. 0.01251 +2025-10-13 02:06:38,008 INFO Train. Iter 42700 : Loss. 0.01286 +2025-10-13 02:07:04,320 INFO Train. Iter 42800 : Loss. 0.01244 +2025-10-13 02:07:31,127 INFO Train. Iter 42900 : Loss. 0.01280 +2025-10-13 02:07:57,244 INFO Train. Iter 43000 : Loss. 0.01263 +2025-10-13 02:08:23,619 INFO Train. Iter 43100 : Loss. 0.01235 +2025-10-13 02:08:49,857 INFO Train. Iter 43200 : Loss. 0.01280 +2025-10-13 02:09:16,276 INFO Train. Iter 43300 : Loss. 0.01273 +2025-10-13 02:09:43,200 INFO Train. Iter 43400 : Loss. 0.01239 +2025-10-13 02:10:09,663 INFO Train. Iter 43500 : Loss. 0.01223 +2025-10-13 02:10:35,951 INFO Train. Iter 43600 : Loss. 0.01257 +2025-10-13 02:11:02,472 INFO Train. Iter 43700 : Loss. 0.01242 +2025-10-13 02:11:28,685 INFO Train. Iter 43800 : Loss. 0.01273 +2025-10-13 02:11:55,484 INFO Train. Iter 43900 : Loss. 0.01226 +2025-10-13 02:12:22,124 INFO Train. Iter 44000 : Loss. 0.01247 +2025-10-13 02:12:48,646 INFO Train. Iter 44100 : Loss. 0.01238 +2025-10-13 02:13:15,187 INFO Train. Iter 44200 : Loss. 0.01256 +2025-10-13 02:13:41,904 INFO Train. Iter 44300 : Loss. 0.01278 +2025-10-13 02:14:08,325 INFO Train. Iter 44400 : Loss. 0.01219 +2025-10-13 02:14:35,089 INFO Train. Iter 44500 : Loss. 0.01239 +2025-10-13 02:15:01,490 INFO Train. Iter 44600 : Loss. 0.01212 +2025-10-13 02:15:27,852 INFO Train. Iter 44700 : Loss. 0.01262 +2025-10-13 02:15:54,567 INFO Train. Iter 44800 : Loss. 0.01216 +2025-10-13 02:16:20,620 INFO Train. Iter 44900 : Loss. 0.01202 +2025-10-13 02:16:46,871 INFO Train. Iter 45000 : Loss. 0.01222 +2025-10-13 02:17:12,838 INFO Train. Iter 45100 : Loss. 0.01236 +2025-10-13 02:17:39,655 INFO Train. Iter 45200 : Loss. 0.01235 +2025-10-13 02:18:06,148 INFO Train. Iter 45300 : Loss. 0.01192 +2025-10-13 02:18:32,237 INFO Train. Iter 45400 : Loss. 0.01199 +2025-10-13 02:18:58,285 INFO Train. Iter 45500 : Loss. 0.01226 +2025-10-13 02:19:24,503 INFO Train. Iter 45600 : Loss. 0.01213 +2025-10-13 02:19:51,465 INFO Train. Iter 45700 : Loss. 0.01231 +2025-10-13 02:20:17,775 INFO Train. Iter 45800 : Loss. 0.01200 +2025-10-13 02:20:44,227 INFO Train. Iter 45900 : Loss. 0.01176 +2025-10-13 02:21:10,462 INFO Train. Iter 46000 : Loss. 0.01221 +2025-10-13 02:21:36,682 INFO Train. Iter 46100 : Loss. 0.01220 +2025-10-13 02:22:03,464 INFO Train. Iter 46200 : Loss. 0.01202 +2025-10-13 02:22:29,692 INFO Train. Iter 46300 : Loss. 0.01173 +2025-10-13 02:22:55,832 INFO Train. Iter 46400 : Loss. 0.01197 +2025-10-13 02:23:21,949 INFO Train. Iter 46500 : Loss. 0.01177 +2025-10-13 02:23:48,209 INFO Train. Iter 46600 : Loss. 0.01202 +2025-10-13 02:24:15,029 INFO Train. Iter 46700 : Loss. 0.01187 +2025-10-13 02:24:41,453 INFO Train. Iter 46800 : Loss. 0.01216 +2025-10-13 02:25:07,569 INFO Train. Iter 46900 : Loss. 0.01179 +2025-10-13 02:25:33,795 INFO Train. Iter 47000 : Loss. 0.01152 +2025-10-13 02:26:00,670 INFO Train. Iter 47100 : Loss. 0.01165 +2025-10-13 02:26:26,945 INFO Train. Iter 47200 : Loss. 0.01160 +2025-10-13 02:26:53,041 INFO Train. Iter 47300 : Loss. 0.01178 +2025-10-13 02:27:19,577 INFO Train. Iter 47400 : Loss. 0.01187 +2025-10-13 02:27:45,654 INFO Train. Iter 47500 : Loss. 0.01163 +2025-10-13 02:28:12,557 INFO Train. Iter 47600 : Loss. 0.01174 +2025-10-13 02:28:38,908 INFO Train. Iter 47700 : Loss. 0.01166 +2025-10-13 02:29:05,227 INFO Train. Iter 47800 : Loss. 0.01157 +2025-10-13 02:29:31,307 INFO Train. Iter 47900 : Loss. 0.01187 +2025-10-13 02:29:57,951 INFO Train. Iter 48000 : Loss. 0.01171 +2025-10-13 02:30:24,573 INFO Train. Iter 48100 : Loss. 0.01144 +2025-10-13 02:30:50,664 INFO Train. Iter 48200 : Loss. 0.01153 +2025-10-13 02:31:17,028 INFO Train. Iter 48300 : Loss. 0.01143 +2025-10-13 02:31:43,528 INFO Train. Iter 48400 : Loss. 0.01166 +2025-10-13 02:32:10,377 INFO Train. Iter 48500 : Loss. 0.01146 +2025-10-13 02:32:36,745 INFO Train. Iter 48600 : Loss. 0.01119 +2025-10-13 02:33:02,824 INFO Train. Iter 48700 : Loss. 0.01128 +2025-10-13 02:33:29,043 INFO Train. Iter 48800 : Loss. 0.01168 +2025-10-13 02:33:55,290 INFO Train. Iter 48900 : Loss. 0.01165 +2025-10-13 02:34:22,066 INFO Train. Iter 49000 : Loss. 0.01107 +2025-10-13 02:34:48,253 INFO Train. Iter 49100 : Loss. 0.01125 +2025-10-13 02:35:14,419 INFO Train. Iter 49200 : Loss. 0.01120 +2025-10-13 02:35:41,462 INFO Train. Iter 49300 : Loss. 0.01123 +2025-10-13 02:36:08,260 INFO Train. Iter 49400 : Loss. 0.01131 +2025-10-13 02:36:34,819 INFO Train. Iter 49500 : Loss. 0.01102 +2025-10-13 02:37:00,953 INFO Train. Iter 49600 : Loss. 0.01115 +2025-10-13 02:37:27,342 INFO Train. Iter 49700 : Loss. 0.01137 +2025-10-13 02:37:53,538 INFO Train. Iter 49800 : Loss. 0.01143 +2025-10-13 02:38:20,436 INFO Train. Iter 49900 : Loss. 0.01133 +2025-10-13 02:38:46,867 INFO Train. Iter 50000 : Loss. 0.01127 +2025-10-13 02:39:14,811 INFO Train. Iter 50100 : Loss. 0.01093 +2025-10-13 02:39:41,013 INFO Train. Iter 50200 : Loss. 0.01096 +2025-10-13 02:40:07,018 INFO Train. Iter 50300 : Loss. 0.01101 +2025-10-13 02:40:33,605 INFO Train. Iter 50400 : Loss. 0.01097 +2025-10-13 02:40:59,996 INFO Train. Iter 50500 : Loss. 0.01103 +2025-10-13 02:41:26,364 INFO Train. Iter 50600 : Loss. 0.01109 +2025-10-13 02:41:52,566 INFO Train. Iter 50700 : Loss. 0.01097 +2025-10-13 02:42:19,617 INFO Train. Iter 50800 : Loss. 0.01115 +2025-10-13 02:42:45,982 INFO Train. Iter 50900 : Loss. 0.01061 +2025-10-13 02:43:12,240 INFO Train. Iter 51000 : Loss. 0.01095 +2025-10-13 02:43:38,562 INFO Train. Iter 51100 : Loss. 0.01095 +2025-10-13 02:44:04,813 INFO Train. Iter 51200 : Loss. 0.01129 +2025-10-13 02:44:31,672 INFO Train. Iter 51300 : Loss. 0.01086 +2025-10-13 02:44:58,028 INFO Train. Iter 51400 : Loss. 0.01076 +2025-10-13 02:45:24,782 INFO Train. Iter 51500 : Loss. 0.01084 +2025-10-13 02:45:51,320 INFO Train. Iter 51600 : Loss. 0.01062 +2025-10-13 02:46:17,595 INFO Train. Iter 51700 : Loss. 0.01110 +2025-10-13 02:46:44,630 INFO Train. Iter 51800 : Loss. 0.01061 +2025-10-13 02:47:10,999 INFO Train. Iter 51900 : Loss. 0.01051 +2025-10-13 02:47:37,188 INFO Train. Iter 52000 : Loss. 0.01078 +2025-10-13 02:48:03,550 INFO Train. Iter 52100 : Loss. 0.01079 +2025-10-13 02:48:30,857 INFO Train. Iter 52200 : Loss. 0.01101 +2025-10-13 02:48:57,197 INFO Train. Iter 52300 : Loss. 0.01048 +2025-10-13 02:49:23,926 INFO Train. Iter 52400 : Loss. 0.01052 +2025-10-13 02:49:50,146 INFO Train. Iter 52500 : Loss. 0.01090 +2025-10-13 02:50:16,591 INFO Train. Iter 52600 : Loss. 0.01061 +2025-10-13 02:50:43,563 INFO Train. Iter 52700 : Loss. 0.01046 +2025-10-13 02:51:09,873 INFO Train. Iter 52800 : Loss. 0.01028 +2025-10-13 02:51:36,216 INFO Train. Iter 52900 : Loss. 0.01081 +2025-10-13 02:52:02,604 INFO Train. Iter 53000 : Loss. 0.01067 +2025-10-13 02:52:29,033 INFO Train. Iter 53100 : Loss. 0.01063 +2025-10-13 02:52:55,900 INFO Train. Iter 53200 : Loss. 0.01026 +2025-10-13 02:53:22,670 INFO Train. Iter 53300 : Loss. 0.01025 +2025-10-13 02:53:48,863 INFO Train. Iter 53400 : Loss. 0.01062 +2025-10-13 02:54:15,034 INFO Train. Iter 53500 : Loss. 0.01045 +2025-10-13 02:54:41,843 INFO Train. Iter 53600 : Loss. 0.01043 +2025-10-13 02:55:08,308 INFO Train. Iter 53700 : Loss. 0.00990 +2025-10-13 02:55:34,651 INFO Train. Iter 53800 : Loss. 0.01028 +2025-10-13 02:56:01,127 INFO Train. Iter 53900 : Loss. 0.01035 +2025-10-13 02:56:27,329 INFO Train. Iter 54000 : Loss. 0.01078 +2025-10-13 02:56:53,865 INFO Train. Iter 54100 : Loss. 0.01029 +2025-10-13 02:57:20,280 INFO Train. Iter 54200 : Loss. 0.01049 +2025-10-13 02:57:46,955 INFO Train. Iter 54300 : Loss. 0.00997 +2025-10-13 02:58:12,891 INFO Train. Iter 54400 : Loss. 0.01001 +2025-10-13 02:58:38,893 INFO Train. Iter 54500 : Loss. 0.01001 +2025-10-13 02:59:05,449 INFO Train. Iter 54600 : Loss. 0.01016 +2025-10-13 02:59:32,120 INFO Train. Iter 54700 : Loss. 0.01022 +2025-10-13 02:59:58,315 INFO Train. Iter 54800 : Loss. 0.01008 +2025-10-13 03:00:24,727 INFO Train. Iter 54900 : Loss. 0.01021 +2025-10-13 03:00:51,408 INFO Train. Iter 55000 : Loss. 0.01034 +2025-10-13 03:01:17,286 INFO Train. Iter 55100 : Loss. 0.01009 +2025-10-13 03:01:43,429 INFO Train. Iter 55200 : Loss. 0.00984 +2025-10-13 03:02:09,711 INFO Train. Iter 55300 : Loss. 0.00999 +2025-10-13 03:02:35,871 INFO Train. Iter 55400 : Loss. 0.01010 +2025-10-13 03:03:02,906 INFO Train. Iter 55500 : Loss. 0.00974 +2025-10-13 03:03:29,171 INFO Train. Iter 55600 : Loss. 0.01000 +2025-10-13 03:03:55,430 INFO Train. Iter 55700 : Loss. 0.01010 +2025-10-13 03:04:21,533 INFO Train. Iter 55800 : Loss. 0.01002 +2025-10-13 03:04:47,866 INFO Train. Iter 55900 : Loss. 0.00990 +2025-10-13 03:05:14,533 INFO Train. Iter 56000 : Loss. 0.00946 +2025-10-13 03:05:40,637 INFO Train. Iter 56100 : Loss. 0.00960 +2025-10-13 03:06:06,880 INFO Train. Iter 56200 : Loss. 0.00988 +2025-10-13 03:06:33,421 INFO Train. Iter 56300 : Loss. 0.00984 +2025-10-13 03:07:00,277 INFO Train. Iter 56400 : Loss. 0.00975 +2025-10-13 03:07:26,224 INFO Train. Iter 56500 : Loss. 0.00948 +2025-10-13 03:07:52,665 INFO Train. Iter 56600 : Loss. 0.00965 +2025-10-13 03:08:19,176 INFO Train. Iter 56700 : Loss. 0.00988 +2025-10-13 03:08:45,326 INFO Train. Iter 56800 : Loss. 0.00963 +2025-10-13 03:09:11,953 INFO Train. Iter 56900 : Loss. 0.00974 +2025-10-13 03:09:38,451 INFO Train. Iter 57000 : Loss. 0.00947 +2025-10-13 03:10:05,256 INFO Train. Iter 57100 : Loss. 0.00933 +2025-10-13 03:10:32,000 INFO Train. Iter 57200 : Loss. 0.00954 +2025-10-13 03:10:58,139 INFO Train. Iter 57300 : Loss. 0.00975 +2025-10-13 03:11:25,160 INFO Train. Iter 57400 : Loss. 0.00931 +2025-10-13 03:11:51,553 INFO Train. Iter 57500 : Loss. 0.00940 +2025-10-13 03:12:17,912 INFO Train. Iter 57600 : Loss. 0.00941 +2025-10-13 03:12:44,276 INFO Train. Iter 57700 : Loss. 0.00958 +2025-10-13 03:13:11,180 INFO Train. Iter 57800 : Loss. 0.00978 +2025-10-13 03:13:37,993 INFO Train. Iter 57900 : Loss. 0.00926 +2025-10-13 03:14:04,039 INFO Train. Iter 58000 : Loss. 0.00938 +2025-10-13 03:14:30,593 INFO Train. Iter 58100 : Loss. 0.00912 +2025-10-13 03:14:56,745 INFO Train. Iter 58200 : Loss. 0.00929 +2025-10-13 03:15:23,589 INFO Train. Iter 58300 : Loss. 0.00927 +2025-10-13 03:15:49,894 INFO Train. Iter 58400 : Loss. 0.00916 +2025-10-13 03:16:16,154 INFO Train. Iter 58500 : Loss. 0.00913 +2025-10-13 03:16:42,319 INFO Train. Iter 58600 : Loss. 0.00900 +2025-10-13 03:17:08,426 INFO Train. Iter 58700 : Loss. 0.00954 +2025-10-13 03:17:35,110 INFO Train. Iter 58800 : Loss. 0.00908 +2025-10-13 03:18:01,376 INFO Train. Iter 58900 : Loss. 0.00928 +2025-10-13 03:18:27,783 INFO Train. Iter 59000 : Loss. 0.00907 +2025-10-13 03:18:53,917 INFO Train. Iter 59100 : Loss. 0.00938 +2025-10-13 03:19:20,761 INFO Train. Iter 59200 : Loss. 0.00904 +2025-10-13 03:19:47,190 INFO Train. Iter 59300 : Loss. 0.00891 +2025-10-13 03:20:13,489 INFO Train. Iter 59400 : Loss. 0.00886 +2025-10-13 03:20:40,072 INFO Train. Iter 59500 : Loss. 0.00897 +2025-10-13 03:21:06,327 INFO Train. Iter 59600 : Loss. 0.00923 +2025-10-13 03:21:32,944 INFO Train. Iter 59700 : Loss. 0.00886 +2025-10-13 03:21:59,532 INFO Train. Iter 59800 : Loss. 0.00902 +2025-10-13 03:22:25,687 INFO Train. Iter 59900 : Loss. 0.00894 +2025-10-13 03:22:51,913 INFO Train. Iter 60000 : Loss. 0.00908 +2025-10-13 03:23:19,764 INFO Train. Iter 60100 : Loss. 0.00891 +2025-10-13 03:23:46,797 INFO Train. Iter 60200 : Loss. 0.00894 +2025-10-13 03:24:12,671 INFO Train. Iter 60300 : Loss. 0.00874 +2025-10-13 03:24:38,988 INFO Train. Iter 60400 : Loss. 0.00897 +2025-10-13 03:25:05,320 INFO Train. Iter 60500 : Loss. 0.00904 +2025-10-13 03:25:32,113 INFO Train. Iter 60600 : Loss. 0.00877 +2025-10-13 03:25:58,496 INFO Train. Iter 60700 : Loss. 0.00864 +2025-10-13 03:26:24,727 INFO Train. Iter 60800 : Loss. 0.00857 +2025-10-13 03:26:50,824 INFO Train. Iter 60900 : Loss. 0.00871 +2025-10-13 03:27:17,300 INFO Train. Iter 61000 : Loss. 0.00911 +2025-10-13 03:27:43,884 INFO Train. Iter 61100 : Loss. 0.00858 +2025-10-13 03:28:10,437 INFO Train. Iter 61200 : Loss. 0.00854 +2025-10-13 03:28:36,566 INFO Train. Iter 61300 : Loss. 0.00852 +2025-10-13 03:29:03,019 INFO Train. Iter 61400 : Loss. 0.00843 +2025-10-13 03:29:29,129 INFO Train. Iter 61500 : Loss. 0.00860 +2025-10-13 03:29:56,355 INFO Train. Iter 61600 : Loss. 0.00859 +2025-10-13 03:30:22,735 INFO Train. Iter 61700 : Loss. 0.00858 +2025-10-13 03:30:48,818 INFO Train. Iter 61800 : Loss. 0.00855 +2025-10-13 03:31:14,841 INFO Train. Iter 61900 : Loss. 0.00860 +2025-10-13 03:31:41,551 INFO Train. Iter 62000 : Loss. 0.00862 +2025-10-13 03:32:07,690 INFO Train. Iter 62100 : Loss. 0.00831 +2025-10-13 03:32:33,906 INFO Train. Iter 62200 : Loss. 0.00831 +2025-10-13 03:32:59,977 INFO Train. Iter 62300 : Loss. 0.00860 +2025-10-13 03:33:26,267 INFO Train. Iter 62400 : Loss. 0.00819 +2025-10-13 03:33:53,030 INFO Train. Iter 62500 : Loss. 0.00855 +2025-10-13 03:34:19,492 INFO Train. Iter 62600 : Loss. 0.00838 +2025-10-13 03:34:46,003 INFO Train. Iter 62700 : Loss. 0.00834 +2025-10-13 03:35:12,182 INFO Train. Iter 62800 : Loss. 0.00828 +2025-10-13 03:35:38,441 INFO Train. Iter 62900 : Loss. 0.00827 +2025-10-13 03:36:05,180 INFO Train. Iter 63000 : Loss. 0.00795 +2025-10-13 03:36:31,746 INFO Train. Iter 63100 : Loss. 0.00800 +2025-10-13 03:36:58,084 INFO Train. Iter 63200 : Loss. 0.00813 +2025-10-13 03:37:24,277 INFO Train. Iter 63300 : Loss. 0.00839 +2025-10-13 03:37:51,119 INFO Train. Iter 63400 : Loss. 0.00822 +2025-10-13 03:38:17,329 INFO Train. Iter 63500 : Loss. 0.00780 +2025-10-13 03:38:43,538 INFO Train. Iter 63600 : Loss. 0.00805 +2025-10-13 03:39:09,899 INFO Train. Iter 63700 : Loss. 0.00807 +2025-10-13 03:39:35,727 INFO Train. Iter 63800 : Loss. 0.00817 +2025-10-13 03:40:02,786 INFO Train. Iter 63900 : Loss. 0.00795 +2025-10-13 03:40:29,041 INFO Train. Iter 64000 : Loss. 0.00790 +2025-10-13 03:40:55,183 INFO Train. Iter 64100 : Loss. 0.00785 +2025-10-13 03:41:21,583 INFO Train. Iter 64200 : Loss. 0.00818 +2025-10-13 03:41:47,752 INFO Train. Iter 64300 : Loss. 0.00788 +2025-10-13 03:42:14,558 INFO Train. Iter 64400 : Loss. 0.00789 +2025-10-13 03:42:40,757 INFO Train. Iter 64500 : Loss. 0.00804 +2025-10-13 03:43:06,899 INFO Train. Iter 64600 : Loss. 0.00784 +2025-10-13 03:43:33,218 INFO Train. Iter 64700 : Loss. 0.00776 +2025-10-13 03:43:59,903 INFO Train. Iter 64800 : Loss. 0.00782 +2025-10-13 03:44:26,435 INFO Train. Iter 64900 : Loss. 0.00785 +2025-10-13 03:44:53,288 INFO Train. Iter 65000 : Loss. 0.00778 +2025-10-13 03:45:19,781 INFO Train. Iter 65100 : Loss. 0.00773 +2025-10-13 03:45:46,040 INFO Train. Iter 65200 : Loss. 0.00773 +2025-10-13 03:46:13,228 INFO Train. Iter 65300 : Loss. 0.00771 +2025-10-13 03:46:39,686 INFO Train. Iter 65400 : Loss. 0.00772 +2025-10-13 03:47:05,879 INFO Train. Iter 65500 : Loss. 0.00760 +2025-10-13 03:47:32,514 INFO Train. Iter 65600 : Loss. 0.00766 +2025-10-13 03:47:59,147 INFO Train. Iter 65700 : Loss. 0.00740 +2025-10-13 03:48:26,327 INFO Train. Iter 65800 : Loss. 0.00754 +2025-10-13 03:48:52,494 INFO Train. Iter 65900 : Loss. 0.00755 +2025-10-13 03:49:18,923 INFO Train. Iter 66000 : Loss. 0.00726 +2025-10-13 03:49:45,127 INFO Train. Iter 66100 : Loss. 0.00778 +2025-10-13 03:50:11,746 INFO Train. Iter 66200 : Loss. 0.00740 +2025-10-13 03:50:38,110 INFO Train. Iter 66300 : Loss. 0.00761 +2025-10-13 03:51:04,243 INFO Train. Iter 66400 : Loss. 0.00765 +2025-10-13 03:51:30,621 INFO Train. Iter 66500 : Loss. 0.00738 +2025-10-13 03:51:57,031 INFO Train. Iter 66600 : Loss. 0.00757 +2025-10-13 03:52:23,730 INFO Train. Iter 66700 : Loss. 0.00742 +2025-10-13 03:52:50,234 INFO Train. Iter 66800 : Loss. 0.00734 +2025-10-13 03:53:16,406 INFO Train. Iter 66900 : Loss. 0.00746 +2025-10-13 03:53:42,687 INFO Train. Iter 67000 : Loss. 0.00741 +2025-10-13 03:54:09,028 INFO Train. Iter 67100 : Loss. 0.00743 +2025-10-13 03:54:35,998 INFO Train. Iter 67200 : Loss. 0.00726 +2025-10-13 03:55:02,144 INFO Train. Iter 67300 : Loss. 0.00738 +2025-10-13 03:55:28,992 INFO Train. Iter 67400 : Loss. 0.00730 +2025-10-13 03:55:55,368 INFO Train. Iter 67500 : Loss. 0.00729 +2025-10-13 03:56:22,328 INFO Train. Iter 67600 : Loss. 0.00717 +2025-10-13 03:56:48,512 INFO Train. Iter 67700 : Loss. 0.00699 +2025-10-13 03:57:14,654 INFO Train. Iter 67800 : Loss. 0.00710 +2025-10-13 03:57:41,288 INFO Train. Iter 67900 : Loss. 0.00739 +2025-10-13 03:58:07,342 INFO Train. Iter 68000 : Loss. 0.00735 +2025-10-13 03:58:33,896 INFO Train. Iter 68100 : Loss. 0.00706 +2025-10-13 03:59:00,095 INFO Train. Iter 68200 : Loss. 0.00702 +2025-10-13 03:59:26,460 INFO Train. Iter 68300 : Loss. 0.00713 +2025-10-13 03:59:53,141 INFO Train. Iter 68400 : Loss. 0.00698 +2025-10-13 04:00:19,261 INFO Train. Iter 68500 : Loss. 0.00716 +2025-10-13 04:00:46,012 INFO Train. Iter 68600 : Loss. 0.00674 +2025-10-13 04:01:12,164 INFO Train. Iter 68700 : Loss. 0.00686 +2025-10-13 04:01:38,229 INFO Train. Iter 68800 : Loss. 0.00694 +2025-10-13 04:02:04,740 INFO Train. Iter 68900 : Loss. 0.00699 +2025-10-13 04:02:31,546 INFO Train. Iter 69000 : Loss. 0.00687 +2025-10-13 04:02:57,821 INFO Train. Iter 69100 : Loss. 0.00679 +2025-10-13 04:03:24,597 INFO Train. Iter 69200 : Loss. 0.00691 +2025-10-13 04:03:51,043 INFO Train. Iter 69300 : Loss. 0.00680 +2025-10-13 04:04:17,333 INFO Train. Iter 69400 : Loss. 0.00657 +2025-10-13 04:04:43,942 INFO Train. Iter 69500 : Loss. 0.00673 +2025-10-13 04:05:10,412 INFO Train. Iter 69600 : Loss. 0.00651 +2025-10-13 04:05:36,552 INFO Train. Iter 69700 : Loss. 0.00671 +2025-10-13 04:06:02,865 INFO Train. Iter 69800 : Loss. 0.00672 +2025-10-13 04:06:29,089 INFO Train. Iter 69900 : Loss. 0.00673 +2025-10-13 04:06:55,674 INFO Train. Iter 70000 : Loss. 0.00663 +2025-10-13 04:07:23,328 INFO Train. Iter 70100 : Loss. 0.00657 +2025-10-13 04:07:49,566 INFO Train. Iter 70200 : Loss. 0.00646 +2025-10-13 04:08:15,659 INFO Train. Iter 70300 : Loss. 0.00675 +2025-10-13 04:08:42,556 INFO Train. Iter 70400 : Loss. 0.00676 +2025-10-13 04:09:08,585 INFO Train. Iter 70500 : Loss. 0.00648 +2025-10-13 04:09:35,164 INFO Train. Iter 70600 : Loss. 0.00665 +2025-10-13 04:10:01,437 INFO Train. Iter 70700 : Loss. 0.00663 +2025-10-13 04:10:28,252 INFO Train. Iter 70800 : Loss. 0.00641 +2025-10-13 04:10:55,344 INFO Train. Iter 70900 : Loss. 0.00632 +2025-10-13 04:11:21,741 INFO Train. Iter 71000 : Loss. 0.00642 +2025-10-13 04:11:48,213 INFO Train. Iter 71100 : Loss. 0.00650 +2025-10-13 04:12:14,598 INFO Train. Iter 71200 : Loss. 0.00636 +2025-10-13 04:12:41,473 INFO Train. Iter 71300 : Loss. 0.00638 +2025-10-13 04:13:07,710 INFO Train. Iter 71400 : Loss. 0.00635 +2025-10-13 04:13:34,455 INFO Train. Iter 71500 : Loss. 0.00630 +2025-10-13 04:14:00,867 INFO Train. Iter 71600 : Loss. 0.00633 +2025-10-13 04:14:27,041 INFO Train. Iter 71700 : Loss. 0.00640 +2025-10-13 04:14:53,715 INFO Train. Iter 71800 : Loss. 0.00636 +2025-10-13 04:15:20,033 INFO Train. Iter 71900 : Loss. 0.00626 +2025-10-13 04:15:46,221 INFO Train. Iter 72000 : Loss. 0.00610 +2025-10-13 04:16:12,322 INFO Train. Iter 72100 : Loss. 0.00644 +2025-10-13 04:16:38,862 INFO Train. Iter 72200 : Loss. 0.00627 +2025-10-13 04:17:05,709 INFO Train. Iter 72300 : Loss. 0.00615 +2025-10-13 04:17:31,729 INFO Train. Iter 72400 : Loss. 0.00609 +2025-10-13 04:17:58,140 INFO Train. Iter 72500 : Loss. 0.00620 +2025-10-13 04:18:24,366 INFO Train. Iter 72600 : Loss. 0.00595 +2025-10-13 04:18:51,178 INFO Train. Iter 72700 : Loss. 0.00623 +2025-10-13 04:19:17,631 INFO Train. Iter 72800 : Loss. 0.00593 +2025-10-13 04:19:43,835 INFO Train. Iter 72900 : Loss. 0.00606 +2025-10-13 04:20:10,253 INFO Train. Iter 73000 : Loss. 0.00602 +2025-10-13 04:20:36,561 INFO Train. Iter 73100 : Loss. 0.00617 +2025-10-13 04:21:03,462 INFO Train. Iter 73200 : Loss. 0.00604 +2025-10-13 04:21:29,838 INFO Train. Iter 73300 : Loss. 0.00578 +2025-10-13 04:21:56,127 INFO Train. Iter 73400 : Loss. 0.00603 +2025-10-13 04:22:22,500 INFO Train. Iter 73500 : Loss. 0.00596 +2025-10-13 04:22:48,936 INFO Train. Iter 73600 : Loss. 0.00598 +2025-10-13 04:23:16,189 INFO Train. Iter 73700 : Loss. 0.00581 +2025-10-13 04:23:43,138 INFO Train. Iter 73800 : Loss. 0.00581 +2025-10-13 04:24:09,647 INFO Train. Iter 73900 : Loss. 0.00588 +2025-10-13 04:24:36,100 INFO Train. Iter 74000 : Loss. 0.00579 +2025-10-13 04:25:03,003 INFO Train. Iter 74100 : Loss. 0.00589 +2025-10-13 04:25:29,379 INFO Train. Iter 74200 : Loss. 0.00582 +2025-10-13 04:25:55,701 INFO Train. Iter 74300 : Loss. 0.00578 +2025-10-13 04:26:22,304 INFO Train. Iter 74400 : Loss. 0.00574 +2025-10-13 04:26:48,750 INFO Train. Iter 74500 : Loss. 0.00572 +2025-10-13 04:27:15,723 INFO Train. Iter 74600 : Loss. 0.00572 +2025-10-13 04:27:41,926 INFO Train. Iter 74700 : Loss. 0.00571 +2025-10-13 04:28:08,204 INFO Train. Iter 74800 : Loss. 0.00564 +2025-10-13 04:28:34,398 INFO Train. Iter 74900 : Loss. 0.00569 +2025-10-13 04:29:00,399 INFO Train. Iter 75000 : Loss. 0.00552 +2025-10-13 04:29:27,052 INFO Train. Iter 75100 : Loss. 0.00570 +2025-10-13 04:29:53,162 INFO Train. Iter 75200 : Loss. 0.00549 +2025-10-13 04:30:19,305 INFO Train. Iter 75300 : Loss. 0.00564 +2025-10-13 04:30:46,035 INFO Train. Iter 75400 : Loss. 0.00561 +2025-10-13 04:31:13,296 INFO Train. Iter 75500 : Loss. 0.00559 +2025-10-13 04:31:39,569 INFO Train. Iter 75600 : Loss. 0.00533 +2025-10-13 04:32:06,207 INFO Train. Iter 75700 : Loss. 0.00555 +2025-10-13 04:32:32,350 INFO Train. Iter 75800 : Loss. 0.00545 +2025-10-13 04:32:58,627 INFO Train. Iter 75900 : Loss. 0.00553 +2025-10-13 04:33:25,409 INFO Train. Iter 76000 : Loss. 0.00543 +2025-10-13 04:33:51,475 INFO Train. Iter 76100 : Loss. 0.00524 +2025-10-13 04:34:17,876 INFO Train. Iter 76200 : Loss. 0.00541 +2025-10-13 04:34:44,061 INFO Train. Iter 76300 : Loss. 0.00542 +2025-10-13 04:35:10,435 INFO Train. Iter 76400 : Loss. 0.00539 +2025-10-13 04:35:37,181 INFO Train. Iter 76500 : Loss. 0.00532 +2025-10-13 04:36:03,651 INFO Train. Iter 76600 : Loss. 0.00540 +2025-10-13 04:36:30,032 INFO Train. Iter 76700 : Loss. 0.00532 +2025-10-13 04:36:56,321 INFO Train. Iter 76800 : Loss. 0.00519 +2025-10-13 04:37:23,598 INFO Train. Iter 76900 : Loss. 0.00528 +2025-10-13 04:37:49,874 INFO Train. Iter 77000 : Loss. 0.00521 +2025-10-13 04:38:16,465 INFO Train. Iter 77100 : Loss. 0.00526 +2025-10-13 04:38:42,745 INFO Train. Iter 77200 : Loss. 0.00524 +2025-10-13 04:39:09,071 INFO Train. Iter 77300 : Loss. 0.00529 +2025-10-13 04:39:35,727 INFO Train. Iter 77400 : Loss. 0.00505 +2025-10-13 04:40:02,330 INFO Train. Iter 77500 : Loss. 0.00506 +2025-10-13 04:40:28,543 INFO Train. Iter 77600 : Loss. 0.00514 +2025-10-13 04:40:54,599 INFO Train. Iter 77700 : Loss. 0.00508 +2025-10-13 04:41:20,670 INFO Train. Iter 77800 : Loss. 0.00518 +2025-10-13 04:41:47,628 INFO Train. Iter 77900 : Loss. 0.00497 +2025-10-13 04:42:13,898 INFO Train. Iter 78000 : Loss. 0.00509 +2025-10-13 04:42:40,113 INFO Train. Iter 78100 : Loss. 0.00518 +2025-10-13 04:43:06,352 INFO Train. Iter 78200 : Loss. 0.00506 +2025-10-13 04:43:33,051 INFO Train. Iter 78300 : Loss. 0.00505 +2025-10-13 04:43:59,313 INFO Train. Iter 78400 : Loss. 0.00486 +2025-10-13 04:44:25,792 INFO Train. Iter 78500 : Loss. 0.00485 +2025-10-13 04:44:51,791 INFO Train. Iter 78600 : Loss. 0.00513 +2025-10-13 04:45:18,151 INFO Train. Iter 78700 : Loss. 0.00499 +2025-10-13 04:45:45,485 INFO Train. Iter 78800 : Loss. 0.00495 +2025-10-13 04:46:11,673 INFO Train. Iter 78900 : Loss. 0.00513 +2025-10-13 04:46:37,881 INFO Train. Iter 79000 : Loss. 0.00485 +2025-10-13 04:47:03,771 INFO Train. Iter 79100 : Loss. 0.00476 +2025-10-13 04:47:30,177 INFO Train. Iter 79200 : Loss. 0.00490 +2025-10-13 04:47:56,961 INFO Train. Iter 79300 : Loss. 0.00479 +2025-10-13 04:48:23,166 INFO Train. Iter 79400 : Loss. 0.00478 +2025-10-13 04:48:49,612 INFO Train. Iter 79500 : Loss. 0.00493 +2025-10-13 04:49:15,858 INFO Train. Iter 79600 : Loss. 0.00489 +2025-10-13 04:49:42,563 INFO Train. Iter 79700 : Loss. 0.00483 +2025-10-13 04:50:08,565 INFO Train. Iter 79800 : Loss. 0.00469 +2025-10-13 04:50:35,167 INFO Train. Iter 79900 : Loss. 0.00468 +2025-10-13 04:51:01,541 INFO Train. Iter 80000 : Loss. 0.00478 +2025-10-13 04:51:29,338 INFO Train. Iter 80100 : Loss. 0.00480 +2025-10-13 04:51:56,268 INFO Train. Iter 80200 : Loss. 0.00470 +2025-10-13 04:52:22,306 INFO Train. Iter 80300 : Loss. 0.00464 +2025-10-13 04:52:48,266 INFO Train. Iter 80400 : Loss. 0.00470 +2025-10-13 04:53:14,444 INFO Train. Iter 80500 : Loss. 0.00461 +2025-10-13 04:53:40,938 INFO Train. Iter 80600 : Loss. 0.00470 +2025-10-13 04:54:07,814 INFO Train. Iter 80700 : Loss. 0.00470 +2025-10-13 04:54:34,113 INFO Train. Iter 80800 : Loss. 0.00459 +2025-10-13 04:55:00,305 INFO Train. Iter 80900 : Loss. 0.00457 +2025-10-13 04:55:26,688 INFO Train. Iter 81000 : Loss. 0.00437 +2025-10-13 04:55:53,627 INFO Train. Iter 81100 : Loss. 0.00456 +2025-10-13 04:56:20,175 INFO Train. Iter 81200 : Loss. 0.00450 +2025-10-13 04:56:46,799 INFO Train. Iter 81300 : Loss. 0.00455 +2025-10-13 04:57:13,068 INFO Train. Iter 81400 : Loss. 0.00450 +2025-10-13 04:57:39,222 INFO Train. Iter 81500 : Loss. 0.00452 +2025-10-13 04:58:05,796 INFO Train. Iter 81600 : Loss. 0.00448 +2025-10-13 04:58:32,202 INFO Train. Iter 81700 : Loss. 0.00440 +2025-10-13 04:58:58,563 INFO Train. Iter 81800 : Loss. 0.00454 +2025-10-13 04:59:24,791 INFO Train. Iter 81900 : Loss. 0.00439 +2025-10-13 04:59:51,278 INFO Train. Iter 82000 : Loss. 0.00432 +2025-10-13 05:00:18,047 INFO Train. Iter 82100 : Loss. 0.00445 +2025-10-13 05:00:44,526 INFO Train. Iter 82200 : Loss. 0.00434 +2025-10-13 05:01:10,999 INFO Train. Iter 82300 : Loss. 0.00425 +2025-10-13 05:01:37,573 INFO Train. Iter 82400 : Loss. 0.00448 +2025-10-13 05:02:04,146 INFO Train. Iter 82500 : Loss. 0.00451 +2025-10-13 05:02:30,732 INFO Train. Iter 82600 : Loss. 0.00445 +2025-10-13 05:02:57,404 INFO Train. Iter 82700 : Loss. 0.00424 +2025-10-13 05:03:23,788 INFO Train. Iter 82800 : Loss. 0.00437 +2025-10-13 05:03:50,362 INFO Train. Iter 82900 : Loss. 0.00442 +2025-10-13 05:04:17,228 INFO Train. Iter 83000 : Loss. 0.00424 +2025-10-13 05:04:44,150 INFO Train. Iter 83100 : Loss. 0.00417 +2025-10-13 05:05:10,488 INFO Train. Iter 83200 : Loss. 0.00419 +2025-10-13 05:05:36,810 INFO Train. Iter 83300 : Loss. 0.00436 +2025-10-13 05:06:02,873 INFO Train. Iter 83400 : Loss. 0.00437 +2025-10-13 05:06:29,461 INFO Train. Iter 83500 : Loss. 0.00424 +2025-10-13 05:06:55,615 INFO Train. Iter 83600 : Loss. 0.00420 +2025-10-13 05:07:21,993 INFO Train. Iter 83700 : Loss. 0.00412 +2025-10-13 05:07:48,522 INFO Train. Iter 83800 : Loss. 0.00416 +2025-10-13 05:08:15,299 INFO Train. Iter 83900 : Loss. 0.00422 +2025-10-13 05:08:41,848 INFO Train. Iter 84000 : Loss. 0.00418 +2025-10-13 05:09:08,188 INFO Train. Iter 84100 : Loss. 0.00416 +2025-10-13 05:09:34,453 INFO Train. Iter 84200 : Loss. 0.00413 +2025-10-13 05:10:00,820 INFO Train. Iter 84300 : Loss. 0.00420 +2025-10-13 05:10:27,661 INFO Train. Iter 84400 : Loss. 0.00411 +2025-10-13 05:10:53,789 INFO Train. Iter 84500 : Loss. 0.00402 +2025-10-13 05:11:20,302 INFO Train. Iter 84600 : Loss. 0.00414 +2025-10-13 05:11:46,670 INFO Train. Iter 84700 : Loss. 0.00421 +2025-10-13 05:12:13,087 INFO Train. Iter 84800 : Loss. 0.00408 +2025-10-13 05:12:40,142 INFO Train. Iter 84900 : Loss. 0.00404 +2025-10-13 05:13:06,586 INFO Train. Iter 85000 : Loss. 0.00403 +2025-10-13 05:13:33,122 INFO Train. Iter 85100 : Loss. 0.00403 +2025-10-13 05:13:59,414 INFO Train. Iter 85200 : Loss. 0.00409 +2025-10-13 05:14:26,333 INFO Train. Iter 85300 : Loss. 0.00404 +2025-10-13 05:14:52,608 INFO Train. Iter 85400 : Loss. 0.00391 +2025-10-13 05:15:18,755 INFO Train. Iter 85500 : Loss. 0.00403 +2025-10-13 05:15:44,984 INFO Train. Iter 85600 : Loss. 0.00395 +2025-10-13 05:16:11,417 INFO Train. Iter 85700 : Loss. 0.00403 +2025-10-13 05:16:38,580 INFO Train. Iter 85800 : Loss. 0.00396 +2025-10-13 05:17:04,997 INFO Train. Iter 85900 : Loss. 0.00393 +2025-10-13 05:17:31,177 INFO Train. Iter 86000 : Loss. 0.00390 +2025-10-13 05:17:57,621 INFO Train. Iter 86100 : Loss. 0.00398 +2025-10-13 05:18:23,823 INFO Train. Iter 86200 : Loss. 0.00393 +2025-10-13 05:18:50,361 INFO Train. Iter 86300 : Loss. 0.00391 +2025-10-13 05:19:16,999 INFO Train. Iter 86400 : Loss. 0.00386 +2025-10-13 05:19:43,823 INFO Train. Iter 86500 : Loss. 0.00391 +2025-10-13 05:20:10,163 INFO Train. Iter 86600 : Loss. 0.00384 +2025-10-13 05:20:37,112 INFO Train. Iter 86700 : Loss. 0.00396 +2025-10-13 05:21:03,677 INFO Train. Iter 86800 : Loss. 0.00373 +2025-10-13 05:21:29,957 INFO Train. Iter 86900 : Loss. 0.00382 +2025-10-13 05:21:56,201 INFO Train. Iter 87000 : Loss. 0.00388 +2025-10-13 05:22:22,615 INFO Train. Iter 87100 : Loss. 0.00396 +2025-10-13 05:22:49,269 INFO Train. Iter 87200 : Loss. 0.00379 +2025-10-13 05:23:15,414 INFO Train. Iter 87300 : Loss. 0.00373 +2025-10-13 05:23:41,693 INFO Train. Iter 87400 : Loss. 0.00381 +2025-10-13 05:24:08,278 INFO Train. Iter 87500 : Loss. 0.00377 +2025-10-13 05:24:34,659 INFO Train. Iter 87600 : Loss. 0.00376 +2025-10-13 05:25:01,630 INFO Train. Iter 87700 : Loss. 0.00371 +2025-10-13 05:25:27,949 INFO Train. Iter 87800 : Loss. 0.00387 +2025-10-13 05:25:54,178 INFO Train. Iter 87900 : Loss. 0.00369 +2025-10-13 05:26:20,505 INFO Train. Iter 88000 : Loss. 0.00389 +2025-10-13 05:26:47,610 INFO Train. Iter 88100 : Loss. 0.00374 +2025-10-13 05:27:13,973 INFO Train. Iter 88200 : Loss. 0.00379 +2025-10-13 05:27:40,358 INFO Train. Iter 88300 : Loss. 0.00375 +2025-10-13 05:28:06,292 INFO Train. Iter 88400 : Loss. 0.00367 +2025-10-13 05:28:32,547 INFO Train. Iter 88500 : Loss. 0.00366 +2025-10-13 05:28:59,512 INFO Train. Iter 88600 : Loss. 0.00364 +2025-10-13 05:29:25,681 INFO Train. Iter 88700 : Loss. 0.00370 +2025-10-13 05:29:52,280 INFO Train. Iter 88800 : Loss. 0.00367 +2025-10-13 05:30:18,611 INFO Train. Iter 88900 : Loss. 0.00371 +2025-10-13 05:30:45,084 INFO Train. Iter 89000 : Loss. 0.00365 +2025-10-13 05:31:12,253 INFO Train. Iter 89100 : Loss. 0.00365 +2025-10-13 05:31:38,816 INFO Train. Iter 89200 : Loss. 0.00353 +2025-10-13 05:32:04,991 INFO Train. Iter 89300 : Loss. 0.00373 +2025-10-13 05:32:31,142 INFO Train. Iter 89400 : Loss. 0.00362 +2025-10-13 05:32:57,966 INFO Train. Iter 89500 : Loss. 0.00373 +2025-10-13 05:33:24,468 INFO Train. Iter 89600 : Loss. 0.00361 +2025-10-13 05:33:50,934 INFO Train. Iter 89700 : Loss. 0.00356 +2025-10-13 05:34:17,350 INFO Train. Iter 89800 : Loss. 0.00369 +2025-10-13 05:34:43,727 INFO Train. Iter 89900 : Loss. 0.00361 +2025-10-13 05:35:10,779 INFO Train. Iter 90000 : Loss. 0.00352 +2025-10-13 05:35:38,551 INFO Train. Iter 90100 : Loss. 0.00359 +2025-10-13 05:36:04,548 INFO Train. Iter 90200 : Loss. 0.00360 +2025-10-13 05:36:31,056 INFO Train. Iter 90300 : Loss. 0.00357 +2025-10-13 05:36:57,631 INFO Train. Iter 90400 : Loss. 0.00371 +2025-10-13 05:37:24,455 INFO Train. Iter 90500 : Loss. 0.00361 +2025-10-13 05:37:50,773 INFO Train. Iter 90600 : Loss. 0.00364 +2025-10-13 05:38:17,005 INFO Train. Iter 90700 : Loss. 0.00350 +2025-10-13 05:38:43,340 INFO Train. Iter 90800 : Loss. 0.00351 +2025-10-13 05:39:10,135 INFO Train. Iter 90900 : Loss. 0.00350 +2025-10-13 05:39:36,376 INFO Train. Iter 91000 : Loss. 0.00347 +2025-10-13 05:40:02,581 INFO Train. Iter 91100 : Loss. 0.00358 +2025-10-13 05:40:29,146 INFO Train. Iter 91200 : Loss. 0.00347 +2025-10-13 05:40:55,551 INFO Train. Iter 91300 : Loss. 0.00352 +2025-10-13 05:41:22,280 INFO Train. Iter 91400 : Loss. 0.00342 +2025-10-13 05:41:48,792 INFO Train. Iter 91500 : Loss. 0.00355 +2025-10-13 05:42:15,006 INFO Train. Iter 91600 : Loss. 0.00347 +2025-10-13 05:42:41,114 INFO Train. Iter 91700 : Loss. 0.00347 +2025-10-13 05:43:07,616 INFO Train. Iter 91800 : Loss. 0.00344 +2025-10-13 05:43:34,785 INFO Train. Iter 91900 : Loss. 0.00346 +2025-10-13 05:44:01,063 INFO Train. Iter 92000 : Loss. 0.00350 +2025-10-13 05:44:27,486 INFO Train. Iter 92100 : Loss. 0.00348 +2025-10-13 05:44:53,791 INFO Train. Iter 92200 : Loss. 0.00344 +2025-10-13 05:45:20,735 INFO Train. Iter 92300 : Loss. 0.00348 +2025-10-13 05:45:47,165 INFO Train. Iter 92400 : Loss. 0.00344 +2025-10-13 05:46:13,283 INFO Train. Iter 92500 : Loss. 0.00339 +2025-10-13 05:46:39,747 INFO Train. Iter 92600 : Loss. 0.00346 +2025-10-13 05:47:06,297 INFO Train. Iter 92700 : Loss. 0.00339 +2025-10-13 05:47:33,542 INFO Train. Iter 92800 : Loss. 0.00346 +2025-10-13 05:47:59,535 INFO Train. Iter 92900 : Loss. 0.00343 +2025-10-13 05:48:25,965 INFO Train. Iter 93000 : Loss. 0.00358 +2025-10-13 05:48:52,350 INFO Train. Iter 93100 : Loss. 0.00339 +2025-10-13 05:49:18,823 INFO Train. Iter 93200 : Loss. 0.00344 +2025-10-13 05:49:45,657 INFO Train. Iter 93300 : Loss. 0.00341 +2025-10-13 05:50:12,139 INFO Train. Iter 93400 : Loss. 0.00352 +2025-10-13 05:50:38,386 INFO Train. Iter 93500 : Loss. 0.00345 +2025-10-13 05:51:04,823 INFO Train. Iter 93600 : Loss. 0.00344 +2025-10-13 05:51:31,817 INFO Train. Iter 93700 : Loss. 0.00348 +2025-10-13 05:51:58,475 INFO Train. Iter 93800 : Loss. 0.00337 +2025-10-13 05:52:24,828 INFO Train. Iter 93900 : Loss. 0.00332 +2025-10-13 05:52:51,071 INFO Train. Iter 94000 : Loss. 0.00342 +2025-10-13 05:53:17,531 INFO Train. Iter 94100 : Loss. 0.00336 +2025-10-13 05:53:44,395 INFO Train. Iter 94200 : Loss. 0.00339 +2025-10-13 05:54:10,853 INFO Train. Iter 94300 : Loss. 0.00344 +2025-10-13 05:54:37,485 INFO Train. Iter 94400 : Loss. 0.00333 +2025-10-13 05:55:03,796 INFO Train. Iter 94500 : Loss. 0.00344 +2025-10-13 05:55:30,531 INFO Train. Iter 94600 : Loss. 0.00334 +2025-10-13 05:55:57,207 INFO Train. Iter 94700 : Loss. 0.00339 +2025-10-13 05:56:23,482 INFO Train. Iter 94800 : Loss. 0.00332 +2025-10-13 05:56:49,662 INFO Train. Iter 94900 : Loss. 0.00332 +2025-10-13 05:57:15,746 INFO Train. Iter 95000 : Loss. 0.00344 +2025-10-13 05:57:42,578 INFO Train. Iter 95100 : Loss. 0.00344 +2025-10-13 05:58:08,757 INFO Train. Iter 95200 : Loss. 0.00339 +2025-10-13 05:58:35,064 INFO Train. Iter 95300 : Loss. 0.00326 +2025-10-13 05:59:01,390 INFO Train. Iter 95400 : Loss. 0.00342 +2025-10-13 05:59:27,609 INFO Train. Iter 95500 : Loss. 0.00340 +2025-10-13 05:59:54,637 INFO Train. Iter 95600 : Loss. 0.00339 +2025-10-13 06:00:20,839 INFO Train. Iter 95700 : Loss. 0.00338 +2025-10-13 06:00:47,265 INFO Train. Iter 95800 : Loss. 0.00334 +2025-10-13 06:01:13,520 INFO Train. Iter 95900 : Loss. 0.00340 +2025-10-13 06:01:40,541 INFO Train. Iter 96000 : Loss. 0.00336 +2025-10-13 06:02:06,627 INFO Train. Iter 96100 : Loss. 0.00338 +2025-10-13 06:02:33,073 INFO Train. Iter 96200 : Loss. 0.00334 +2025-10-13 06:02:59,353 INFO Train. Iter 96300 : Loss. 0.00328 +2025-10-13 06:03:25,718 INFO Train. Iter 96400 : Loss. 0.00340 +2025-10-13 06:03:52,664 INFO Train. Iter 96500 : Loss. 0.00345 +2025-10-13 06:04:19,020 INFO Train. Iter 96600 : Loss. 0.00329 +2025-10-13 06:04:45,329 INFO Train. Iter 96700 : Loss. 0.00332 +2025-10-13 06:05:11,550 INFO Train. Iter 96800 : Loss. 0.00326 +2025-10-13 06:05:38,057 INFO Train. Iter 96900 : Loss. 0.00327 +2025-10-13 06:06:04,779 INFO Train. Iter 97000 : Loss. 0.00331 +2025-10-13 06:06:31,219 INFO Train. Iter 97100 : Loss. 0.00328 +2025-10-13 06:06:57,572 INFO Train. Iter 97200 : Loss. 0.00332 +2025-10-13 06:07:24,165 INFO Train. Iter 97300 : Loss. 0.00339 +2025-10-13 06:07:50,925 INFO Train. Iter 97400 : Loss. 0.00333 +2025-10-13 06:08:17,286 INFO Train. Iter 97500 : Loss. 0.00331 +2025-10-13 06:08:43,747 INFO Train. Iter 97600 : Loss. 0.00327 +2025-10-13 06:09:10,423 INFO Train. Iter 97700 : Loss. 0.00331 +2025-10-13 06:09:36,609 INFO Train. Iter 97800 : Loss. 0.00335 +2025-10-13 06:10:03,203 INFO Train. Iter 97900 : Loss. 0.00334 +2025-10-13 06:10:29,413 INFO Train. Iter 98000 : Loss. 0.00338 +2025-10-13 06:10:55,861 INFO Train. Iter 98100 : Loss. 0.00322 +2025-10-13 06:11:22,277 INFO Train. Iter 98200 : Loss. 0.00342 +2025-10-13 06:11:48,631 INFO Train. Iter 98300 : Loss. 0.00341 +2025-10-13 06:12:15,487 INFO Train. Iter 98400 : Loss. 0.00330 +2025-10-13 06:12:41,857 INFO Train. Iter 98500 : Loss. 0.00335 +2025-10-13 06:13:08,322 INFO Train. Iter 98600 : Loss. 0.00325 +2025-10-13 06:13:34,748 INFO Train. Iter 98700 : Loss. 0.00336 +2025-10-13 06:14:01,436 INFO Train. Iter 98800 : Loss. 0.00340 +2025-10-13 06:14:27,765 INFO Train. Iter 98900 : Loss. 0.00335 +2025-10-13 06:14:54,189 INFO Train. Iter 99000 : Loss. 0.00330 +2025-10-13 06:15:20,469 INFO Train. Iter 99100 : Loss. 0.00342 +2025-10-13 06:15:47,189 INFO Train. Iter 99200 : Loss. 0.00324 +2025-10-13 06:16:13,856 INFO Train. Iter 99300 : Loss. 0.00335 +2025-10-13 06:16:40,437 INFO Train. Iter 99400 : Loss. 0.00334 +2025-10-13 06:17:06,688 INFO Train. Iter 99500 : Loss. 0.00330 +2025-10-13 06:17:33,000 INFO Train. Iter 99600 : Loss. 0.00341 +2025-10-13 06:17:59,101 INFO Train. Iter 99700 : Loss. 0.00338 +2025-10-13 06:18:25,857 INFO Train. Iter 99800 : Loss. 0.00327 +2025-10-13 06:18:52,350 INFO Train. Iter 99900 : Loss. 0.00334 +2025-10-13 06:19:18,673 INFO Train. Iter 100000 : Loss. 0.00342 +2025-10-13 08:27:39,548 INFO { + "batch_size": 30, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 200000, + "weight_decay": 1e-06 +} +2025-10-13 08:28:30,140 INFO Train. Iter 100100 : Loss. 0.99922 +2025-10-13 08:28:55,680 INFO Train. Iter 100200 : Loss. 1.00001 +2025-10-13 08:29:21,523 INFO Train. Iter 100300 : Loss. 0.99960 +2025-10-13 08:29:47,211 INFO Train. Iter 100400 : Loss. 1.00013 +2025-10-13 08:30:13,394 INFO Train. Iter 100500 : Loss. 0.99954 +2025-10-13 08:30:39,729 INFO Train. Iter 100600 : Loss. 1.00068 +2025-10-13 08:31:05,794 INFO Train. Iter 100700 : Loss. 0.99887 +2025-10-13 08:31:31,478 INFO Train. Iter 100800 : Loss. 0.99917 +2025-10-13 08:31:57,410 INFO Train. Iter 100900 : Loss. 1.00094 +2025-10-13 08:32:23,758 INFO Train. Iter 101000 : Loss. 0.99987 +2025-10-13 08:34:13,861 INFO { + "batch_size": 30, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 800000, + "weight_decay": 1e-06 +} +2025-10-13 08:35:02,621 INFO Train. Iter 100100 : Loss. 0.99922 +2025-10-13 08:35:28,017 INFO Train. Iter 100200 : Loss. 1.00001 +2025-10-13 08:35:53,697 INFO Train. Iter 100300 : Loss. 0.99960 +2025-10-13 08:36:19,324 INFO Train. Iter 100400 : Loss. 1.00013 +2025-10-13 08:36:45,416 INFO Train. Iter 100500 : Loss. 0.99954 +2025-10-13 08:37:11,487 INFO Train. Iter 100600 : Loss. 1.00068 +2025-10-13 08:37:37,562 INFO Train. Iter 100700 : Loss. 0.99887 +2025-10-13 08:38:03,262 INFO Train. Iter 100800 : Loss. 0.99917 +2025-10-13 08:38:29,113 INFO Train. Iter 100900 : Loss. 1.00094 +2025-10-13 08:38:55,158 INFO Train. Iter 101000 : Loss. 0.99987 +2025-10-13 08:39:21,165 INFO Train. Iter 101100 : Loss. 1.00088 +2025-10-13 08:39:46,836 INFO Train. Iter 101200 : Loss. 1.00036 +2025-10-13 08:40:12,641 INFO Train. Iter 101300 : Loss. 1.00060 +2025-10-13 08:40:38,905 INFO Train. Iter 101400 : Loss. 0.99962 +2025-10-13 08:41:04,724 INFO Train. Iter 101500 : Loss. 0.99980 +2025-10-13 08:46:24,817 INFO { + "batch_size": 30, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-13 08:47:14,327 INFO Train. Iter 100 : Loss. 0.12488 +2025-10-13 08:47:38,442 INFO Train. Iter 200 : Loss. 0.12486 +2025-10-13 08:48:02,813 INFO Train. Iter 300 : Loss. 0.12454 +2025-10-13 08:58:11,833 INFO { + "batch_size": 30, + "dataname": "t2m_babel_272", + "decay_option": "all", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "exp_name": "motionstreamer_model", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 0.0001, + "mode": "pos", + "num_diffusion_head_layers": 9, + "num_gpus": 1, + "optimizer": "adamw", + "out_dir": "Experiments/motionstreamer_model", + "resume_pth": null, + "resume_trans": null, + "seed": 123, + "stride_t": 2, + "text": "A man is jogging around.", + "total_iter": 100000, + "weight_decay": 1e-06 +} +2025-10-13 08:59:01,292 INFO Train. Iter 100 : Loss. 0.12488 +2025-10-13 08:59:25,372 INFO Train. Iter 200 : Loss. 0.12486 +2025-10-13 08:59:49,727 INFO Train. Iter 300 : Loss. 0.12454 +2025-10-13 09:00:14,132 INFO Train. Iter 400 : Loss. 0.12420 +2025-10-13 09:00:38,928 INFO Train. Iter 500 : Loss. 0.12355 +2025-10-13 09:01:03,792 INFO Train. Iter 600 : Loss. 0.12294 +2025-10-13 09:01:28,661 INFO Train. Iter 700 : Loss. 0.12178 +2025-10-13 09:01:52,985 INFO Train. Iter 800 : Loss. 0.12067 +2025-10-13 09:02:17,620 INFO Train. Iter 900 : Loss. 0.11950 +2025-10-13 09:02:42,495 INFO Train. Iter 1000 : Loss. 0.11774 +2025-10-13 09:03:07,324 INFO Train. Iter 1100 : Loss. 0.11594 +2025-10-13 09:03:31,728 INFO Train. Iter 1200 : Loss. 0.11364 +2025-10-13 09:03:56,324 INFO Train. Iter 1300 : Loss. 0.11103 +2025-10-13 09:04:21,484 INFO Train. Iter 1400 : Loss. 0.10788 +2025-10-13 09:04:46,218 INFO Train. Iter 1500 : Loss. 0.10439 +2025-10-13 09:05:10,989 INFO Train. Iter 1600 : Loss. 0.10042 +2025-10-13 09:05:35,748 INFO Train. Iter 1700 : Loss. 0.09561 +2025-10-13 09:06:00,774 INFO Train. Iter 1800 : Loss. 0.09025 +2025-10-13 09:06:26,072 INFO Train. Iter 1900 : Loss. 0.08403 +2025-10-13 09:06:50,672 INFO Train. Iter 2000 : Loss. 0.07706 +2025-10-13 09:07:15,383 INFO Train. Iter 2100 : Loss. 0.06929 +2025-10-13 09:07:40,711 INFO Train. Iter 2200 : Loss. 0.06081 +2025-10-13 09:08:05,682 INFO Train. Iter 2300 : Loss. 0.05216 +2025-10-13 09:08:30,997 INFO Train. Iter 2400 : Loss. 0.04340 +2025-10-13 09:08:55,472 INFO Train. Iter 2500 : Loss. 0.03512 +2025-10-13 09:09:20,116 INFO Train. Iter 2600 : Loss. 0.02773 +2025-10-13 09:09:45,001 INFO Train. Iter 2700 : Loss. 0.02181 +2025-10-13 09:10:10,041 INFO Train. Iter 2800 : Loss. 0.01712 +2025-10-13 09:10:34,615 INFO Train. Iter 2900 : Loss. 0.01426 +2025-10-13 09:10:59,275 INFO Train. Iter 3000 : Loss. 0.01273 +2025-10-13 09:11:23,912 INFO Train. Iter 3100 : Loss. 0.01148 +2025-10-13 09:11:48,670 INFO Train. Iter 3200 : Loss. 0.01102 +2025-10-13 09:12:13,746 INFO Train. Iter 3300 : Loss. 0.01066 +2025-10-13 09:12:38,330 INFO Train. Iter 3400 : Loss. 0.01040 +2025-10-13 09:13:03,168 INFO Train. Iter 3500 : Loss. 0.01036 +2025-10-13 09:13:27,839 INFO Train. Iter 3600 : Loss. 0.01017 +2025-10-13 09:13:52,309 INFO Train. Iter 3700 : Loss. 0.01018 +2025-10-13 09:14:17,831 INFO Train. Iter 3800 : Loss. 0.00998 +2025-10-13 09:14:42,326 INFO Train. Iter 3900 : Loss. 0.00984 +2025-10-13 09:15:07,101 INFO Train. Iter 4000 : Loss. 0.00962 +2025-10-13 09:15:31,913 INFO Train. Iter 4100 : Loss. 0.00983 +2025-10-13 09:15:57,216 INFO Train. Iter 4200 : Loss. 0.00959 +2025-10-13 09:16:21,715 INFO Train. Iter 4300 : Loss. 0.00953 +2025-10-13 09:16:46,296 INFO Train. Iter 4400 : Loss. 0.00935 +2025-10-13 09:17:10,759 INFO Train. Iter 4500 : Loss. 0.00928 +2025-10-13 09:17:35,578 INFO Train. Iter 4600 : Loss. 0.00919 +2025-10-13 09:18:00,713 INFO Train. Iter 4700 : Loss. 0.00909 +2025-10-13 09:18:25,283 INFO Train. Iter 4800 : Loss. 0.00897 +2025-10-13 09:18:49,889 INFO Train. Iter 4900 : Loss. 0.00878 +2025-10-13 09:19:14,821 INFO Train. Iter 5000 : Loss. 0.00940 +2025-10-13 09:19:39,471 INFO Train. Iter 5100 : Loss. 0.00848 +2025-10-13 09:20:04,896 INFO Train. Iter 5200 : Loss. 0.00820 +2025-10-13 09:20:29,641 INFO Train. Iter 5300 : Loss. 0.00843 +2025-10-13 09:20:54,555 INFO Train. Iter 5400 : Loss. 0.00817 +2025-10-13 09:21:19,484 INFO Train. Iter 5500 : Loss. 0.00775 +2025-10-13 09:21:44,546 INFO Train. Iter 5600 : Loss. 0.00861 +2025-10-13 09:22:08,987 INFO Train. Iter 5700 : Loss. 0.00780 +2025-10-13 09:22:33,750 INFO Train. Iter 5800 : Loss. 0.00731 +2025-10-13 09:22:58,384 INFO Train. Iter 5900 : Loss. 0.00716 +2025-10-13 09:23:23,189 INFO Train. Iter 6000 : Loss. 0.00812 +2025-10-13 09:23:48,388 INFO Train. Iter 6100 : Loss. 0.00710 +2025-10-13 09:24:13,217 INFO Train. Iter 6200 : Loss. 0.00645 +2025-10-13 09:24:37,964 INFO Train. Iter 6300 : Loss. 0.00669 +2025-10-13 09:25:02,669 INFO Train. Iter 6400 : Loss. 0.00648 +2025-10-13 09:25:27,335 INFO Train. Iter 6500 : Loss. 0.00617 +2025-10-13 09:25:52,980 INFO Train. Iter 6600 : Loss. 0.00828 +2025-10-13 09:26:17,901 INFO Train. Iter 6700 : Loss. 0.00672 +2025-10-13 09:26:42,422 INFO Train. Iter 6800 : Loss. 0.00600 +2025-10-13 09:27:07,052 INFO Train. Iter 6900 : Loss. 0.00618 +2025-10-13 09:27:32,285 INFO Train. Iter 7000 : Loss. 0.00577 +2025-10-13 09:27:57,041 INFO Train. Iter 7100 : Loss. 0.00596 +2025-10-13 09:28:21,551 INFO Train. Iter 7200 : Loss. 0.00567 +2025-10-13 09:28:46,169 INFO Train. Iter 7300 : Loss. 0.00583 +2025-10-13 09:29:10,788 INFO Train. Iter 7400 : Loss. 0.00548 +2025-10-13 09:29:36,166 INFO Train. Iter 7500 : Loss. 0.00598 +2025-10-13 09:30:01,198 INFO Train. Iter 7600 : Loss. 0.00675 +2025-10-13 09:30:26,221 INFO Train. Iter 7700 : Loss. 0.00566 +2025-10-13 09:30:51,159 INFO Train. Iter 7800 : Loss. 0.00584 +2025-10-13 09:31:15,845 INFO Train. Iter 7900 : Loss. 0.00545 +2025-10-13 09:31:41,478 INFO Train. Iter 8000 : Loss. 0.00531 +2025-10-13 09:32:06,219 INFO Train. Iter 8100 : Loss. 0.00535 +2025-10-13 09:32:30,918 INFO Train. Iter 8200 : Loss. 0.00518 +2025-10-13 09:32:55,505 INFO Train. Iter 8300 : Loss. 0.00546 +2025-10-13 09:33:20,467 INFO Train. Iter 8400 : Loss. 0.00517 +2025-10-13 09:33:45,012 INFO Train. Iter 8500 : Loss. 0.00525 +2025-10-13 09:34:09,978 INFO Train. Iter 8600 : Loss. 0.00510 +2025-10-13 09:34:34,515 INFO Train. Iter 8700 : Loss. 0.00517 +2025-10-13 09:34:59,429 INFO Train. Iter 8800 : Loss. 0.00532 +2025-10-13 09:35:24,515 INFO Train. Iter 8900 : Loss. 0.00507 +2025-10-13 09:35:49,158 INFO Train. Iter 9000 : Loss. 0.00487 +2025-10-13 09:36:13,974 INFO Train. Iter 9100 : Loss. 0.00479 +2025-10-13 09:36:38,541 INFO Train. Iter 9200 : Loss. 0.00493 +2025-10-13 09:37:02,947 INFO Train. Iter 9300 : Loss. 0.00484 +2025-10-13 09:37:28,125 INFO Train. Iter 9400 : Loss. 0.00489 +2025-10-13 09:37:52,851 INFO Train. Iter 9500 : Loss. 0.00470 +2025-10-13 09:38:17,609 INFO Train. Iter 9600 : Loss. 0.00479 +2025-10-13 09:38:42,052 INFO Train. Iter 9700 : Loss. 0.00476 +2025-10-13 09:39:07,101 INFO Train. Iter 9800 : Loss. 0.00472 +2025-10-13 09:39:31,912 INFO Train. Iter 9900 : Loss. 0.00502 +2025-10-13 09:39:56,240 INFO Train. Iter 10000 : Loss. 0.00454 +2025-10-13 09:40:23,269 INFO Train. Iter 10100 : Loss. 0.00449 +2025-10-13 09:40:48,011 INFO Train. Iter 10200 : Loss. 0.00447 +2025-10-13 09:41:13,274 INFO Train. Iter 10300 : Loss. 0.00444 +2025-10-13 09:41:37,886 INFO Train. Iter 10400 : Loss. 0.00433 +2025-10-13 09:42:02,633 INFO Train. Iter 10500 : Loss. 0.00472 +2025-10-13 09:42:27,268 INFO Train. Iter 10600 : Loss. 0.00462 +2025-10-13 09:42:51,769 INFO Train. Iter 10700 : Loss. 0.00444 +2025-10-13 09:43:17,270 INFO Train. Iter 10800 : Loss. 0.00428 +2025-10-13 09:43:41,795 INFO Train. Iter 10900 : Loss. 0.00423 +2025-10-13 09:44:06,202 INFO Train. Iter 11000 : Loss. 0.00420 +2025-10-13 09:44:30,892 INFO Train. Iter 11100 : Loss. 0.00429 +2025-10-13 09:44:56,219 INFO Train. Iter 11200 : Loss. 0.00419 +2025-10-13 09:45:20,840 INFO Train. Iter 11300 : Loss. 0.00416 +2025-10-13 09:45:45,114 INFO Train. Iter 11400 : Loss. 0.00426 +2025-10-13 09:46:10,040 INFO Train. Iter 11500 : Loss. 0.00421 +2025-10-13 09:46:34,827 INFO Train. Iter 11600 : Loss. 0.00402 +2025-10-13 09:47:00,023 INFO Train. Iter 11700 : Loss. 0.00418 +2025-10-13 09:47:24,827 INFO Train. Iter 11800 : Loss. 0.00408 +2025-10-13 09:47:49,550 INFO Train. Iter 11900 : Loss. 0.00405 +2025-10-13 09:48:14,509 INFO Train. Iter 12000 : Loss. 0.00390 +2025-10-13 09:48:39,306 INFO Train. Iter 12100 : Loss. 0.00395 +2025-10-13 09:49:04,730 INFO Train. Iter 12200 : Loss. 0.00398 +2025-10-13 09:49:29,522 INFO Train. Iter 12300 : Loss. 0.00388 +2025-10-13 09:49:54,567 INFO Train. Iter 12400 : Loss. 0.00403 +2025-10-13 09:50:19,300 INFO Train. Iter 12500 : Loss. 0.00385 +2025-10-13 09:50:44,408 INFO Train. Iter 12600 : Loss. 0.00383 +2025-10-13 09:51:09,005 INFO Train. Iter 12700 : Loss. 0.00398 +2025-10-13 09:51:33,749 INFO Train. Iter 12800 : Loss. 0.00389 +2025-10-13 09:51:58,275 INFO Train. Iter 12900 : Loss. 0.00373 +2025-10-13 09:52:23,045 INFO Train. Iter 13000 : Loss. 0.00378 +2025-10-13 09:52:48,255 INFO Train. Iter 13100 : Loss. 0.00382 +2025-10-13 09:53:12,477 INFO Train. Iter 13200 : Loss. 0.00377 +2025-10-13 09:53:37,171 INFO Train. Iter 13300 : Loss. 0.00365 +2025-10-13 09:54:02,242 INFO Train. Iter 13400 : Loss. 0.00404 +2025-10-13 09:54:26,997 INFO Train. Iter 13500 : Loss. 0.00377 +2025-10-13 09:54:52,435 INFO Train. Iter 13600 : Loss. 0.00363 +2025-10-13 09:55:17,178 INFO Train. Iter 13700 : Loss. 0.00374 +2025-10-13 09:55:41,770 INFO Train. Iter 13800 : Loss. 0.00373 +2025-10-13 09:56:06,410 INFO Train. Iter 13900 : Loss. 0.00366 +2025-10-13 09:56:31,395 INFO Train. Iter 14000 : Loss. 0.00371 +2025-10-13 09:56:56,184 INFO Train. Iter 14100 : Loss. 0.00360 +2025-10-13 09:57:20,924 INFO Train. Iter 14200 : Loss. 0.00360 +2025-10-13 09:57:45,721 INFO Train. Iter 14300 : Loss. 0.00369 +2025-10-13 09:58:10,336 INFO Train. Iter 14400 : Loss. 0.00356 +2025-10-13 09:58:35,371 INFO Train. Iter 14500 : Loss. 0.00352 +2025-10-13 09:59:00,109 INFO Train. Iter 14600 : Loss. 0.00361 +2025-10-13 09:59:24,752 INFO Train. Iter 14700 : Loss. 0.00364 +2025-10-13 09:59:49,515 INFO Train. Iter 14800 : Loss. 0.00367 +2025-10-13 10:00:14,133 INFO Train. Iter 14900 : Loss. 0.00363 +2025-10-13 10:00:39,716 INFO Train. Iter 15000 : Loss. 0.00361 +2025-10-13 10:01:04,128 INFO Train. Iter 15100 : Loss. 0.00353 +2025-10-13 10:01:28,936 INFO Train. Iter 15200 : Loss. 0.00352 +2025-10-13 10:01:53,591 INFO Train. Iter 15300 : Loss. 0.00365 +2025-10-13 10:02:18,909 INFO Train. Iter 15400 : Loss. 0.00355 +2025-10-13 10:02:43,634 INFO Train. Iter 15500 : Loss. 0.00357 +2025-10-13 10:03:08,209 INFO Train. Iter 15600 : Loss. 0.00353 +2025-10-13 10:03:32,843 INFO Train. Iter 15700 : Loss. 0.00352 +2025-10-13 10:03:57,742 INFO Train. Iter 15800 : Loss. 0.00357 +2025-10-13 10:04:23,015 INFO Train. Iter 15900 : Loss. 0.00340 +2025-10-13 10:04:47,935 INFO Train. Iter 16000 : Loss. 0.00347 +2025-10-13 10:05:12,456 INFO Train. Iter 16100 : Loss. 0.00353 +2025-10-13 10:05:37,447 INFO Train. Iter 16200 : Loss. 0.00354 +2025-10-13 10:06:02,232 INFO Train. Iter 16300 : Loss. 0.00341 +2025-10-13 10:06:27,501 INFO Train. Iter 16400 : Loss. 0.00356 +2025-10-13 10:06:51,985 INFO Train. Iter 16500 : Loss. 0.00347 +2025-10-13 10:07:16,944 INFO Train. Iter 16600 : Loss. 0.00361 +2025-10-13 10:07:41,579 INFO Train. Iter 16700 : Loss. 0.00342 +2025-10-13 10:08:06,910 INFO Train. Iter 16800 : Loss. 0.00346 +2025-10-13 10:08:31,599 INFO Train. Iter 16900 : Loss. 0.00354 +2025-10-13 10:08:56,228 INFO Train. Iter 17000 : Loss. 0.00349 +2025-10-13 10:09:21,225 INFO Train. Iter 17100 : Loss. 0.00345 +2025-10-13 10:09:45,921 INFO Train. Iter 17200 : Loss. 0.00341 +2025-10-13 10:10:11,109 INFO Train. Iter 17300 : Loss. 0.00332 +2025-10-13 10:10:36,125 INFO Train. Iter 17400 : Loss. 0.00357 +2025-10-13 10:11:00,848 INFO Train. Iter 17500 : Loss. 0.00336 +2025-10-13 10:11:25,884 INFO Train. Iter 17600 : Loss. 0.00335 +2025-10-13 10:11:50,701 INFO Train. Iter 17700 : Loss. 0.00341 +2025-10-13 10:12:15,840 INFO Train. Iter 17800 : Loss. 0.00347 +2025-10-13 10:12:40,675 INFO Train. Iter 17900 : Loss. 0.00334 +2025-10-13 10:13:05,282 INFO Train. Iter 18000 : Loss. 0.00341 +2025-10-13 10:13:30,224 INFO Train. Iter 18100 : Loss. 0.00339 +2025-10-13 10:13:55,392 INFO Train. Iter 18200 : Loss. 0.00343 +2025-10-13 10:14:20,141 INFO Train. Iter 18300 : Loss. 0.00342 +2025-10-13 10:14:44,903 INFO Train. Iter 18400 : Loss. 0.00336 +2025-10-13 10:15:09,458 INFO Train. Iter 18500 : Loss. 0.00342 +2025-10-13 10:15:34,115 INFO Train. Iter 18600 : Loss. 0.00346 +2025-10-13 10:15:59,422 INFO Train. Iter 18700 : Loss. 0.00339 +2025-10-13 10:16:23,974 INFO Train. Iter 18800 : Loss. 0.00336 +2025-10-13 10:16:48,542 INFO Train. Iter 18900 : Loss. 0.00339 +2025-10-13 10:17:13,048 INFO Train. Iter 19000 : Loss. 0.00326 +2025-10-13 10:17:37,799 INFO Train. Iter 19100 : Loss. 0.00332 +2025-10-13 10:18:03,170 INFO Train. Iter 19200 : Loss. 0.00333 +2025-10-13 10:18:28,037 INFO Train. Iter 19300 : Loss. 0.00332 +2025-10-13 10:18:52,990 INFO Train. Iter 19400 : Loss. 0.00330 +2025-10-13 10:19:17,516 INFO Train. Iter 19500 : Loss. 0.00338 +2025-10-13 10:19:42,707 INFO Train. Iter 19600 : Loss. 0.00345 +2025-10-13 10:20:07,425 INFO Train. Iter 19700 : Loss. 0.00338 +2025-10-13 10:20:32,696 INFO Train. Iter 19800 : Loss. 0.00343 +2025-10-13 10:20:57,636 INFO Train. Iter 19900 : Loss. 0.00335 +2025-10-13 10:21:22,398 INFO Train. Iter 20000 : Loss. 0.00332 +2025-10-13 10:21:51,529 INFO Train. Iter 20100 : Loss. 0.00336 +2025-10-13 10:22:16,361 INFO Train. Iter 20200 : Loss. 0.00329 +2025-10-13 10:22:41,042 INFO Train. Iter 20300 : Loss. 0.00336 +2025-10-13 10:23:05,808 INFO Train. Iter 20400 : Loss. 0.00331 +2025-10-13 10:23:30,481 INFO Train. Iter 20500 : Loss. 0.00343 +2025-10-13 10:23:55,662 INFO Train. Iter 20600 : Loss. 0.00328 +2025-10-13 10:24:20,636 INFO Train. Iter 20700 : Loss. 0.00334 +2025-10-13 10:24:45,360 INFO Train. Iter 20800 : Loss. 0.00328 +2025-10-13 10:25:09,732 INFO Train. Iter 20900 : Loss. 0.00324 +2025-10-13 10:25:34,875 INFO Train. Iter 21000 : Loss. 0.00339 +2025-10-13 10:25:59,585 INFO Train. Iter 21100 : Loss. 0.00334 +2025-10-13 10:26:23,998 INFO Train. Iter 21200 : Loss. 0.00334 +2025-10-13 10:26:48,874 INFO Train. Iter 21300 : Loss. 0.00328 +2025-10-13 10:27:13,834 INFO Train. Iter 21400 : Loss. 0.00348 +2025-10-13 10:27:39,038 INFO Train. Iter 21500 : Loss. 0.00328 +2025-10-13 10:28:03,909 INFO Train. Iter 21600 : Loss. 0.00333 +2025-10-13 10:28:28,560 INFO Train. Iter 21700 : Loss. 0.00341 +2025-10-13 10:28:53,416 INFO Train. Iter 21800 : Loss. 0.00328 +2025-10-13 10:29:17,962 INFO Train. Iter 21900 : Loss. 0.00325 +2025-10-13 10:29:43,264 INFO Train. Iter 22000 : Loss. 0.00331 +2025-10-13 10:30:07,694 INFO Train. Iter 22100 : Loss. 0.00317 +2025-10-13 10:30:32,384 INFO Train. Iter 22200 : Loss. 0.00336 +2025-10-13 10:30:57,068 INFO Train. Iter 22300 : Loss. 0.00340 +2025-10-13 10:31:22,386 INFO Train. Iter 22400 : Loss. 0.00332 +2025-10-13 10:31:47,349 INFO Train. Iter 22500 : Loss. 0.00330 +2025-10-13 10:32:12,031 INFO Train. Iter 22600 : Loss. 0.00327 +2025-10-13 10:32:36,755 INFO Train. Iter 22700 : Loss. 0.00336 +2025-10-13 10:33:01,426 INFO Train. Iter 22800 : Loss. 0.00324 +2025-10-13 10:33:26,514 INFO Train. Iter 22900 : Loss. 0.00329 +2025-10-13 10:33:51,348 INFO Train. Iter 23000 : Loss. 0.00335 +2025-10-13 10:34:15,880 INFO Train. Iter 23100 : Loss. 0.00323 +2025-10-13 10:34:40,868 INFO Train. Iter 23200 : Loss. 0.00329 +2025-10-13 10:35:05,690 INFO Train. Iter 23300 : Loss. 0.00336 +2025-10-13 10:35:30,761 INFO Train. Iter 23400 : Loss. 0.00318 +2025-10-13 10:35:55,270 INFO Train. Iter 23500 : Loss. 0.00341 +2025-10-13 10:36:20,251 INFO Train. Iter 23600 : Loss. 0.00315 +2025-10-13 10:36:44,783 INFO Train. Iter 23700 : Loss. 0.00332 +2025-10-13 10:37:09,831 INFO Train. Iter 23800 : Loss. 0.00340 +2025-10-13 10:37:34,456 INFO Train. Iter 23900 : Loss. 0.00332 +2025-10-13 10:37:59,190 INFO Train. Iter 24000 : Loss. 0.00328 +2025-10-13 10:38:23,877 INFO Train. Iter 24100 : Loss. 0.00333 +2025-10-13 10:38:48,527 INFO Train. Iter 24200 : Loss. 0.00323 +2025-10-13 10:39:13,642 INFO Train. Iter 24300 : Loss. 0.00336 +2025-10-13 10:39:38,478 INFO Train. Iter 24400 : Loss. 0.00328 +2025-10-13 10:40:02,929 INFO Train. Iter 24500 : Loss. 0.00331 +2025-10-13 10:40:27,800 INFO Train. Iter 24600 : Loss. 0.00319 +2025-10-13 10:40:53,331 INFO Train. Iter 24700 : Loss. 0.00325 +2025-10-13 10:41:17,882 INFO Train. Iter 24800 : Loss. 0.00329 +2025-10-13 10:41:42,562 INFO Train. Iter 24900 : Loss. 0.00337 +2025-10-13 10:42:07,252 INFO Train. Iter 25000 : Loss. 0.00322 +2025-10-13 10:42:31,512 INFO Train. Iter 25100 : Loss. 0.00332 +2025-10-13 10:42:56,637 INFO Train. Iter 25200 : Loss. 0.00328 +2025-10-13 10:43:21,392 INFO Train. Iter 25300 : Loss. 0.00330 +2025-10-13 10:43:45,990 INFO Train. Iter 25400 : Loss. 0.00329 +2025-10-13 10:44:10,612 INFO Train. Iter 25500 : Loss. 0.00321 +2025-10-13 10:44:35,340 INFO Train. Iter 25600 : Loss. 0.00321 +2025-10-13 10:45:00,213 INFO Train. Iter 25700 : Loss. 0.00327 +2025-10-13 10:45:25,037 INFO Train. Iter 25800 : Loss. 0.00330 +2025-10-13 10:45:49,704 INFO Train. Iter 25900 : Loss. 0.00324 +2025-10-13 10:46:14,282 INFO Train. Iter 26000 : Loss. 0.00330 +2025-10-13 10:46:39,429 INFO Train. Iter 26100 : Loss. 0.00328 +2025-10-13 10:47:04,596 INFO Train. Iter 26200 : Loss. 0.00328 +2025-10-13 10:47:29,063 INFO Train. Iter 26300 : Loss. 0.00329 +2025-10-13 10:47:53,649 INFO Train. Iter 26400 : Loss. 0.00331 +2025-10-13 10:48:18,468 INFO Train. Iter 26500 : Loss. 0.00319 +2025-10-13 10:48:43,653 INFO Train. Iter 26600 : Loss. 0.00322 +2025-10-13 10:49:08,302 INFO Train. Iter 26700 : Loss. 0.00328 +2025-10-13 10:49:32,879 INFO Train. Iter 26800 : Loss. 0.00330 +2025-10-13 10:49:57,595 INFO Train. Iter 26900 : Loss. 0.00328 +2025-10-13 10:50:22,131 INFO Train. Iter 27000 : Loss. 0.00326 +2025-10-13 10:50:46,750 INFO Train. Iter 27100 : Loss. 0.00326 +2025-10-13 10:51:11,533 INFO Train. Iter 27200 : Loss. 0.00326 +2025-10-13 10:51:36,076 INFO Train. Iter 27300 : Loss. 0.00329 +2025-10-13 10:52:00,702 INFO Train. Iter 27400 : Loss. 0.00339 +2025-10-13 10:52:25,882 INFO Train. Iter 27500 : Loss. 0.00332 +2025-10-13 10:52:50,374 INFO Train. Iter 27600 : Loss. 0.00331 +2025-10-13 10:53:15,386 INFO Train. Iter 27700 : Loss. 0.00328 +2025-10-13 10:53:40,195 INFO Train. Iter 27800 : Loss. 0.00329 +2025-10-13 10:54:04,768 INFO Train. Iter 27900 : Loss. 0.00322 +2025-10-13 10:54:29,562 INFO Train. Iter 28000 : Loss. 0.00325 +2025-10-13 10:54:54,267 INFO Train. Iter 28100 : Loss. 0.00319 +2025-10-13 10:55:18,920 INFO Train. Iter 28200 : Loss. 0.00337 +2025-10-13 10:55:43,487 INFO Train. Iter 28300 : Loss. 0.00337 +2025-10-13 10:56:08,360 INFO Train. Iter 28400 : Loss. 0.00331 +2025-10-13 10:56:33,855 INFO Train. Iter 28500 : Loss. 0.00322 +2025-10-13 10:56:58,342 INFO Train. Iter 28600 : Loss. 0.00326 +2025-10-13 10:57:23,286 INFO Train. Iter 28700 : Loss. 0.00329 +2025-10-13 10:57:48,085 INFO Train. Iter 28800 : Loss. 0.00319 +2025-10-13 10:58:13,020 INFO Train. Iter 28900 : Loss. 0.00327 +2025-10-13 10:58:37,779 INFO Train. Iter 29000 : Loss. 0.00325 +2025-10-13 10:59:02,444 INFO Train. Iter 29100 : Loss. 0.00333 +2025-10-13 10:59:26,917 INFO Train. Iter 29200 : Loss. 0.00330 +2025-10-13 10:59:51,539 INFO Train. Iter 29300 : Loss. 0.00324 +2025-10-13 11:00:17,212 INFO Train. Iter 29400 : Loss. 0.00328 +2025-10-13 11:00:41,912 INFO Train. Iter 29500 : Loss. 0.00321 +2025-10-13 11:01:06,684 INFO Train. Iter 29600 : Loss. 0.00328 +2025-10-13 11:01:31,234 INFO Train. Iter 29700 : Loss. 0.00325 +2025-10-13 11:01:55,430 INFO Train. Iter 29800 : Loss. 0.00326 +2025-10-13 11:02:20,453 INFO Train. Iter 29900 : Loss. 0.00326 +2025-10-13 11:02:45,104 INFO Train. Iter 30000 : Loss. 0.00328 +2025-10-13 11:03:12,372 INFO Train. Iter 30100 : Loss. 0.00321 +2025-10-13 11:03:37,019 INFO Train. Iter 30200 : Loss. 0.00330 +2025-10-13 11:04:02,081 INFO Train. Iter 30300 : Loss. 0.00335 +2025-10-13 11:04:26,633 INFO Train. Iter 30400 : Loss. 0.00323 +2025-10-13 11:04:51,449 INFO Train. Iter 30500 : Loss. 0.00327 +2025-10-13 11:05:16,163 INFO Train. Iter 30600 : Loss. 0.00328 +2025-10-13 11:05:40,830 INFO Train. Iter 30700 : Loss. 0.00327 +2025-10-13 11:06:06,345 INFO Train. Iter 30800 : Loss. 0.00332 +2025-10-13 11:06:31,244 INFO Train. Iter 30900 : Loss. 0.00323 +2025-10-13 11:06:56,436 INFO Train. Iter 31000 : Loss. 0.00320 +2025-10-13 11:07:20,734 INFO Train. Iter 31100 : Loss. 0.00326 +2025-10-13 11:07:45,408 INFO Train. Iter 31200 : Loss. 0.00324 +2025-10-13 11:08:10,703 INFO Train. Iter 31300 : Loss. 0.00333 +2025-10-13 11:08:35,372 INFO Train. Iter 31400 : Loss. 0.00329 +2025-10-13 11:08:59,859 INFO Train. Iter 31500 : Loss. 0.00325 +2025-10-13 11:09:24,604 INFO Train. Iter 31600 : Loss. 0.00325 +2025-10-13 11:09:49,565 INFO Train. Iter 31700 : Loss. 0.00322 +2025-10-13 11:10:14,201 INFO Train. Iter 31800 : Loss. 0.00331 +2025-10-13 11:10:39,008 INFO Train. Iter 31900 : Loss. 0.00326 +2025-10-13 11:11:03,513 INFO Train. Iter 32000 : Loss. 0.00321 +2025-10-13 11:11:28,031 INFO Train. Iter 32100 : Loss. 0.00316 +2025-10-13 11:11:53,604 INFO Train. Iter 32200 : Loss. 0.00323 +2025-10-13 11:12:18,222 INFO Train. Iter 32300 : Loss. 0.00327 +2025-10-13 11:12:43,135 INFO Train. Iter 32400 : Loss. 0.00329 +2025-10-13 11:13:07,645 INFO Train. Iter 32500 : Loss. 0.00326 +2025-10-13 11:13:32,296 INFO Train. Iter 32600 : Loss. 0.00325 +2025-10-13 11:13:57,560 INFO Train. Iter 32700 : Loss. 0.00325 +2025-10-13 11:14:22,175 INFO Train. Iter 32800 : Loss. 0.00310 +2025-10-13 11:14:46,635 INFO Train. Iter 32900 : Loss. 0.00326 +2025-10-13 11:15:11,113 INFO Train. Iter 33000 : Loss. 0.00327 +2025-10-13 11:15:36,163 INFO Train. Iter 33100 : Loss. 0.00328 +2025-10-13 11:16:00,814 INFO Train. Iter 33200 : Loss. 0.00326 +2025-10-13 11:16:25,339 INFO Train. Iter 33300 : Loss. 0.00327 +2025-10-13 11:16:50,028 INFO Train. Iter 33400 : Loss. 0.00321 +2025-10-13 11:17:14,377 INFO Train. Iter 33500 : Loss. 0.00328 +2025-10-13 11:17:39,686 INFO Train. Iter 33600 : Loss. 0.00328 +2025-10-13 11:18:04,420 INFO Train. Iter 33700 : Loss. 0.00323 +2025-10-13 11:18:28,931 INFO Train. Iter 33800 : Loss. 0.00320 +2025-10-13 11:18:53,485 INFO Train. Iter 33900 : Loss. 0.00327 +2025-10-13 11:19:18,278 INFO Train. Iter 34000 : Loss. 0.00326 +2025-10-13 11:19:43,365 INFO Train. Iter 34100 : Loss. 0.00324 +2025-10-13 11:20:08,332 INFO Train. Iter 34200 : Loss. 0.00325 +2025-10-13 11:20:32,942 INFO Train. Iter 34300 : Loss. 0.00323 +2025-10-13 11:20:57,421 INFO Train. Iter 34400 : Loss. 0.00318 +2025-10-13 11:21:22,623 INFO Train. Iter 34500 : Loss. 0.00323 +2025-10-13 11:21:47,229 INFO Train. Iter 34600 : Loss. 0.00321 +2025-10-13 11:22:11,874 INFO Train. Iter 34700 : Loss. 0.00332 +2025-10-13 11:22:36,524 INFO Train. Iter 34800 : Loss. 0.00321 +2025-10-13 11:23:00,934 INFO Train. Iter 34900 : Loss. 0.00317 +2025-10-13 11:23:25,953 INFO Train. Iter 35000 : Loss. 0.00315 +2025-10-13 11:23:50,569 INFO Train. Iter 35100 : Loss. 0.00331 +2025-10-13 11:24:15,393 INFO Train. Iter 35200 : Loss. 0.00330 +2025-10-13 11:24:39,729 INFO Train. Iter 35300 : Loss. 0.00325 +2025-10-13 11:25:04,373 INFO Train. Iter 35400 : Loss. 0.00327 +2025-10-13 11:25:29,695 INFO Train. Iter 35500 : Loss. 0.00324 +2025-10-13 11:25:54,148 INFO Train. Iter 35600 : Loss. 0.00315 +2025-10-13 11:26:18,957 INFO Train. Iter 35700 : Loss. 0.00315 +2025-10-13 11:26:43,601 INFO Train. Iter 35800 : Loss. 0.00323 +2025-10-13 11:27:08,665 INFO Train. Iter 35900 : Loss. 0.00333 +2025-10-13 11:27:33,347 INFO Train. Iter 36000 : Loss. 0.00322 +2025-10-13 11:27:57,902 INFO Train. Iter 36100 : Loss. 0.00317 +2025-10-13 11:28:22,152 INFO Train. Iter 36200 : Loss. 0.00328 +2025-10-13 11:28:46,711 INFO Train. Iter 36300 : Loss. 0.00330 +2025-10-13 11:29:11,912 INFO Train. Iter 36400 : Loss. 0.00322 +2025-10-13 11:29:36,456 INFO Train. Iter 36500 : Loss. 0.00319 +2025-10-13 11:30:00,992 INFO Train. Iter 36600 : Loss. 0.00334 +2025-10-13 11:30:25,543 INFO Train. Iter 36700 : Loss. 0.00319 +2025-10-13 11:30:50,178 INFO Train. Iter 36800 : Loss. 0.00325 +2025-10-13 11:31:15,329 INFO Train. Iter 36900 : Loss. 0.00320 +2025-10-13 11:31:39,979 INFO Train. Iter 37000 : Loss. 0.00322 +2025-10-13 11:32:04,467 INFO Train. Iter 37100 : Loss. 0.00324 +2025-10-13 11:32:29,046 INFO Train. Iter 37200 : Loss. 0.00326 +2025-10-13 11:32:54,061 INFO Train. Iter 37300 : Loss. 0.00326 +2025-10-13 11:33:18,882 INFO Train. Iter 37400 : Loss. 0.00321 +2025-10-13 11:33:43,537 INFO Train. Iter 37500 : Loss. 0.00319 +2025-10-13 11:34:08,033 INFO Train. Iter 37600 : Loss. 0.00329 +2025-10-13 11:34:32,801 INFO Train. Iter 37700 : Loss. 0.00315 +2025-10-13 11:34:57,886 INFO Train. Iter 37800 : Loss. 0.00326 +2025-10-13 11:35:22,501 INFO Train. Iter 37900 : Loss. 0.00322 +2025-10-13 11:35:46,916 INFO Train. Iter 38000 : Loss. 0.00328 +2025-10-13 11:36:11,477 INFO Train. Iter 38100 : Loss. 0.00321 +2025-10-13 11:36:35,840 INFO Train. Iter 38200 : Loss. 0.00317 +2025-10-13 11:37:00,916 INFO Train. Iter 38300 : Loss. 0.00326 +2025-10-13 11:37:25,299 INFO Train. Iter 38400 : Loss. 0.00323 +2025-10-13 11:37:50,115 INFO Train. Iter 38500 : Loss. 0.00319 +2025-10-13 11:38:14,509 INFO Train. Iter 38600 : Loss. 0.00328 +2025-10-13 11:38:39,745 INFO Train. Iter 38700 : Loss. 0.00335 +2025-10-13 11:39:04,193 INFO Train. Iter 38800 : Loss. 0.00326 +2025-10-13 11:39:28,846 INFO Train. Iter 38900 : Loss. 0.00319 +2025-10-13 11:39:53,336 INFO Train. Iter 39000 : Loss. 0.00318 +2025-10-13 11:40:18,212 INFO Train. Iter 39100 : Loss. 0.00328 +2025-10-13 11:40:43,328 INFO Train. Iter 39200 : Loss. 0.00322 +2025-10-13 11:41:07,942 INFO Train. Iter 39300 : Loss. 0.00322 +2025-10-13 11:41:32,372 INFO Train. Iter 39400 : Loss. 0.00322 +2025-10-13 11:41:56,831 INFO Train. Iter 39500 : Loss. 0.00322 +2025-10-13 11:42:21,600 INFO Train. Iter 39600 : Loss. 0.00326 +2025-10-13 11:42:46,882 INFO Train. Iter 39700 : Loss. 0.00321 +2025-10-13 11:43:11,347 INFO Train. Iter 39800 : Loss. 0.00327 +2025-10-13 11:43:36,022 INFO Train. Iter 39900 : Loss. 0.00332 +2025-10-13 11:44:00,736 INFO Train. Iter 40000 : Loss. 0.00318 +2025-10-13 11:44:28,097 INFO Train. Iter 40100 : Loss. 0.00328 +2025-10-13 11:44:52,581 INFO Train. Iter 40200 : Loss. 0.00318 +2025-10-13 11:45:17,455 INFO Train. Iter 40300 : Loss. 0.00326 +2025-10-13 11:45:41,862 INFO Train. Iter 40400 : Loss. 0.00324 +2025-10-13 11:46:06,545 INFO Train. Iter 40500 : Loss. 0.00327 +2025-10-13 11:46:31,974 INFO Train. Iter 40600 : Loss. 0.00327 +2025-10-13 11:46:56,669 INFO Train. Iter 40700 : Loss. 0.00322 +2025-10-13 11:47:21,176 INFO Train. Iter 40800 : Loss. 0.00321 +2025-10-13 11:47:45,620 INFO Train. Iter 40900 : Loss. 0.00326 +2025-10-13 11:48:10,304 INFO Train. Iter 41000 : Loss. 0.00328 +2025-10-13 11:48:35,238 INFO Train. Iter 41100 : Loss. 0.00323 +2025-10-13 11:48:59,843 INFO Train. Iter 41200 : Loss. 0.00326 +2025-10-13 11:49:24,490 INFO Train. Iter 41300 : Loss. 0.00322 +2025-10-13 11:49:49,219 INFO Train. Iter 41400 : Loss. 0.00327 +2025-10-13 11:50:14,362 INFO Train. Iter 41500 : Loss. 0.00322 +2025-10-13 11:50:39,110 INFO Train. Iter 41600 : Loss. 0.00329 +2025-10-13 11:51:03,744 INFO Train. Iter 41700 : Loss. 0.00335 +2025-10-13 11:51:28,267 INFO Train. Iter 41800 : Loss. 0.00319 +2025-10-13 11:51:53,011 INFO Train. Iter 41900 : Loss. 0.00323 +2025-10-13 11:52:18,155 INFO Train. Iter 42000 : Loss. 0.00322 +2025-10-13 11:52:43,046 INFO Train. Iter 42100 : Loss. 0.00326 +2025-10-13 11:53:07,836 INFO Train. Iter 42200 : Loss. 0.00327 +2025-10-13 11:53:32,261 INFO Train. Iter 42300 : Loss. 0.00323 +2025-10-13 11:53:56,776 INFO Train. Iter 42400 : Loss. 0.00325 +2025-10-13 11:54:21,959 INFO Train. Iter 42500 : Loss. 0.00329 +2025-10-13 11:54:46,503 INFO Train. Iter 42600 : Loss. 0.00328 +2025-10-13 11:55:11,117 INFO Train. Iter 42700 : Loss. 0.00327 +2025-10-13 11:55:35,718 INFO Train. Iter 42800 : Loss. 0.00318 +2025-10-13 11:56:00,830 INFO Train. Iter 42900 : Loss. 0.00321 +2025-10-13 11:56:25,251 INFO Train. Iter 43000 : Loss. 0.00320 +2025-10-13 11:56:49,815 INFO Train. Iter 43100 : Loss. 0.00322 +2025-10-13 11:57:14,466 INFO Train. Iter 43200 : Loss. 0.00324 +2025-10-13 11:57:39,156 INFO Train. Iter 43300 : Loss. 0.00330 +2025-10-13 11:58:04,181 INFO Train. Iter 43400 : Loss. 0.00316 +2025-10-13 11:58:28,827 INFO Train. Iter 43500 : Loss. 0.00320 +2025-10-13 11:58:53,179 INFO Train. Iter 43600 : Loss. 0.00324 +2025-10-13 11:59:17,863 INFO Train. Iter 43700 : Loss. 0.00324 +2025-10-13 11:59:42,093 INFO Train. Iter 43800 : Loss. 0.00326 +2025-10-13 12:00:07,455 INFO Train. Iter 43900 : Loss. 0.00325 +2025-10-13 12:00:32,081 INFO Train. Iter 44000 : Loss. 0.00326 +2025-10-13 12:00:56,716 INFO Train. Iter 44100 : Loss. 0.00322 +2025-10-13 12:01:21,519 INFO Train. Iter 44200 : Loss. 0.00333 +2025-10-13 12:01:46,335 INFO Train. Iter 44300 : Loss. 0.00328 +2025-10-13 12:02:10,857 INFO Train. Iter 44400 : Loss. 0.00323 +2025-10-13 12:02:35,747 INFO Train. Iter 44500 : Loss. 0.00326 +2025-10-13 12:03:00,188 INFO Train. Iter 44600 : Loss. 0.00314 +2025-10-13 12:03:24,563 INFO Train. Iter 44700 : Loss. 0.00334 +2025-10-13 12:03:49,502 INFO Train. Iter 44800 : Loss. 0.00329 +2025-10-13 12:04:13,861 INFO Train. Iter 44900 : Loss. 0.00324 +2025-10-13 12:04:38,404 INFO Train. Iter 45000 : Loss. 0.00327 +2025-10-13 12:05:02,898 INFO Train. Iter 45100 : Loss. 0.00327 +2025-10-13 12:05:27,538 INFO Train. Iter 45200 : Loss. 0.00327 +2025-10-13 12:05:52,497 INFO Train. Iter 45300 : Loss. 0.00327 +2025-10-13 12:06:16,742 INFO Train. Iter 45400 : Loss. 0.00321 +2025-10-13 12:06:41,467 INFO Train. Iter 45500 : Loss. 0.00328 +2025-10-13 12:07:06,045 INFO Train. Iter 45600 : Loss. 0.00323 +2025-10-13 12:07:31,235 INFO Train. Iter 45700 : Loss. 0.00331 +2025-10-13 12:07:55,904 INFO Train. Iter 45800 : Loss. 0.00325 +2025-10-13 12:08:20,580 INFO Train. Iter 45900 : Loss. 0.00315 +2025-10-13 12:08:45,153 INFO Train. Iter 46000 : Loss. 0.00332 +2025-10-13 12:09:09,669 INFO Train. Iter 46100 : Loss. 0.00329 +2025-10-13 12:09:34,827 INFO Train. Iter 46200 : Loss. 0.00330 +2025-10-13 12:09:59,416 INFO Train. Iter 46300 : Loss. 0.00325 +2025-10-13 12:10:23,967 INFO Train. Iter 46400 : Loss. 0.00324 +2025-10-13 12:10:48,461 INFO Train. Iter 46500 : Loss. 0.00325 +2025-10-13 12:11:13,315 INFO Train. Iter 46600 : Loss. 0.00331 +2025-10-13 12:11:38,445 INFO Train. Iter 46700 : Loss. 0.00325 +2025-10-13 12:12:03,103 INFO Train. Iter 46800 : Loss. 0.00335 +2025-10-13 12:12:27,727 INFO Train. Iter 46900 : Loss. 0.00326 +2025-10-13 12:12:52,345 INFO Train. Iter 47000 : Loss. 0.00321 +2025-10-13 12:13:18,136 INFO Train. Iter 47100 : Loss. 0.00325 +2025-10-13 12:13:42,855 INFO Train. Iter 47200 : Loss. 0.00325 +2025-10-13 12:14:07,361 INFO Train. Iter 47300 : Loss. 0.00329 +2025-10-13 12:14:32,166 INFO Train. Iter 47400 : Loss. 0.00329 +2025-10-13 12:14:56,766 INFO Train. Iter 47500 : Loss. 0.00323 +2025-10-13 12:15:22,099 INFO Train. Iter 47600 : Loss. 0.00330 +2025-10-13 12:15:46,812 INFO Train. Iter 47700 : Loss. 0.00334 +2025-10-13 12:16:11,489 INFO Train. Iter 47800 : Loss. 0.00325 +2025-10-13 12:16:35,960 INFO Train. Iter 47900 : Loss. 0.00331 +2025-10-13 12:17:00,895 INFO Train. Iter 48000 : Loss. 0.00329 +2025-10-13 12:17:25,862 INFO Train. Iter 48100 : Loss. 0.00325 +2025-10-13 12:17:50,370 INFO Train. Iter 48200 : Loss. 0.00326 +2025-10-13 12:18:14,789 INFO Train. Iter 48300 : Loss. 0.00325 +2025-10-13 12:18:39,703 INFO Train. Iter 48400 : Loss. 0.00326 +2025-10-13 12:19:04,773 INFO Train. Iter 48500 : Loss. 0.00327 +2025-10-13 12:19:29,622 INFO Train. Iter 48600 : Loss. 0.00323 +2025-10-13 12:19:54,400 INFO Train. Iter 48700 : Loss. 0.00324 +2025-10-13 12:20:18,970 INFO Train. Iter 48800 : Loss. 0.00331 +2025-10-13 12:20:43,688 INFO Train. Iter 48900 : Loss. 0.00330 +2025-10-13 12:21:08,780 INFO Train. Iter 49000 : Loss. 0.00317 +2025-10-13 12:21:33,445 INFO Train. Iter 49100 : Loss. 0.00322 +2025-10-13 12:21:57,983 INFO Train. Iter 49200 : Loss. 0.00328 +2025-10-13 12:22:22,710 INFO Train. Iter 49300 : Loss. 0.00325 +2025-10-13 12:22:47,974 INFO Train. Iter 49400 : Loss. 0.00333 +2025-10-13 12:23:13,104 INFO Train. Iter 49500 : Loss. 0.00322 +2025-10-13 12:23:37,689 INFO Train. Iter 49600 : Loss. 0.00330 +2025-10-13 12:24:02,488 INFO Train. Iter 49700 : Loss. 0.00333 +2025-10-13 12:24:27,016 INFO Train. Iter 49800 : Loss. 0.00334 +2025-10-13 12:24:51,931 INFO Train. Iter 49900 : Loss. 0.00331 +2025-10-13 12:25:16,644 INFO Train. Iter 50000 : Loss. 0.00331 +2025-10-13 12:25:43,670 INFO Train. Iter 50100 : Loss. 0.00328 +2025-10-13 12:26:08,208 INFO Train. Iter 50200 : Loss. 0.00325 +2025-10-13 12:26:32,895 INFO Train. Iter 50300 : Loss. 0.00322 +2025-10-13 12:26:57,930 INFO Train. Iter 50400 : Loss. 0.00324 +2025-10-13 12:27:22,403 INFO Train. Iter 50500 : Loss. 0.00329 +2025-10-13 12:27:46,934 INFO Train. Iter 50600 : Loss. 0.00331 +2025-10-13 12:28:11,281 INFO Train. Iter 50700 : Loss. 0.00331 +2025-10-13 12:28:36,615 INFO Train. Iter 50800 : Loss. 0.00324 +2025-10-13 12:29:01,356 INFO Train. Iter 50900 : Loss. 0.00321 +2025-10-13 12:29:25,946 INFO Train. Iter 51000 : Loss. 0.00329 +2025-10-13 12:29:50,467 INFO Train. Iter 51100 : Loss. 0.00328 +2025-10-13 12:30:15,112 INFO Train. Iter 51200 : Loss. 0.00336 +2025-10-13 12:30:40,281 INFO Train. Iter 51300 : Loss. 0.00328 +2025-10-13 12:31:04,880 INFO Train. Iter 51400 : Loss. 0.00327 +2025-10-13 12:31:29,496 INFO Train. Iter 51500 : Loss. 0.00319 +2025-10-13 12:31:54,133 INFO Train. Iter 51600 : Loss. 0.00329 +2025-10-13 12:32:18,509 INFO Train. Iter 51700 : Loss. 0.00336 +2025-10-13 12:32:43,583 INFO Train. Iter 51800 : Loss. 0.00326 +2025-10-13 12:33:08,311 INFO Train. Iter 51900 : Loss. 0.00324 +2025-10-13 12:33:32,733 INFO Train. Iter 52000 : Loss. 0.00328 +2025-10-13 12:33:57,651 INFO Train. Iter 52100 : Loss. 0.00333 +2025-10-13 12:34:23,354 INFO Train. Iter 52200 : Loss. 0.00339 +2025-10-13 12:34:48,156 INFO Train. Iter 52300 : Loss. 0.00325 +2025-10-13 12:35:13,076 INFO Train. Iter 52400 : Loss. 0.00325 +2025-10-13 12:35:37,645 INFO Train. Iter 52500 : Loss. 0.00330 +2025-10-13 12:36:02,670 INFO Train. Iter 52600 : Loss. 0.00326 +2025-10-13 12:36:28,246 INFO Train. Iter 52700 : Loss. 0.00326 +2025-10-13 12:36:53,047 INFO Train. Iter 52800 : Loss. 0.00320 +2025-10-13 12:37:17,819 INFO Train. Iter 52900 : Loss. 0.00334 +2025-10-13 12:37:42,584 INFO Train. Iter 53000 : Loss. 0.00335 +2025-10-13 12:38:07,091 INFO Train. Iter 53100 : Loss. 0.00330 +2025-10-13 12:38:32,272 INFO Train. Iter 53200 : Loss. 0.00324 +2025-10-13 12:38:57,268 INFO Train. Iter 53300 : Loss. 0.00323 +2025-10-13 12:39:21,697 INFO Train. Iter 53400 : Loss. 0.00332 +2025-10-13 12:39:46,433 INFO Train. Iter 53500 : Loss. 0.00329 +2025-10-13 12:40:11,576 INFO Train. Iter 53600 : Loss. 0.00327 +2025-10-13 12:40:36,245 INFO Train. Iter 53700 : Loss. 0.00320 +2025-10-13 12:41:00,811 INFO Train. Iter 53800 : Loss. 0.00326 +2025-10-13 12:41:25,573 INFO Train. Iter 53900 : Loss. 0.00329 +2025-10-13 12:41:49,978 INFO Train. Iter 54000 : Loss. 0.00341 +2025-10-13 12:42:14,948 INFO Train. Iter 54100 : Loss. 0.00330 +2025-10-13 12:42:39,835 INFO Train. Iter 54200 : Loss. 0.00334 +2025-10-13 12:43:04,728 INFO Train. Iter 54300 : Loss. 0.00327 +2025-10-13 12:43:29,050 INFO Train. Iter 54400 : Loss. 0.00324 +2025-10-13 12:43:53,376 INFO Train. Iter 54500 : Loss. 0.00317 +2025-10-13 12:44:18,409 INFO Train. Iter 54600 : Loss. 0.00325 +2025-10-13 12:44:43,082 INFO Train. Iter 54700 : Loss. 0.00334 +2025-10-13 12:45:07,480 INFO Train. Iter 54800 : Loss. 0.00330 +2025-10-13 12:45:32,228 INFO Train. Iter 54900 : Loss. 0.00339 +2025-10-13 12:45:57,177 INFO Train. Iter 55000 : Loss. 0.00334 +2025-10-13 12:46:21,634 INFO Train. Iter 55100 : Loss. 0.00329 +2025-10-13 12:46:46,106 INFO Train. Iter 55200 : Loss. 0.00329 +2025-10-13 12:47:10,720 INFO Train. Iter 55300 : Loss. 0.00326 +2025-10-13 12:47:35,256 INFO Train. Iter 55400 : Loss. 0.00329 +2025-10-13 12:48:00,601 INFO Train. Iter 55500 : Loss. 0.00321 +2025-10-13 12:48:25,267 INFO Train. Iter 55600 : Loss. 0.00330 +2025-10-13 12:48:49,848 INFO Train. Iter 55700 : Loss. 0.00328 +2025-10-13 12:49:14,285 INFO Train. Iter 55800 : Loss. 0.00333 +2025-10-13 12:49:38,918 INFO Train. Iter 55900 : Loss. 0.00333 +2025-10-13 12:50:04,138 INFO Train. Iter 56000 : Loss. 0.00322 +2025-10-13 12:50:28,506 INFO Train. Iter 56100 : Loss. 0.00326 +2025-10-13 12:50:53,038 INFO Train. Iter 56200 : Loss. 0.00337 +2025-10-13 12:51:17,654 INFO Train. Iter 56300 : Loss. 0.00332 +2025-10-13 12:51:42,976 INFO Train. Iter 56400 : Loss. 0.00327 +2025-10-13 12:52:07,252 INFO Train. Iter 56500 : Loss. 0.00326 +2025-10-13 12:52:32,034 INFO Train. Iter 56600 : Loss. 0.00329 +2025-10-13 12:52:57,346 INFO Train. Iter 56700 : Loss. 0.00337 +2025-10-13 12:53:22,022 INFO Train. Iter 56800 : Loss. 0.00331 +2025-10-13 12:53:46,994 INFO Train. Iter 56900 : Loss. 0.00330 +2025-10-13 12:54:11,548 INFO Train. Iter 57000 : Loss. 0.00329 +2025-10-13 12:54:36,350 INFO Train. Iter 57100 : Loss. 0.00333 +2025-10-13 12:55:00,986 INFO Train. Iter 57200 : Loss. 0.00328 +2025-10-13 12:55:25,203 INFO Train. Iter 57300 : Loss. 0.00337 +2025-10-13 12:55:50,273 INFO Train. Iter 57400 : Loss. 0.00322 +2025-10-13 12:56:14,813 INFO Train. Iter 57500 : Loss. 0.00321 +2025-10-13 12:56:39,284 INFO Train. Iter 57600 : Loss. 0.00325 +2025-10-13 12:57:03,794 INFO Train. Iter 57700 : Loss. 0.00330 +2025-10-13 12:57:29,146 INFO Train. Iter 57800 : Loss. 0.00334 +2025-10-13 12:57:53,830 INFO Train. Iter 57900 : Loss. 0.00326 +2025-10-13 12:58:18,142 INFO Train. Iter 58000 : Loss. 0.00334 +2025-10-13 12:58:43,047 INFO Train. Iter 58100 : Loss. 0.00322 +2025-10-13 12:59:07,476 INFO Train. Iter 58200 : Loss. 0.00330 +2025-10-13 12:59:32,932 INFO Train. Iter 58300 : Loss. 0.00332 +2025-10-13 12:59:57,583 INFO Train. Iter 58400 : Loss. 0.00332 +2025-10-13 13:00:22,129 INFO Train. Iter 58500 : Loss. 0.00327 +2025-10-13 13:00:46,594 INFO Train. Iter 58600 : Loss. 0.00320 +2025-10-13 13:01:11,048 INFO Train. Iter 58700 : Loss. 0.00339 +2025-10-13 13:01:36,248 INFO Train. Iter 58800 : Loss. 0.00320 +2025-10-13 13:02:00,932 INFO Train. Iter 58900 : Loss. 0.00332 +2025-10-13 13:02:25,679 INFO Train. Iter 59000 : Loss. 0.00325 +2025-10-13 13:02:50,260 INFO Train. Iter 59100 : Loss. 0.00339 +2025-10-13 13:03:15,544 INFO Train. Iter 59200 : Loss. 0.00332 +2025-10-13 13:03:40,354 INFO Train. Iter 59300 : Loss. 0.00329 +2025-10-13 13:04:05,064 INFO Train. Iter 59400 : Loss. 0.00322 +2025-10-13 13:04:29,678 INFO Train. Iter 59500 : Loss. 0.00332 +2025-10-13 13:04:54,235 INFO Train. Iter 59600 : Loss. 0.00335 +2025-10-13 13:05:19,146 INFO Train. Iter 59700 : Loss. 0.00328 +2025-10-13 13:05:43,859 INFO Train. Iter 59800 : Loss. 0.00332 +2025-10-13 13:06:08,612 INFO Train. Iter 59900 : Loss. 0.00328 +2025-10-13 13:06:33,051 INFO Train. Iter 60000 : Loss. 0.00328 +2025-10-13 13:07:01,736 INFO Train. Iter 60100 : Loss. 0.00332 +2025-10-13 13:07:26,928 INFO Train. Iter 60200 : Loss. 0.00329 +2025-10-13 13:07:51,131 INFO Train. Iter 60300 : Loss. 0.00324 +2025-10-13 13:08:15,667 INFO Train. Iter 60400 : Loss. 0.00328 +2025-10-13 13:08:40,271 INFO Train. Iter 60500 : Loss. 0.00335 +2025-10-13 13:09:05,396 INFO Train. Iter 60600 : Loss. 0.00329 +2025-10-13 13:09:30,203 INFO Train. Iter 60700 : Loss. 0.00320 +2025-10-13 13:09:54,741 INFO Train. Iter 60800 : Loss. 0.00323 +2025-10-13 13:10:19,149 INFO Train. Iter 60900 : Loss. 0.00334 +2025-10-13 13:10:43,679 INFO Train. Iter 61000 : Loss. 0.00343 +2025-10-13 13:11:08,811 INFO Train. Iter 61100 : Loss. 0.00330 +2025-10-13 13:11:33,685 INFO Train. Iter 61200 : Loss. 0.00326 +2025-10-13 13:11:58,109 INFO Train. Iter 61300 : Loss. 0.00329 +2025-10-13 13:12:22,924 INFO Train. Iter 61400 : Loss. 0.00322 +2025-10-13 13:12:47,663 INFO Train. Iter 61500 : Loss. 0.00331 +2025-10-13 13:13:13,046 INFO Train. Iter 61600 : Loss. 0.00329 +2025-10-13 13:13:37,744 INFO Train. Iter 61700 : Loss. 0.00325 +2025-10-13 13:14:02,058 INFO Train. Iter 61800 : Loss. 0.00332 +2025-10-13 13:14:26,293 INFO Train. Iter 61900 : Loss. 0.00331 +2025-10-13 13:14:51,299 INFO Train. Iter 62000 : Loss. 0.00339 +2025-10-13 13:15:16,137 INFO Train. Iter 62100 : Loss. 0.00328 +2025-10-13 13:15:41,267 INFO Train. Iter 62200 : Loss. 0.00325 +2025-10-13 13:16:06,380 INFO Train. Iter 62300 : Loss. 0.00338 +2025-10-13 13:16:31,126 INFO Train. Iter 62400 : Loss. 0.00321 +2025-10-13 13:16:56,971 INFO Train. Iter 62500 : Loss. 0.00340 +2025-10-13 13:17:22,183 INFO Train. Iter 62600 : Loss. 0.00335 +2025-10-13 13:17:47,689 INFO Train. Iter 62700 : Loss. 0.00331 +2025-10-13 13:18:12,945 INFO Train. Iter 62800 : Loss. 0.00332 +2025-10-13 13:18:37,923 INFO Train. Iter 62900 : Loss. 0.00333 +2025-10-13 13:19:03,622 INFO Train. Iter 63000 : Loss. 0.00323 +2025-10-13 13:19:29,759 INFO Train. Iter 63100 : Loss. 0.00322 +2025-10-13 13:19:54,829 INFO Train. Iter 63200 : Loss. 0.00327 +2025-10-13 13:20:20,105 INFO Train. Iter 63300 : Loss. 0.00329 +2025-10-13 13:20:45,532 INFO Train. Iter 63400 : Loss. 0.00329 +2025-10-13 13:21:10,290 INFO Train. Iter 63500 : Loss. 0.00316 +2025-10-13 13:21:34,839 INFO Train. Iter 63600 : Loss. 0.00325 +2025-10-13 13:21:59,415 INFO Train. Iter 63700 : Loss. 0.00332 +2025-10-13 13:22:23,504 INFO Train. Iter 63800 : Loss. 0.00331 +2025-10-13 13:22:48,763 INFO Train. Iter 63900 : Loss. 0.00330 +2025-10-13 13:23:13,507 INFO Train. Iter 64000 : Loss. 0.00330 +2025-10-13 13:23:37,957 INFO Train. Iter 64100 : Loss. 0.00327 +2025-10-13 13:24:02,553 INFO Train. Iter 64200 : Loss. 0.00332 +2025-10-13 13:24:27,039 INFO Train. Iter 64300 : Loss. 0.00325 +2025-10-13 13:24:52,255 INFO Train. Iter 64400 : Loss. 0.00330 +2025-10-13 13:25:16,796 INFO Train. Iter 64500 : Loss. 0.00334 +2025-10-13 13:25:41,517 INFO Train. Iter 64600 : Loss. 0.00325 +2025-10-13 13:26:06,473 INFO Train. Iter 64700 : Loss. 0.00329 +2025-10-13 13:26:31,449 INFO Train. Iter 64800 : Loss. 0.00327 +2025-10-13 13:26:56,165 INFO Train. Iter 64900 : Loss. 0.00333 +2025-10-13 13:27:20,937 INFO Train. Iter 65000 : Loss. 0.00332 +2025-10-13 13:27:45,478 INFO Train. Iter 65100 : Loss. 0.00326 +2025-10-13 13:28:10,079 INFO Train. Iter 65200 : Loss. 0.00327 +2025-10-13 13:28:36,107 INFO Train. Iter 65300 : Loss. 0.00323 +2025-10-13 13:29:01,625 INFO Train. Iter 65400 : Loss. 0.00329 +2025-10-13 13:29:26,655 INFO Train. Iter 65500 : Loss. 0.00325 +2025-10-13 13:29:51,676 INFO Train. Iter 65600 : Loss. 0.00332 +2025-10-13 13:30:16,645 INFO Train. Iter 65700 : Loss. 0.00322 +2025-10-13 13:30:42,189 INFO Train. Iter 65800 : Loss. 0.00326 +2025-10-13 13:31:06,553 INFO Train. Iter 65900 : Loss. 0.00329 +2025-10-13 13:31:31,068 INFO Train. Iter 66000 : Loss. 0.00323 +2025-10-13 13:31:55,455 INFO Train. Iter 66100 : Loss. 0.00331 +2025-10-13 13:32:20,423 INFO Train. Iter 66200 : Loss. 0.00320 +2025-10-13 13:32:45,259 INFO Train. Iter 66300 : Loss. 0.00331 +2025-10-13 13:33:09,704 INFO Train. Iter 66400 : Loss. 0.00337 +2025-10-13 13:33:34,341 INFO Train. Iter 66500 : Loss. 0.00331 +2025-10-13 13:33:58,922 INFO Train. Iter 66600 : Loss. 0.00334 +2025-10-13 13:34:23,887 INFO Train. Iter 66700 : Loss. 0.00331 +2025-10-13 13:34:48,612 INFO Train. Iter 66800 : Loss. 0.00325 +2025-10-13 13:35:12,956 INFO Train. Iter 66900 : Loss. 0.00332 +2025-10-13 13:35:37,562 INFO Train. Iter 67000 : Loss. 0.00325 +2025-10-13 13:36:02,102 INFO Train. Iter 67100 : Loss. 0.00331 +2025-10-13 13:36:27,040 INFO Train. Iter 67200 : Loss. 0.00324 +2025-10-13 13:36:51,360 INFO Train. Iter 67300 : Loss. 0.00329 +2025-10-13 13:37:16,000 INFO Train. Iter 67400 : Loss. 0.00322 +2025-10-13 13:37:40,593 INFO Train. Iter 67500 : Loss. 0.00332 +2025-10-13 13:38:05,646 INFO Train. Iter 67600 : Loss. 0.00326 +2025-10-13 13:38:30,156 INFO Train. Iter 67700 : Loss. 0.00323 +2025-10-13 13:38:54,503 INFO Train. Iter 67800 : Loss. 0.00327 +2025-10-13 13:39:19,416 INFO Train. Iter 67900 : Loss. 0.00330 +2025-10-13 13:39:43,803 INFO Train. Iter 68000 : Loss. 0.00329 +2025-10-13 13:40:08,609 INFO Train. Iter 68100 : Loss. 0.00326 +2025-10-13 13:40:33,190 INFO Train. Iter 68200 : Loss. 0.00329 +2025-10-13 13:40:57,735 INFO Train. Iter 68300 : Loss. 0.00327 +2025-10-13 13:41:22,575 INFO Train. Iter 68400 : Loss. 0.00322 +2025-10-13 13:41:47,011 INFO Train. Iter 68500 : Loss. 0.00334 +2025-10-13 13:42:12,361 INFO Train. Iter 68600 : Loss. 0.00322 +2025-10-13 13:42:36,818 INFO Train. Iter 68700 : Loss. 0.00327 +2025-10-13 13:43:01,209 INFO Train. Iter 68800 : Loss. 0.00326 +2025-10-13 13:43:26,090 INFO Train. Iter 68900 : Loss. 0.00328 +2025-10-13 13:43:50,933 INFO Train. Iter 69000 : Loss. 0.00332 +2025-10-13 13:44:15,455 INFO Train. Iter 69100 : Loss. 0.00324 +2025-10-13 13:44:40,422 INFO Train. Iter 69200 : Loss. 0.00330 +2025-10-13 13:45:05,055 INFO Train. Iter 69300 : Loss. 0.00324 +2025-10-13 13:45:29,655 INFO Train. Iter 69400 : Loss. 0.00317 +2025-10-13 13:45:54,881 INFO Train. Iter 69500 : Loss. 0.00332 +2025-10-13 13:46:19,843 INFO Train. Iter 69600 : Loss. 0.00320 +2025-10-13 13:46:44,390 INFO Train. Iter 69700 : Loss. 0.00325 +2025-10-13 13:47:09,011 INFO Train. Iter 69800 : Loss. 0.00327 +2025-10-13 13:47:33,466 INFO Train. Iter 69900 : Loss. 0.00321 +2025-10-13 13:47:58,483 INFO Train. Iter 70000 : Loss. 0.00325 +2025-10-13 13:48:26,891 INFO Train. Iter 70100 : Loss. 0.00325 +2025-10-13 13:48:51,452 INFO Train. Iter 70200 : Loss. 0.00315 +2025-10-13 13:49:15,785 INFO Train. Iter 70300 : Loss. 0.00334 +2025-10-13 13:49:41,082 INFO Train. Iter 70400 : Loss. 0.00332 +2025-10-13 13:50:05,335 INFO Train. Iter 70500 : Loss. 0.00322 +2025-10-13 13:50:29,741 INFO Train. Iter 70600 : Loss. 0.00330 +2025-10-13 13:50:54,193 INFO Train. Iter 70700 : Loss. 0.00328 +2025-10-13 13:51:18,850 INFO Train. Iter 70800 : Loss. 0.00324 +2025-10-13 13:51:43,919 INFO Train. Iter 70900 : Loss. 0.00325 +2025-10-13 13:52:08,560 INFO Train. Iter 71000 : Loss. 0.00324 +2025-10-13 13:52:33,059 INFO Train. Iter 71100 : Loss. 0.00325 +2025-10-13 13:52:57,572 INFO Train. Iter 71200 : Loss. 0.00326 +2025-10-13 13:53:22,626 INFO Train. Iter 71300 : Loss. 0.00328 +2025-10-13 13:53:47,067 INFO Train. Iter 71400 : Loss. 0.00324 +2025-10-13 13:54:11,798 INFO Train. Iter 71500 : Loss. 0.00320 +2025-10-13 13:54:36,292 INFO Train. Iter 71600 : Loss. 0.00324 +2025-10-13 13:55:00,497 INFO Train. Iter 71700 : Loss. 0.00330 +2025-10-13 13:55:25,649 INFO Train. Iter 71800 : Loss. 0.00328 +2025-10-13 13:55:50,297 INFO Train. Iter 71900 : Loss. 0.00329 +2025-10-13 13:56:14,829 INFO Train. Iter 72000 : Loss. 0.00323 +2025-10-13 13:56:39,269 INFO Train. Iter 72100 : Loss. 0.00332 +2025-10-13 13:57:03,875 INFO Train. Iter 72200 : Loss. 0.00326 +2025-10-13 13:57:28,913 INFO Train. Iter 72300 : Loss. 0.00324 +2025-10-13 13:57:53,216 INFO Train. Iter 72400 : Loss. 0.00328 +2025-10-13 13:58:17,882 INFO Train. Iter 72500 : Loss. 0.00327 +2025-10-13 13:58:42,611 INFO Train. Iter 72600 : Loss. 0.00320 +2025-10-13 13:59:07,743 INFO Train. Iter 72700 : Loss. 0.00330 +2025-10-13 13:59:32,451 INFO Train. Iter 72800 : Loss. 0.00322 +2025-10-13 13:59:56,839 INFO Train. Iter 72900 : Loss. 0.00334 +2025-10-13 14:00:21,724 INFO Train. Iter 73000 : Loss. 0.00327 +2025-10-13 14:00:46,391 INFO Train. Iter 73100 : Loss. 0.00330 +2025-10-13 14:01:11,566 INFO Train. Iter 73200 : Loss. 0.00319 +2025-10-13 14:01:36,259 INFO Train. Iter 73300 : Loss. 0.00318 +2025-10-13 14:02:00,705 INFO Train. Iter 73400 : Loss. 0.00331 +2025-10-13 14:02:25,017 INFO Train. Iter 73500 : Loss. 0.00327 +2025-10-13 14:02:49,554 INFO Train. Iter 73600 : Loss. 0.00322 +2025-10-13 14:03:14,922 INFO Train. Iter 73700 : Loss. 0.00318 +2025-10-13 14:03:39,780 INFO Train. Iter 73800 : Loss. 0.00321 +2025-10-13 14:04:04,331 INFO Train. Iter 73900 : Loss. 0.00321 +2025-10-13 14:04:28,939 INFO Train. Iter 74000 : Loss. 0.00323 +2025-10-13 14:04:54,038 INFO Train. Iter 74100 : Loss. 0.00322 +2025-10-13 14:05:18,729 INFO Train. Iter 74200 : Loss. 0.00324 +2025-10-13 14:05:43,603 INFO Train. Iter 74300 : Loss. 0.00331 +2025-10-13 14:06:08,183 INFO Train. Iter 74400 : Loss. 0.00326 +2025-10-13 14:06:32,649 INFO Train. Iter 74500 : Loss. 0.00332 +2025-10-13 14:06:57,876 INFO Train. Iter 74600 : Loss. 0.00325 +2025-10-13 14:07:22,399 INFO Train. Iter 74700 : Loss. 0.00326 +2025-10-13 14:07:46,978 INFO Train. Iter 74800 : Loss. 0.00326 +2025-10-13 14:08:11,565 INFO Train. Iter 74900 : Loss. 0.00321 +2025-10-13 14:08:35,909 INFO Train. Iter 75000 : Loss. 0.00315 +2025-10-13 14:09:01,002 INFO Train. Iter 75100 : Loss. 0.00323 +2025-10-13 14:09:25,392 INFO Train. Iter 75200 : Loss. 0.00315 +2025-10-13 14:09:49,866 INFO Train. Iter 75300 : Loss. 0.00325 +2025-10-13 14:10:14,689 INFO Train. Iter 75400 : Loss. 0.00327 +2025-10-13 14:10:40,178 INFO Train. Iter 75500 : Loss. 0.00324 +2025-10-13 14:11:04,871 INFO Train. Iter 75600 : Loss. 0.00314 +2025-10-13 14:11:29,851 INFO Train. Iter 75700 : Loss. 0.00322 +2025-10-13 14:11:54,236 INFO Train. Iter 75800 : Loss. 0.00324 +2025-10-13 14:12:18,947 INFO Train. Iter 75900 : Loss. 0.00324 +2025-10-13 14:12:44,038 INFO Train. Iter 76000 : Loss. 0.00323 +2025-10-13 14:13:08,356 INFO Train. Iter 76100 : Loss. 0.00311 +2025-10-13 14:13:33,128 INFO Train. Iter 76200 : Loss. 0.00323 +2025-10-13 14:13:57,658 INFO Train. Iter 76300 : Loss. 0.00326 +2025-10-13 14:14:22,279 INFO Train. Iter 76400 : Loss. 0.00329 +2025-10-13 14:14:47,435 INFO Train. Iter 76500 : Loss. 0.00316 +2025-10-13 14:15:12,103 INFO Train. Iter 76600 : Loss. 0.00322 +2025-10-13 14:15:36,618 INFO Train. Iter 76700 : Loss. 0.00323 +2025-10-13 14:16:01,267 INFO Train. Iter 76800 : Loss. 0.00320 +2025-10-13 14:16:26,335 INFO Train. Iter 76900 : Loss. 0.00316 +2025-10-13 14:16:50,788 INFO Train. Iter 77000 : Loss. 0.00319 +2025-10-13 14:17:15,527 INFO Train. Iter 77100 : Loss. 0.00325 +2025-10-13 14:17:40,026 INFO Train. Iter 77200 : Loss. 0.00321 +2025-10-13 14:18:04,607 INFO Train. Iter 77300 : Loss. 0.00324 +2025-10-13 14:18:29,625 INFO Train. Iter 77400 : Loss. 0.00307 +2025-10-13 14:18:54,696 INFO Train. Iter 77500 : Loss. 0.00310 +2025-10-13 14:19:19,219 INFO Train. Iter 77600 : Loss. 0.00321 +2025-10-13 14:19:43,625 INFO Train. Iter 77700 : Loss. 0.00322 +2025-10-13 14:20:08,084 INFO Train. Iter 77800 : Loss. 0.00321 +2025-10-13 14:20:33,347 INFO Train. Iter 77900 : Loss. 0.00310 +2025-10-13 14:20:57,883 INFO Train. Iter 78000 : Loss. 0.00326 +2025-10-13 14:21:22,437 INFO Train. Iter 78100 : Loss. 0.00322 +2025-10-13 14:21:47,123 INFO Train. Iter 78200 : Loss. 0.00323 +2025-10-13 14:22:12,345 INFO Train. Iter 78300 : Loss. 0.00318 +2025-10-13 14:22:36,783 INFO Train. Iter 78400 : Loss. 0.00315 +2025-10-13 14:23:01,313 INFO Train. Iter 78500 : Loss. 0.00314 +2025-10-13 14:23:25,644 INFO Train. Iter 78600 : Loss. 0.00329 +2025-10-13 14:23:50,327 INFO Train. Iter 78700 : Loss. 0.00324 +2025-10-13 14:24:15,440 INFO Train. Iter 78800 : Loss. 0.00316 +2025-10-13 14:24:39,984 INFO Train. Iter 78900 : Loss. 0.00319 +2025-10-13 14:25:04,569 INFO Train. Iter 79000 : Loss. 0.00314 +2025-10-13 14:25:29,228 INFO Train. Iter 79100 : Loss. 0.00319 +2025-10-13 14:25:53,816 INFO Train. Iter 79200 : Loss. 0.00319 +2025-10-13 14:26:18,912 INFO Train. Iter 79300 : Loss. 0.00311 +2025-10-13 14:26:43,493 INFO Train. Iter 79400 : Loss. 0.00318 +2025-10-13 14:27:08,145 INFO Train. Iter 79500 : Loss. 0.00320 +2025-10-13 14:27:32,560 INFO Train. Iter 79600 : Loss. 0.00320 +2025-10-13 14:27:57,681 INFO Train. Iter 79700 : Loss. 0.00318 +2025-10-13 14:28:21,993 INFO Train. Iter 79800 : Loss. 0.00320 +2025-10-13 14:28:46,965 INFO Train. Iter 79900 : Loss. 0.00311 +2025-10-13 14:29:11,684 INFO Train. Iter 80000 : Loss. 0.00319 +2025-10-13 14:29:40,113 INFO Train. Iter 80100 : Loss. 0.00320 +2025-10-13 14:30:05,257 INFO Train. Iter 80200 : Loss. 0.00318 +2025-10-13 14:30:29,596 INFO Train. Iter 80300 : Loss. 0.00318 +2025-10-13 14:30:53,889 INFO Train. Iter 80400 : Loss. 0.00318 +2025-10-13 14:31:18,378 INFO Train. Iter 80500 : Loss. 0.00308 +2025-10-13 14:31:43,225 INFO Train. Iter 80600 : Loss. 0.00312 +2025-10-13 14:32:09,005 INFO Train. Iter 80700 : Loss. 0.00320 +2025-10-13 14:32:33,625 INFO Train. Iter 80800 : Loss. 0.00311 +2025-10-13 14:32:58,026 INFO Train. Iter 80900 : Loss. 0.00316 +2025-10-13 14:33:22,570 INFO Train. Iter 81000 : Loss. 0.00307 +2025-10-13 14:33:47,554 INFO Train. Iter 81100 : Loss. 0.00315 +2025-10-13 14:34:12,322 INFO Train. Iter 81200 : Loss. 0.00309 +2025-10-13 14:34:37,189 INFO Train. Iter 81300 : Loss. 0.00315 +2025-10-13 14:35:01,726 INFO Train. Iter 81400 : Loss. 0.00317 +2025-10-13 14:35:26,129 INFO Train. Iter 81500 : Loss. 0.00321 +2025-10-13 14:35:50,981 INFO Train. Iter 81600 : Loss. 0.00313 +2025-10-13 14:36:15,500 INFO Train. Iter 81700 : Loss. 0.00316 +2025-10-13 14:36:40,071 INFO Train. Iter 81800 : Loss. 0.00318 +2025-10-13 14:37:04,447 INFO Train. Iter 81900 : Loss. 0.00310 +2025-10-13 14:37:29,178 INFO Train. Iter 82000 : Loss. 0.00311 +2025-10-13 14:37:54,152 INFO Train. Iter 82100 : Loss. 0.00317 +2025-10-13 14:38:18,777 INFO Train. Iter 82200 : Loss. 0.00311 +2025-10-13 14:38:43,675 INFO Train. Iter 82300 : Loss. 0.00306 +2025-10-13 14:39:08,498 INFO Train. Iter 82400 : Loss. 0.00319 +2025-10-13 14:39:33,266 INFO Train. Iter 82500 : Loss. 0.00323 +2025-10-13 14:39:58,018 INFO Train. Iter 82600 : Loss. 0.00317 +2025-10-13 14:40:22,777 INFO Train. Iter 82700 : Loss. 0.00309 +2025-10-13 14:40:47,424 INFO Train. Iter 82800 : Loss. 0.00310 +2025-10-13 14:41:12,135 INFO Train. Iter 82900 : Loss. 0.00321 +2025-10-13 14:41:37,237 INFO Train. Iter 83000 : Loss. 0.00308 +2025-10-13 14:42:02,343 INFO Train. Iter 83100 : Loss. 0.00306 +2025-10-13 14:42:26,986 INFO Train. Iter 83200 : Loss. 0.00303 +2025-10-13 14:42:51,442 INFO Train. Iter 83300 : Loss. 0.00319 +2025-10-13 14:43:15,875 INFO Train. Iter 83400 : Loss. 0.00322 +2025-10-13 14:43:40,807 INFO Train. Iter 83500 : Loss. 0.00315 +2025-10-13 14:44:05,311 INFO Train. Iter 83600 : Loss. 0.00317 +2025-10-13 14:44:30,006 INFO Train. Iter 83700 : Loss. 0.00308 +2025-10-13 14:44:54,790 INFO Train. Iter 83800 : Loss. 0.00303 +2025-10-13 14:45:20,130 INFO Train. Iter 83900 : Loss. 0.00311 +2025-10-13 14:45:45,104 INFO Train. Iter 84000 : Loss. 0.00311 +2025-10-13 14:46:10,290 INFO Train. Iter 84100 : Loss. 0.00313 +2025-10-13 14:46:35,199 INFO Train. Iter 84200 : Loss. 0.00312 +2025-10-13 14:47:00,019 INFO Train. Iter 84300 : Loss. 0.00304 +2025-10-13 14:47:25,163 INFO Train. Iter 84400 : Loss. 0.00312 +2025-10-13 14:47:49,519 INFO Train. Iter 84500 : Loss. 0.00311 +2025-10-13 14:48:14,299 INFO Train. Iter 84600 : Loss. 0.00315 +2025-10-13 14:48:38,808 INFO Train. Iter 84700 : Loss. 0.00316 +2025-10-13 14:49:03,284 INFO Train. Iter 84800 : Loss. 0.00311 +2025-10-13 14:49:28,326 INFO Train. Iter 84900 : Loss. 0.00308 +2025-10-13 14:49:52,847 INFO Train. Iter 85000 : Loss. 0.00311 +2025-10-13 14:50:17,551 INFO Train. Iter 85100 : Loss. 0.00303 +2025-10-13 14:50:41,911 INFO Train. Iter 85200 : Loss. 0.00319 +2025-10-13 14:51:07,126 INFO Train. Iter 85300 : Loss. 0.00316 +2025-10-13 14:51:31,684 INFO Train. Iter 85400 : Loss. 0.00308 +2025-10-13 14:51:56,390 INFO Train. Iter 85500 : Loss. 0.00313 +2025-10-13 14:52:20,749 INFO Train. Iter 85600 : Loss. 0.00308 +2025-10-13 14:52:45,404 INFO Train. Iter 85700 : Loss. 0.00304 +2025-10-13 14:53:11,027 INFO Train. Iter 85800 : Loss. 0.00303 +2025-10-13 14:53:35,646 INFO Train. Iter 85900 : Loss. 0.00297 +2025-10-13 14:53:59,978 INFO Train. Iter 86000 : Loss. 0.00305 +2025-10-13 14:54:24,766 INFO Train. Iter 86100 : Loss. 0.00311 +2025-10-13 14:54:49,231 INFO Train. Iter 86200 : Loss. 0.00306 +2025-10-13 14:55:14,177 INFO Train. Iter 86300 : Loss. 0.00305 +2025-10-13 14:55:39,197 INFO Train. Iter 86400 : Loss. 0.00307 +2025-10-13 14:56:04,168 INFO Train. Iter 86500 : Loss. 0.00303 +2025-10-13 14:56:28,972 INFO Train. Iter 86600 : Loss. 0.00301 +2025-10-13 14:56:54,342 INFO Train. Iter 86700 : Loss. 0.00310 +2025-10-13 14:57:19,347 INFO Train. Iter 86800 : Loss. 0.00296 +2025-10-13 14:57:44,153 INFO Train. Iter 86900 : Loss. 0.00300 +2025-10-13 14:58:08,704 INFO Train. Iter 87000 : Loss. 0.00306 +2025-10-13 14:58:34,002 INFO Train. Iter 87100 : Loss. 0.00312 +2025-10-13 14:58:59,334 INFO Train. Iter 87200 : Loss. 0.00306 +2025-10-13 14:59:24,138 INFO Train. Iter 87300 : Loss. 0.00302 +2025-10-13 14:59:48,841 INFO Train. Iter 87400 : Loss. 0.00306 +2025-10-13 15:00:13,977 INFO Train. Iter 87500 : Loss. 0.00318 +2025-10-13 15:00:39,122 INFO Train. Iter 87600 : Loss. 0.00309 +2025-10-13 15:01:04,688 INFO Train. Iter 87700 : Loss. 0.00298 +2025-10-13 15:01:29,368 INFO Train. Iter 87800 : Loss. 0.00311 +2025-10-13 15:01:53,995 INFO Train. Iter 87900 : Loss. 0.00295 +2025-10-13 15:02:18,598 INFO Train. Iter 88000 : Loss. 0.00306 +2025-10-13 15:02:43,763 INFO Train. Iter 88100 : Loss. 0.00304 +2025-10-13 15:03:08,515 INFO Train. Iter 88200 : Loss. 0.00306 +2025-10-13 15:03:33,343 INFO Train. Iter 88300 : Loss. 0.00309 +2025-10-13 15:03:57,774 INFO Train. Iter 88400 : Loss. 0.00302 +2025-10-13 15:04:22,392 INFO Train. Iter 88500 : Loss. 0.00307 +2025-10-13 15:04:47,828 INFO Train. Iter 88600 : Loss. 0.00297 +2025-10-13 15:05:12,488 INFO Train. Iter 88700 : Loss. 0.00307 +2025-10-13 15:05:37,344 INFO Train. Iter 88800 : Loss. 0.00299 +2025-10-13 15:06:02,020 INFO Train. Iter 88900 : Loss. 0.00306 +2025-10-13 15:06:26,759 INFO Train. Iter 89000 : Loss. 0.00303 +2025-10-13 15:06:52,231 INFO Train. Iter 89100 : Loss. 0.00301 +2025-10-13 15:07:17,290 INFO Train. Iter 89200 : Loss. 0.00293 +2025-10-13 15:07:41,797 INFO Train. Iter 89300 : Loss. 0.00304 +2025-10-13 15:08:06,519 INFO Train. Iter 89400 : Loss. 0.00302 +2025-10-13 15:08:31,921 INFO Train. Iter 89500 : Loss. 0.00307 +2025-10-13 15:08:56,681 INFO Train. Iter 89600 : Loss. 0.00300 +2025-10-13 15:09:20,950 INFO Train. Iter 89700 : Loss. 0.00299 +2025-10-13 15:09:45,698 INFO Train. Iter 89800 : Loss. 0.00310 +2025-10-13 15:10:10,484 INFO Train. Iter 89900 : Loss. 0.00305 +2025-10-13 15:10:35,901 INFO Train. Iter 90000 : Loss. 0.00302 +2025-10-13 15:11:04,199 INFO Train. Iter 90100 : Loss. 0.00301 +2025-10-13 15:11:28,510 INFO Train. Iter 90200 : Loss. 0.00302 +2025-10-13 15:11:53,405 INFO Train. Iter 90300 : Loss. 0.00298 +2025-10-13 15:12:18,088 INFO Train. Iter 90400 : Loss. 0.00309 +2025-10-13 15:12:43,168 INFO Train. Iter 90500 : Loss. 0.00308 +2025-10-13 15:13:07,676 INFO Train. Iter 90600 : Loss. 0.00299 +2025-10-13 15:13:32,165 INFO Train. Iter 90700 : Loss. 0.00295 +2025-10-13 15:13:56,737 INFO Train. Iter 90800 : Loss. 0.00300 +2025-10-13 15:14:21,837 INFO Train. Iter 90900 : Loss. 0.00290 +2025-10-13 15:14:46,310 INFO Train. Iter 91000 : Loss. 0.00290 +2025-10-13 15:15:10,820 INFO Train. Iter 91100 : Loss. 0.00301 +2025-10-13 15:15:35,445 INFO Train. Iter 91200 : Loss. 0.00303 +2025-10-13 15:16:00,135 INFO Train. Iter 91300 : Loss. 0.00305 +2025-10-13 15:16:25,190 INFO Train. Iter 91400 : Loss. 0.00301 +2025-10-13 15:16:49,942 INFO Train. Iter 91500 : Loss. 0.00299 +2025-10-13 15:17:14,523 INFO Train. Iter 91600 : Loss. 0.00298 +2025-10-13 15:17:38,935 INFO Train. Iter 91700 : Loss. 0.00298 +2025-10-13 15:18:03,769 INFO Train. Iter 91800 : Loss. 0.00291 +2025-10-13 15:18:29,157 INFO Train. Iter 91900 : Loss. 0.00294 +2025-10-13 15:18:53,535 INFO Train. Iter 92000 : Loss. 0.00301 +2025-10-13 15:19:18,244 INFO Train. Iter 92100 : Loss. 0.00295 +2025-10-13 15:19:42,995 INFO Train. Iter 92200 : Loss. 0.00293 +2025-10-13 15:20:08,205 INFO Train. Iter 92300 : Loss. 0.00299 +2025-10-13 15:20:32,920 INFO Train. Iter 92400 : Loss. 0.00299 +2025-10-13 15:20:57,252 INFO Train. Iter 92500 : Loss. 0.00288 +2025-10-13 15:21:22,120 INFO Train. Iter 92600 : Loss. 0.00297 +2025-10-13 15:21:46,789 INFO Train. Iter 92700 : Loss. 0.00291 +2025-10-13 15:22:11,709 INFO Train. Iter 92800 : Loss. 0.00291 +2025-10-13 15:22:35,798 INFO Train. Iter 92900 : Loss. 0.00294 +2025-10-13 15:23:00,415 INFO Train. Iter 93000 : Loss. 0.00305 +2025-10-13 15:23:25,026 INFO Train. Iter 93100 : Loss. 0.00294 +2025-10-13 15:23:49,718 INFO Train. Iter 93200 : Loss. 0.00300 +2025-10-13 15:24:14,751 INFO Train. Iter 93300 : Loss. 0.00291 +2025-10-13 15:24:39,501 INFO Train. Iter 93400 : Loss. 0.00297 +2025-10-13 15:25:04,183 INFO Train. Iter 93500 : Loss. 0.00293 +2025-10-13 15:25:28,980 INFO Train. Iter 93600 : Loss. 0.00300 +2025-10-13 15:25:54,250 INFO Train. Iter 93700 : Loss. 0.00295 +2025-10-13 15:26:18,886 INFO Train. Iter 93800 : Loss. 0.00287 +2025-10-13 15:26:43,478 INFO Train. Iter 93900 : Loss. 0.00285 +2025-10-13 15:27:07,867 INFO Train. Iter 94000 : Loss. 0.00293 +2025-10-13 15:27:32,408 INFO Train. Iter 94100 : Loss. 0.00291 +2025-10-13 15:27:57,302 INFO Train. Iter 94200 : Loss. 0.00291 +2025-10-13 15:28:21,941 INFO Train. Iter 94300 : Loss. 0.00295 +2025-10-13 15:28:46,423 INFO Train. Iter 94400 : Loss. 0.00288 +2025-10-13 15:29:10,840 INFO Train. Iter 94500 : Loss. 0.00295 +2025-10-13 15:29:35,939 INFO Train. Iter 94600 : Loss. 0.00290 +2025-10-13 15:30:01,006 INFO Train. Iter 94700 : Loss. 0.00292 +2025-10-13 15:30:25,681 INFO Train. Iter 94800 : Loss. 0.00283 +2025-10-13 15:30:50,308 INFO Train. Iter 94900 : Loss. 0.00289 +2025-10-13 15:31:14,740 INFO Train. Iter 95000 : Loss. 0.00295 +2025-10-13 15:31:40,028 INFO Train. Iter 95100 : Loss. 0.00295 +2025-10-13 15:32:04,483 INFO Train. Iter 95200 : Loss. 0.00287 +2025-10-13 15:32:29,017 INFO Train. Iter 95300 : Loss. 0.00284 +2025-10-13 15:32:53,646 INFO Train. Iter 95400 : Loss. 0.00292 +2025-10-13 15:33:18,141 INFO Train. Iter 95500 : Loss. 0.00290 +2025-10-13 15:33:43,342 INFO Train. Iter 95600 : Loss. 0.00290 +2025-10-13 15:34:07,672 INFO Train. Iter 95700 : Loss. 0.00288 +2025-10-13 15:34:32,338 INFO Train. Iter 95800 : Loss. 0.00287 +2025-10-13 15:34:56,768 INFO Train. Iter 95900 : Loss. 0.00294 +2025-10-13 15:35:21,682 INFO Train. Iter 96000 : Loss. 0.00289 +2025-10-13 15:35:46,159 INFO Train. Iter 96100 : Loss. 0.00286 +2025-10-13 15:36:10,930 INFO Train. Iter 96200 : Loss. 0.00292 +2025-10-13 15:36:35,527 INFO Train. Iter 96300 : Loss. 0.00290 +2025-10-13 15:37:00,089 INFO Train. Iter 96400 : Loss. 0.00287 +2025-10-13 15:37:25,255 INFO Train. Iter 96500 : Loss. 0.00289 +2025-10-13 15:37:49,989 INFO Train. Iter 96600 : Loss. 0.00285 +2025-10-13 15:38:14,590 INFO Train. Iter 96700 : Loss. 0.00287 +2025-10-13 15:38:39,040 INFO Train. Iter 96800 : Loss. 0.00280 +2025-10-13 15:39:03,603 INFO Train. Iter 96900 : Loss. 0.00283 +2025-10-13 15:39:28,468 INFO Train. Iter 97000 : Loss. 0.00287 +2025-10-13 15:39:52,965 INFO Train. Iter 97100 : Loss. 0.00284 +2025-10-13 15:40:17,621 INFO Train. Iter 97200 : Loss. 0.00284 +2025-10-13 15:40:42,320 INFO Train. Iter 97300 : Loss. 0.00287 +2025-10-13 15:41:07,282 INFO Train. Iter 97400 : Loss. 0.00291 +2025-10-13 15:41:31,837 INFO Train. Iter 97500 : Loss. 0.00282 +2025-10-13 15:41:56,256 INFO Train. Iter 97600 : Loss. 0.00283 +2025-10-13 15:42:21,150 INFO Train. Iter 97700 : Loss. 0.00283 +2025-10-13 15:42:45,541 INFO Train. Iter 97800 : Loss. 0.00295 +2025-10-13 15:43:10,420 INFO Train. Iter 97900 : Loss. 0.00283 +2025-10-13 15:43:34,807 INFO Train. Iter 98000 : Loss. 0.00287 +2025-10-13 15:43:59,333 INFO Train. Iter 98100 : Loss. 0.00281 +2025-10-13 15:44:23,941 INFO Train. Iter 98200 : Loss. 0.00290 +2025-10-13 15:44:48,747 INFO Train. Iter 98300 : Loss. 0.00290 +2025-10-13 15:45:13,778 INFO Train. Iter 98400 : Loss. 0.00285 +2025-10-13 15:45:38,297 INFO Train. Iter 98500 : Loss. 0.00287 +2025-10-13 15:46:02,942 INFO Train. Iter 98600 : Loss. 0.00271 +2025-10-13 15:46:27,490 INFO Train. Iter 98700 : Loss. 0.00282 +2025-10-13 15:46:52,338 INFO Train. Iter 98800 : Loss. 0.00285 +2025-10-13 15:47:16,722 INFO Train. Iter 98900 : Loss. 0.00280 +2025-10-13 15:47:41,420 INFO Train. Iter 99000 : Loss. 0.00278 +2025-10-13 15:48:05,908 INFO Train. Iter 99100 : Loss. 0.00283 +2025-10-13 15:48:30,512 INFO Train. Iter 99200 : Loss. 0.00280 +2025-10-13 15:48:55,385 INFO Train. Iter 99300 : Loss. 0.00284 +2025-10-13 15:49:20,026 INFO Train. Iter 99400 : Loss. 0.00282 +2025-10-13 15:49:44,510 INFO Train. Iter 99500 : Loss. 0.00278 +2025-10-13 15:50:08,970 INFO Train. Iter 99600 : Loss. 0.00288 +2025-10-13 15:50:33,498 INFO Train. Iter 99700 : Loss. 0.00284 +2025-10-13 15:50:58,419 INFO Train. Iter 99800 : Loss. 0.00271 +2025-10-13 15:51:23,295 INFO Train. Iter 99900 : Loss. 0.00281 +2025-10-13 15:51:47,702 INFO Train. Iter 100000 : Loss. 0.00286 diff --git a/Experiments/t2m_model/latest.pth b/Experiments/t2m_model/latest.pth new file mode 100644 index 0000000000000000000000000000000000000000..b21ec056c10322a2de3b1a783363caf224cb4165 --- /dev/null +++ b/Experiments/t2m_model/latest.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07cd1dd03f703ca7aaa2b817052903bd3b666a4d87806103108c9bef1b4503aa +size 969288637 diff --git a/output/exp/events.out.tfevents.1760305753.86fdc3d9f180.4764.0 b/output/exp/events.out.tfevents.1760305753.86fdc3d9f180.4764.0 new file mode 100644 index 0000000000000000000000000000000000000000..2774f5a2abcbe9004334fd33fd201a6bde767fbd --- /dev/null +++ b/output/exp/events.out.tfevents.1760305753.86fdc3d9f180.4764.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:abb1f28a4c8a404ebc66b06f41818f82dd1f0cadf07f587dfe454559d8cc923d +size 88 diff --git a/output/exp/events.out.tfevents.1760306000.86fdc3d9f180.4858.0 b/output/exp/events.out.tfevents.1760306000.86fdc3d9f180.4858.0 new file mode 100644 index 0000000000000000000000000000000000000000..0f54314bab4c99d239e1659ad6d4255575b07d28 --- /dev/null +++ b/output/exp/events.out.tfevents.1760306000.86fdc3d9f180.4858.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c69c1e475a93d69f36531ed13043ec7dac88efac794068980c96651dc813d44 +size 88 diff --git a/output/exp/run.log b/output/exp/run.log new file mode 100644 index 0000000000000000000000000000000000000000..0c779e19415b9d32b96a07ab0e995c0a0069a893 --- /dev/null +++ b/output/exp/run.log @@ -0,0 +1,69 @@ +2025-10-12 21:49:13,879 INFO { + "batch_size": 128, + "dataname": "t2m_babel_272", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "eval_iter": 20000, + "exp_name": "exp", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 5e-05, + "lr_scheduler": [ + 50000, + 400000 + ], + "nb_joints": 22, + "nb_vis": 20, + "num_gpus": 1, + "out_dir": "output/exp", + "print_iter": 200, + "results_dir": "visual_results/", + "resume_pth": "Causal_TAE_t2m_babel/net_last.pth", + "root_loss": 7.0, + "seed": 123, + "stride_t": 2, + "total_iter": 2000000, + "vis_gt": false, + "visual_name": "vis", + "warm_up_iter": 1000, + "weight_decay": 0.0, + "window_size": 64 +} +2025-10-12 21:53:20,954 INFO { + "batch_size": 128, + "dataname": "t2m_babel_272", + "depth": 3, + "dilation_growth_rate": 3, + "down_t": 2, + "eval_iter": 20000, + "exp_name": "exp", + "gamma": 0.05, + "hidden_size": 1024, + "latent_dim": 16, + "latent_dir": "babel_272_stream/t2m_babel_latents", + "lr": 5e-05, + "lr_scheduler": [ + 50000, + 400000 + ], + "nb_joints": 22, + "nb_vis": 20, + "num_gpus": 1, + "out_dir": "output/exp", + "print_iter": 200, + "results_dir": "visual_results/", + "resume_pth": "Causal_TAE_t2m_babel/net_last.pth", + "root_loss": 7.0, + "seed": 123, + "stride_t": 2, + "total_iter": 2000000, + "vis_gt": false, + "visual_name": "vis", + "warm_up_iter": 1000, + "weight_decay": 0.0, + "window_size": 64 +} +2025-10-12 21:53:32,366 INFO loading checkpoint from Causal_TAE_t2m_babel/net_last.pth