|
|
import os |
|
|
from pprint import pformat |
|
|
|
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
from omegaconf import OmegaConf |
|
|
from pytorch_lightning import loggers as pl_loggers |
|
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
|
|
|
|
|
|
from mld.callback import ProgressLogger |
|
|
from mld.config import parse_args |
|
|
from mld.data.get_data import get_datasets |
|
|
from mld.models.get_model import get_model |
|
|
from mld.utils.logger import create_logger |
|
|
|
|
|
|
|
|
def main(): |
|
|
cfg = parse_args() |
|
|
|
|
|
|
|
|
logger = create_logger(cfg, phase="train") |
|
|
|
|
|
|
|
|
if cfg.TRAIN.RESUME: |
|
|
resume = cfg.TRAIN.RESUME |
|
|
backcfg = cfg.TRAIN.copy() |
|
|
if os.path.exists(resume): |
|
|
file_list = sorted(os.listdir(resume), reverse=True) |
|
|
for item in file_list: |
|
|
if item.endswith(".yaml"): |
|
|
cfg = OmegaConf.load(os.path.join(resume, item)) |
|
|
cfg.TRAIN = backcfg |
|
|
break |
|
|
checkpoints = sorted(os.listdir(os.path.join( |
|
|
resume, "checkpoints")), |
|
|
key=lambda x: int(x[6:-5]), |
|
|
reverse=True) |
|
|
for checkpoint in checkpoints: |
|
|
if "epoch=" in checkpoint: |
|
|
cfg.TRAIN.PRETRAINED = os.path.join( |
|
|
resume, "checkpoints", checkpoint) |
|
|
break |
|
|
if os.path.exists(os.path.join(resume, "wandb")): |
|
|
wandb_list = sorted(os.listdir(os.path.join(resume, "wandb")), |
|
|
reverse=True) |
|
|
for item in wandb_list: |
|
|
if "run-" in item: |
|
|
cfg.LOGGER.WANDB.RESUME_ID = item.split("-")[-1] |
|
|
|
|
|
else: |
|
|
raise ValueError("Resume path is not right.") |
|
|
|
|
|
pl.seed_everything(cfg.SEED_VALUE) |
|
|
|
|
|
|
|
|
if cfg.ACCELERATOR == "gpu": |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
|
|
|
loggers = [] |
|
|
if cfg.LOGGER.WANDB.PROJECT: |
|
|
wandb_logger = pl_loggers.WandbLogger( |
|
|
project=cfg.LOGGER.WANDB.PROJECT, |
|
|
offline=cfg.LOGGER.WANDB.OFFLINE, |
|
|
id=cfg.LOGGER.WANDB.RESUME_ID, |
|
|
save_dir=cfg.FOLDER_EXP, |
|
|
version="", |
|
|
name=cfg.NAME, |
|
|
anonymous=False, |
|
|
log_model=False, |
|
|
) |
|
|
loggers.append(wandb_logger) |
|
|
if cfg.LOGGER.TENSORBOARD: |
|
|
tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.FOLDER_EXP, |
|
|
sub_dir="tensorboard", |
|
|
version="", |
|
|
name="") |
|
|
loggers.append(tb_logger) |
|
|
logger.info(OmegaConf.to_yaml(cfg)) |
|
|
|
|
|
|
|
|
datasets = get_datasets(cfg, logger=logger) |
|
|
logger.info("datasets module {} initialized".format("".join( |
|
|
cfg.TRAIN.DATASETS))) |
|
|
|
|
|
|
|
|
model = get_model(cfg, datasets[0]) |
|
|
logger.info("model {} loaded".format(cfg.model.model_type)) |
|
|
|
|
|
|
|
|
|
|
|
if cfg.TRAIN.STAGE in ['gpt']: |
|
|
logger.info("Loading pretrain vae from {}".format( |
|
|
cfg.TRAIN.PRETRAINED_VAE)) |
|
|
state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE, |
|
|
map_location="cpu")["state_dict"] |
|
|
|
|
|
from collections import OrderedDict |
|
|
vae_dict = OrderedDict() |
|
|
for k, v in state_dict.items(): |
|
|
if k.split(".")[0] == "vae": |
|
|
name = k.replace("vae.vqvae", "vqvae") |
|
|
vae_dict[name] = v |
|
|
model.vae.load_state_dict(vae_dict, strict=True) |
|
|
else: |
|
|
if cfg.TRAIN.PRETRAINED_VAE: |
|
|
logger.info("Loading pretrain vae from {}".format( |
|
|
cfg.TRAIN.PRETRAINED_VAE)) |
|
|
state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE, |
|
|
map_location="cpu")["state_dict"] |
|
|
|
|
|
from collections import OrderedDict |
|
|
vae_dict = OrderedDict() |
|
|
for k, v in state_dict.items(): |
|
|
if k.split(".")[0] == "vae": |
|
|
name = k.replace("vae.", "") |
|
|
vae_dict[name] = v |
|
|
model.vae.load_state_dict(vae_dict, strict=True) |
|
|
|
|
|
|
|
|
|
|
|
metric_monitor = { |
|
|
"Train_jf": "recons/text2jfeats/train", |
|
|
"Val_jf": "recons/text2jfeats/val", |
|
|
"Train_rf": "recons/text2rfeats/train", |
|
|
"Val_rf": "recons/text2rfeats/val", |
|
|
"APE root": "Metrics/APE_root", |
|
|
"APE mean pose": "Metrics/APE_mean_pose", |
|
|
"AVE root": "Metrics/AVE_root", |
|
|
"AVE mean pose": "Metrics/AVE_mean_pose", |
|
|
"R_TOP_1": "Metrics/R_precision_top_1", |
|
|
"R_TOP_2": "Metrics/R_precision_top_2", |
|
|
"R_TOP_3": "Metrics/R_precision_top_3", |
|
|
"gt_R_TOP_1": "Metrics/gt_R_precision_top_1", |
|
|
"gt_R_TOP_2": "Metrics/gt_R_precision_top_2", |
|
|
"gt_R_TOP_3": "Metrics/gt_R_precision_top_3", |
|
|
"FID": "Metrics/FID", |
|
|
"gt_FID": "Metrics/gt_FID", |
|
|
"Diversity": "Metrics/Diversity", |
|
|
"gt_Diversity": "Metrics/gt_Diversity", |
|
|
"MM dist": "Metrics/Matching_score", |
|
|
"Accuracy": "Metrics/accuracy", |
|
|
"gt_Accuracy": "Metrics/gt_accuracy", |
|
|
} |
|
|
|
|
|
|
|
|
callbacks = [ |
|
|
pl.callbacks.RichProgressBar(), |
|
|
ProgressLogger(metric_monitor=metric_monitor), |
|
|
|
|
|
ModelCheckpoint( |
|
|
dirpath=os.path.join(cfg.FOLDER_EXP, "checkpoints"), |
|
|
filename="{epoch}", |
|
|
monitor="step", |
|
|
mode="max", |
|
|
every_n_epochs=cfg.LOGGER.SAVE_CHECKPOINT_EPOCH, |
|
|
save_top_k=-1, |
|
|
save_last=False, |
|
|
save_on_train_epoch_end=True, |
|
|
), |
|
|
] |
|
|
logger.info("Callbacks initialized") |
|
|
|
|
|
if len(cfg.DEVICE) > 1: |
|
|
|
|
|
ddp_strategy = "ddp" |
|
|
else: |
|
|
ddp_strategy = None |
|
|
|
|
|
|
|
|
trainer = pl.Trainer( |
|
|
benchmark=False, |
|
|
max_epochs=cfg.TRAIN.END_EPOCH, |
|
|
accelerator=cfg.ACCELERATOR, |
|
|
devices=cfg.DEVICE, |
|
|
|
|
|
strategy=ddp_strategy, |
|
|
|
|
|
default_root_dir=cfg.FOLDER_EXP, |
|
|
log_every_n_steps=cfg.LOGGER.VAL_EVERY_STEPS, |
|
|
deterministic=False, |
|
|
detect_anomaly=False, |
|
|
enable_progress_bar=True, |
|
|
logger=loggers, |
|
|
callbacks=callbacks, |
|
|
check_val_every_n_epoch=cfg.LOGGER.VAL_EVERY_STEPS, |
|
|
) |
|
|
logger.info("Trainer initialized") |
|
|
|
|
|
if cfg.TRAIN.STAGE == 'temos': |
|
|
vae_type = 'temos' |
|
|
else: |
|
|
vae_type = cfg.model.motion_vae.target.split(".")[-1].lower().replace( |
|
|
"vae", "") |
|
|
|
|
|
|
|
|
if cfg.TRAIN.PRETRAINED_MLD: |
|
|
logger.info("Loading pretrain mld from {}".format( |
|
|
cfg.TRAIN.PRETRAINED_MLD)) |
|
|
|
|
|
state_dict = torch.load(cfg.TRAIN.PRETRAINED_MLD, |
|
|
map_location="cpu")["state_dict"] |
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
vae_dict = OrderedDict() |
|
|
for k, v in state_dict.items(): |
|
|
if k.split(".")[0] == "denoiser": |
|
|
name = k.replace("denoiser.", "") |
|
|
vae_dict[name] = v |
|
|
model.denoiser.load_state_dict(vae_dict, strict=True) |
|
|
|
|
|
|
|
|
|
|
|
if cfg.TRAIN.PRETRAINED: |
|
|
|
|
|
logger.info("Loading pretrain mode from {}".format( |
|
|
cfg.TRAIN.PRETRAINED)) |
|
|
logger.info("Attention! VAE will be recovered") |
|
|
state_dict = torch.load(cfg.TRAIN.PRETRAINED, |
|
|
map_location="cpu")["state_dict"] |
|
|
|
|
|
from collections import OrderedDict |
|
|
|
|
|
new_state_dict = OrderedDict() |
|
|
for k, v in state_dict.items(): |
|
|
if k not in ["denoiser.sequence_pos_encoding.pe"]: |
|
|
new_state_dict[k] = v |
|
|
model.load_state_dict(new_state_dict, strict=False) |
|
|
|
|
|
|
|
|
|
|
|
if cfg.TRAIN.RESUME: |
|
|
trainer.validate(model, datamodule=datasets[0], ckpt_path=cfg.TRAIN.PRETRAINED) |
|
|
trainer.fit(model, |
|
|
datamodule=datasets[0], |
|
|
ckpt_path=cfg.TRAIN.PRETRAINED) |
|
|
else: |
|
|
trainer.fit(model, datamodule=datasets[0]) |
|
|
|
|
|
|
|
|
checkpoint_folder = trainer.checkpoint_callback.dirpath |
|
|
logger.info(f"The checkpoints are stored in {checkpoint_folder}") |
|
|
logger.info( |
|
|
f"The outputs of this experiment are stored in {cfg.FOLDER_EXP}") |
|
|
|
|
|
|
|
|
logger.info("Training ends!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|