zirobtc's picture
Initial upload of MotionStreamer code, excluding large extracted data and output folders.
adeebb7 verified
import os
from pprint import pformat
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
# from pytorch_lightning.strategies.ddp import DDPStrategy
from mld.callback import ProgressLogger
from mld.config import parse_args
from mld.data.get_data import get_datasets
from mld.models.get_model import get_model
from mld.utils.logger import create_logger
def main():
cfg = parse_args() # parse config file
# create logger
logger = create_logger(cfg, phase="train")
# resume
if cfg.TRAIN.RESUME:
resume = cfg.TRAIN.RESUME
backcfg = cfg.TRAIN.copy()
if os.path.exists(resume):
file_list = sorted(os.listdir(resume), reverse=True)
for item in file_list:
if item.endswith(".yaml"):
cfg = OmegaConf.load(os.path.join(resume, item))
cfg.TRAIN = backcfg
break
checkpoints = sorted(os.listdir(os.path.join(
resume, "checkpoints")),
key=lambda x: int(x[6:-5]),
reverse=True)
for checkpoint in checkpoints:
if "epoch=" in checkpoint:
cfg.TRAIN.PRETRAINED = os.path.join(
resume, "checkpoints", checkpoint)
break
if os.path.exists(os.path.join(resume, "wandb")):
wandb_list = sorted(os.listdir(os.path.join(resume, "wandb")),
reverse=True)
for item in wandb_list:
if "run-" in item:
cfg.LOGGER.WANDB.RESUME_ID = item.split("-")[-1]
else:
raise ValueError("Resume path is not right.")
# set seed
pl.seed_everything(cfg.SEED_VALUE)
# gpu setting
if cfg.ACCELERATOR == "gpu":
# os.environ["PYTHONWARNINGS"] = "ignore"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(str(x) for x in cfg.DEVICE)
# tensorboard logger and wandb logger
loggers = []
if cfg.LOGGER.WANDB.PROJECT:
wandb_logger = pl_loggers.WandbLogger(
project=cfg.LOGGER.WANDB.PROJECT,
offline=cfg.LOGGER.WANDB.OFFLINE,
id=cfg.LOGGER.WANDB.RESUME_ID,
save_dir=cfg.FOLDER_EXP,
version="",
name=cfg.NAME,
anonymous=False,
log_model=False,
)
loggers.append(wandb_logger)
if cfg.LOGGER.TENSORBOARD:
tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.FOLDER_EXP,
sub_dir="tensorboard",
version="",
name="")
loggers.append(tb_logger)
logger.info(OmegaConf.to_yaml(cfg))
# create dataset
datasets = get_datasets(cfg, logger=logger)
logger.info("datasets module {} initialized".format("".join(
cfg.TRAIN.DATASETS)))
# create model
model = get_model(cfg, datasets[0])
logger.info("model {} loaded".format(cfg.model.model_type))
if cfg.TRAIN.STAGE in ['gpt']:
logger.info("Loading pretrain vae from {}".format(
cfg.TRAIN.PRETRAINED_VAE))
state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE,
map_location="cpu")["state_dict"]
# extract encoder/decoder
from collections import OrderedDict
vae_dict = OrderedDict()
for k, v in state_dict.items():
if k.split(".")[0] == "vae":
name = k.replace("vae.vqvae", "vqvae")
vae_dict[name] = v
model.vae.load_state_dict(vae_dict, strict=True)
else:
if cfg.TRAIN.PRETRAINED_VAE:
logger.info("Loading pretrain vae from {}".format(
cfg.TRAIN.PRETRAINED_VAE))
state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE,
map_location="cpu")["state_dict"]
# extract encoder/decoder
from collections import OrderedDict
vae_dict = OrderedDict()
for k, v in state_dict.items():
if k.split(".")[0] == "vae":
name = k.replace("vae.", "")
vae_dict[name] = v
model.vae.load_state_dict(vae_dict, strict=True)
# optimizer
metric_monitor = {
"Train_jf": "recons/text2jfeats/train",
"Val_jf": "recons/text2jfeats/val",
"Train_rf": "recons/text2rfeats/train",
"Val_rf": "recons/text2rfeats/val",
"APE root": "Metrics/APE_root",
"APE mean pose": "Metrics/APE_mean_pose",
"AVE root": "Metrics/AVE_root",
"AVE mean pose": "Metrics/AVE_mean_pose",
"R_TOP_1": "Metrics/R_precision_top_1",
"R_TOP_2": "Metrics/R_precision_top_2",
"R_TOP_3": "Metrics/R_precision_top_3",
"gt_R_TOP_1": "Metrics/gt_R_precision_top_1",
"gt_R_TOP_2": "Metrics/gt_R_precision_top_2",
"gt_R_TOP_3": "Metrics/gt_R_precision_top_3",
"FID": "Metrics/FID",
"gt_FID": "Metrics/gt_FID",
"Diversity": "Metrics/Diversity",
"gt_Diversity": "Metrics/gt_Diversity",
"MM dist": "Metrics/Matching_score",
"Accuracy": "Metrics/accuracy",
"gt_Accuracy": "Metrics/gt_accuracy",
}
# callbacks
callbacks = [
pl.callbacks.RichProgressBar(),
ProgressLogger(metric_monitor=metric_monitor),
# ModelCheckpoint(dirpath=os.path.join(cfg.FOLDER_EXP,'checkpoints'),filename='latest-{epoch}',every_n_epochs=1,save_top_k=1,save_last=True,save_on_train_epoch_end=True),
ModelCheckpoint(
dirpath=os.path.join(cfg.FOLDER_EXP, "checkpoints"),
filename="{epoch}",
monitor="step",
mode="max",
every_n_epochs=cfg.LOGGER.SAVE_CHECKPOINT_EPOCH,
save_top_k=-1,
save_last=False,
save_on_train_epoch_end=True,
),
]
logger.info("Callbacks initialized")
if len(cfg.DEVICE) > 1:
# ddp_strategy = DDPStrategy(find_unused_parameters=False)
ddp_strategy = "ddp"
else:
ddp_strategy = None
# trainer
trainer = pl.Trainer(
benchmark=False,
max_epochs=cfg.TRAIN.END_EPOCH,
accelerator=cfg.ACCELERATOR,
devices=cfg.DEVICE,
# gpus=2,
strategy=ddp_strategy,
# move_metrics_to_cpu=True,
default_root_dir=cfg.FOLDER_EXP,
log_every_n_steps=cfg.LOGGER.VAL_EVERY_STEPS,
deterministic=False,
detect_anomaly=False,
enable_progress_bar=True,
logger=loggers,
callbacks=callbacks,
check_val_every_n_epoch=cfg.LOGGER.VAL_EVERY_STEPS,
)
logger.info("Trainer initialized")
if cfg.TRAIN.STAGE == 'temos':
vae_type = 'temos'
else:
vae_type = cfg.model.motion_vae.target.split(".")[-1].lower().replace(
"vae", "")
if cfg.TRAIN.PRETRAINED_MLD:
logger.info("Loading pretrain mld from {}".format(
cfg.TRAIN.PRETRAINED_MLD))
state_dict = torch.load(cfg.TRAIN.PRETRAINED_MLD,
map_location="cpu")["state_dict"]
from collections import OrderedDict
vae_dict = OrderedDict()
for k, v in state_dict.items():
if k.split(".")[0] == "denoiser":
name = k.replace("denoiser.", "")
vae_dict[name] = v
model.denoiser.load_state_dict(vae_dict, strict=True)
if cfg.TRAIN.PRETRAINED:
logger.info("Loading pretrain mode from {}".format(
cfg.TRAIN.PRETRAINED))
logger.info("Attention! VAE will be recovered")
state_dict = torch.load(cfg.TRAIN.PRETRAINED,
map_location="cpu")["state_dict"]
# remove mismatched and unused params
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k not in ["denoiser.sequence_pos_encoding.pe"]:
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
# fitting
if cfg.TRAIN.RESUME:
trainer.validate(model, datamodule=datasets[0], ckpt_path=cfg.TRAIN.PRETRAINED)
trainer.fit(model,
datamodule=datasets[0],
ckpt_path=cfg.TRAIN.PRETRAINED)
else:
trainer.fit(model, datamodule=datasets[0])
# checkpoint
checkpoint_folder = trainer.checkpoint_callback.dirpath
logger.info(f"The checkpoints are stored in {checkpoint_folder}")
logger.info(
f"The outputs of this experiment are stored in {cfg.FOLDER_EXP}")
# end
logger.info("Training ends!")
if __name__ == "__main__":
main()