| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | import os
|
| | import math
|
| | import argparse
|
| | import shutil
|
| | import datetime
|
| | import logging
|
| | from omegaconf import OmegaConf
|
| |
|
| | from tqdm.auto import tqdm
|
| | from einops import rearrange
|
| |
|
| | import torch
|
| | import torch.nn.functional as F
|
| | import torch.distributed as dist
|
| | from torch.utils.data.distributed import DistributedSampler
|
| | from torch.nn.parallel import DistributedDataParallel as DDP
|
| |
|
| | import diffusers
|
| | from diffusers import AutoencoderKL, DDIMScheduler
|
| | from diffusers.utils.logging import get_logger
|
| | from diffusers.optimization import get_scheduler
|
| | from diffusers.utils.import_utils import is_xformers_available
|
| | from accelerate.utils import set_seed
|
| |
|
| | from latentsync.data.unet_dataset import UNetDataset
|
| | from latentsync.models.unet import UNet3DConditionModel
|
| | from latentsync.models.syncnet import SyncNet
|
| | from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
|
| | from latentsync.utils.util import (
|
| | init_dist,
|
| | cosine_loss,
|
| | reversed_forward,
|
| | )
|
| | from latentsync.utils.util import plot_loss_chart, gather_loss
|
| | from latentsync.whisper.audio2feature import Audio2Feature
|
| | from latentsync.trepa import TREPALoss
|
| | from eval.syncnet import SyncNetEval
|
| | from eval.syncnet_detect import SyncNetDetector
|
| | from eval.eval_sync_conf import syncnet_eval
|
| | import lpips
|
| |
|
| |
|
| | logger = get_logger(__name__)
|
| |
|
| |
|
| | def main(config):
|
| |
|
| | local_rank = init_dist()
|
| | global_rank = dist.get_rank()
|
| | num_processes = dist.get_world_size()
|
| | is_main_process = global_rank == 0
|
| |
|
| | seed = config.run.seed + global_rank
|
| | set_seed(seed)
|
| |
|
| |
|
| | folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S")
|
| | output_dir = os.path.join(config.data.train_output_dir, folder_name)
|
| |
|
| |
|
| | logging.basicConfig(
|
| | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| | datefmt="%m/%d/%Y %H:%M:%S",
|
| | level=logging.INFO,
|
| | )
|
| |
|
| |
|
| | if is_main_process:
|
| | diffusers.utils.logging.set_verbosity_info()
|
| | os.makedirs(output_dir, exist_ok=True)
|
| | os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
|
| | os.makedirs(f"{output_dir}/val_videos", exist_ok=True)
|
| | os.makedirs(f"{output_dir}/loss_charts", exist_ok=True)
|
| | shutil.copy(config.unet_config_path, output_dir)
|
| | shutil.copy(config.data.syncnet_config_path, output_dir)
|
| |
|
| | device = torch.device(local_rank)
|
| |
|
| | noise_scheduler = DDIMScheduler.from_pretrained("configs")
|
| |
|
| | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
|
| | vae.config.scaling_factor = 0.18215
|
| | vae.config.shift_factor = 0
|
| | vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
| | vae.requires_grad_(False)
|
| | vae.to(device)
|
| |
|
| | syncnet_eval_model = SyncNetEval(device=device)
|
| | syncnet_eval_model.loadParameters("checkpoints/auxiliary/syncnet_v2.model")
|
| |
|
| | syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")
|
| |
|
| | if config.model.cross_attention_dim == 768:
|
| | whisper_model_path = "checkpoints/whisper/small.pt"
|
| | elif config.model.cross_attention_dim == 384:
|
| | whisper_model_path = "checkpoints/whisper/tiny.pt"
|
| | else:
|
| | raise NotImplementedError("cross_attention_dim must be 768 or 384")
|
| |
|
| | audio_encoder = Audio2Feature(
|
| | model_path=whisper_model_path,
|
| | device=device,
|
| | audio_embeds_cache_dir=config.data.audio_embeds_cache_dir,
|
| | num_frames=config.data.num_frames,
|
| | )
|
| |
|
| | unet, resume_global_step = UNet3DConditionModel.from_pretrained(
|
| | OmegaConf.to_container(config.model),
|
| | config.ckpt.resume_ckpt_path,
|
| | device=device,
|
| | )
|
| |
|
| | if config.model.add_audio_layer and config.run.use_syncnet:
|
| | syncnet_config = OmegaConf.load(config.data.syncnet_config_path)
|
| | if syncnet_config.ckpt.inference_ckpt_path == "":
|
| | raise ValueError("SyncNet path is not provided")
|
| | syncnet = SyncNet(OmegaConf.to_container(syncnet_config.model)).to(device=device, dtype=torch.float16)
|
| | syncnet_checkpoint = torch.load(syncnet_config.ckpt.inference_ckpt_path, map_location=device)
|
| | syncnet.load_state_dict(syncnet_checkpoint["state_dict"])
|
| | syncnet.requires_grad_(False)
|
| |
|
| | unet.requires_grad_(True)
|
| | trainable_params = list(unet.parameters())
|
| |
|
| | if config.optimizer.scale_lr:
|
| | config.optimizer.lr = config.optimizer.lr * num_processes
|
| |
|
| | optimizer = torch.optim.AdamW(trainable_params, lr=config.optimizer.lr)
|
| |
|
| | if is_main_process:
|
| | logger.info(f"trainable params number: {len(trainable_params)}")
|
| | logger.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
|
| |
|
| |
|
| | if config.run.enable_xformers_memory_efficient_attention:
|
| | if is_xformers_available():
|
| | unet.enable_xformers_memory_efficient_attention()
|
| | else:
|
| | raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| |
|
| |
|
| | if config.run.enable_gradient_checkpointing:
|
| | unet.enable_gradient_checkpointing()
|
| |
|
| |
|
| | train_dataset = UNetDataset(config.data.train_data_dir, config)
|
| | distributed_sampler = DistributedSampler(
|
| | train_dataset,
|
| | num_replicas=num_processes,
|
| | rank=global_rank,
|
| | shuffle=True,
|
| | seed=config.run.seed,
|
| | )
|
| |
|
| |
|
| | train_dataloader = torch.utils.data.DataLoader(
|
| | train_dataset,
|
| | batch_size=config.data.batch_size,
|
| | shuffle=False,
|
| | sampler=distributed_sampler,
|
| | num_workers=config.data.num_workers,
|
| | pin_memory=False,
|
| | drop_last=True,
|
| | worker_init_fn=train_dataset.worker_init_fn,
|
| | )
|
| |
|
| |
|
| | if config.run.max_train_steps == -1:
|
| | assert config.run.max_train_epochs != -1
|
| | config.run.max_train_steps = config.run.max_train_epochs * len(train_dataloader)
|
| |
|
| |
|
| | lr_scheduler = get_scheduler(
|
| | config.optimizer.lr_scheduler,
|
| | optimizer=optimizer,
|
| | num_warmup_steps=config.optimizer.lr_warmup_steps,
|
| | num_training_steps=config.run.max_train_steps,
|
| | )
|
| |
|
| | if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise:
|
| | lpips_loss_func = lpips.LPIPS(net="vgg").to(device)
|
| |
|
| | if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
|
| | trepa_loss_func = TREPALoss(device=device)
|
| |
|
| |
|
| | pipeline = LipsyncPipeline(
|
| | vae=vae,
|
| | audio_encoder=audio_encoder,
|
| | unet=unet,
|
| | scheduler=noise_scheduler,
|
| | ).to(device)
|
| | pipeline.set_progress_bar_config(disable=True)
|
| |
|
| |
|
| | unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
|
| |
|
| |
|
| | num_update_steps_per_epoch = math.ceil(len(train_dataloader))
|
| |
|
| | num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch)
|
| |
|
| |
|
| | total_batch_size = config.data.batch_size * num_processes
|
| |
|
| | if is_main_process:
|
| | logger.info("***** Running training *****")
|
| | logger.info(f" Num examples = {len(train_dataset)}")
|
| | logger.info(f" Num Epochs = {num_train_epochs}")
|
| | logger.info(f" Instantaneous batch size per device = {config.data.batch_size}")
|
| | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| | logger.info(f" Total optimization steps = {config.run.max_train_steps}")
|
| | global_step = resume_global_step
|
| | first_epoch = resume_global_step // num_update_steps_per_epoch
|
| |
|
| |
|
| | progress_bar = tqdm(
|
| | range(0, config.run.max_train_steps),
|
| | initial=resume_global_step,
|
| | desc="Steps",
|
| | disable=not is_main_process,
|
| | )
|
| |
|
| | train_step_list = []
|
| | sync_loss_list = []
|
| | recon_loss_list = []
|
| |
|
| | val_step_list = []
|
| | sync_conf_list = []
|
| |
|
| |
|
| | scaler = torch.cuda.amp.GradScaler() if config.run.mixed_precision_training else None
|
| |
|
| | for epoch in range(first_epoch, num_train_epochs):
|
| | train_dataloader.sampler.set_epoch(epoch)
|
| | unet.train()
|
| |
|
| | for step, batch in enumerate(train_dataloader):
|
| |
|
| |
|
| | if config.model.add_audio_layer:
|
| | if batch["mel"] != []:
|
| | mel = batch["mel"].to(device, dtype=torch.float16)
|
| |
|
| | audio_embeds_list = []
|
| | try:
|
| | for idx in range(len(batch["video_path"])):
|
| | video_path = batch["video_path"][idx]
|
| | start_idx = batch["start_idx"][idx]
|
| |
|
| | with torch.no_grad():
|
| | audio_feat = audio_encoder.audio2feat(video_path)
|
| | audio_embeds = audio_encoder.crop_overlap_audio_window(audio_feat, start_idx)
|
| | audio_embeds_list.append(audio_embeds)
|
| | except Exception as e:
|
| | logger.info(f"{type(e).__name__} - {e} - {video_path}")
|
| | continue
|
| | audio_embeds = torch.stack(audio_embeds_list)
|
| | audio_embeds = audio_embeds.to(device, dtype=torch.float16)
|
| | else:
|
| | audio_embeds = None
|
| |
|
| |
|
| | gt_images = batch["gt"].to(device, dtype=torch.float16)
|
| | gt_masked_images = batch["masked_gt"].to(device, dtype=torch.float16)
|
| | mask = batch["mask"].to(device, dtype=torch.float16)
|
| | ref_images = batch["ref"].to(device, dtype=torch.float16)
|
| |
|
| | gt_images = rearrange(gt_images, "b f c h w -> (b f) c h w")
|
| | gt_masked_images = rearrange(gt_masked_images, "b f c h w -> (b f) c h w")
|
| | mask = rearrange(mask, "b f c h w -> (b f) c h w")
|
| | ref_images = rearrange(ref_images, "b f c h w -> (b f) c h w")
|
| |
|
| | with torch.no_grad():
|
| | gt_latents = vae.encode(gt_images).latent_dist.sample()
|
| | gt_masked_images = vae.encode(gt_masked_images).latent_dist.sample()
|
| | ref_images = vae.encode(ref_images).latent_dist.sample()
|
| |
|
| | mask = torch.nn.functional.interpolate(mask, size=config.data.resolution // vae_scale_factor)
|
| |
|
| | gt_latents = (
|
| | rearrange(gt_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor
|
| | ) * vae.config.scaling_factor
|
| | gt_masked_images = (
|
| | rearrange(gt_masked_images, "(b f) c h w -> b c f h w", f=config.data.num_frames)
|
| | - vae.config.shift_factor
|
| | ) * vae.config.scaling_factor
|
| | ref_images = (
|
| | rearrange(ref_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor
|
| | ) * vae.config.scaling_factor
|
| | mask = rearrange(mask, "(b f) c h w -> b c f h w", f=config.data.num_frames)
|
| |
|
| |
|
| | if config.run.use_mixed_noise:
|
| |
|
| | noise_shared_std_dev = (config.run.mixed_noise_alpha**2 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5
|
| | noise_shared = torch.randn_like(gt_latents) * noise_shared_std_dev
|
| | noise_shared = noise_shared[:, :, 0:1].repeat(1, 1, config.data.num_frames, 1, 1)
|
| |
|
| | noise_ind_std_dev = (1 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5
|
| | noise_ind = torch.randn_like(gt_latents) * noise_ind_std_dev
|
| | noise = noise_ind + noise_shared
|
| | else:
|
| | noise = torch.randn_like(gt_latents)
|
| | noise = noise[:, :, 0:1].repeat(
|
| | 1, 1, config.data.num_frames, 1, 1
|
| | )
|
| |
|
| | bsz = gt_latents.shape[0]
|
| |
|
| |
|
| | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=gt_latents.device)
|
| | timesteps = timesteps.long()
|
| |
|
| |
|
| |
|
| | noisy_tensor = noise_scheduler.add_noise(gt_latents, noise, timesteps)
|
| |
|
| |
|
| | if noise_scheduler.config.prediction_type == "epsilon":
|
| | target = noise
|
| | elif noise_scheduler.config.prediction_type == "v_prediction":
|
| | raise NotImplementedError
|
| | else:
|
| | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| |
|
| | unet_input = torch.cat([noisy_tensor, mask, gt_masked_images, ref_images], dim=1)
|
| |
|
| |
|
| |
|
| | with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training):
|
| | pred_noise = unet(unet_input, timesteps, encoder_hidden_states=audio_embeds).sample
|
| |
|
| | if config.run.recon_loss_weight != 0:
|
| | recon_loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean")
|
| | else:
|
| | recon_loss = 0
|
| |
|
| | pred_latents = reversed_forward(noise_scheduler, pred_noise, timesteps, noisy_tensor)
|
| |
|
| | if config.run.pixel_space_supervise:
|
| | pred_images = vae.decode(
|
| | rearrange(pred_latents, "b c f h w -> (b f) c h w") / vae.config.scaling_factor
|
| | + vae.config.shift_factor
|
| | ).sample
|
| |
|
| | if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise:
|
| | pred_images_perceptual = pred_images[:, :, pred_images.shape[2] // 2 :, :]
|
| | gt_images_perceptual = gt_images[:, :, gt_images.shape[2] // 2 :, :]
|
| | lpips_loss = lpips_loss_func(pred_images_perceptual.float(), gt_images_perceptual.float()).mean()
|
| | else:
|
| | lpips_loss = 0
|
| |
|
| | if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
|
| | trepa_pred_images = rearrange(pred_images, "(b f) c h w -> b c f h w", f=config.data.num_frames)
|
| | trepa_gt_images = rearrange(gt_images, "(b f) c h w -> b c f h w", f=config.data.num_frames)
|
| | trepa_loss = trepa_loss_func(trepa_pred_images, trepa_gt_images)
|
| | else:
|
| | trepa_loss = 0
|
| |
|
| | if config.model.add_audio_layer and config.run.use_syncnet:
|
| | if config.run.pixel_space_supervise:
|
| | syncnet_input = rearrange(pred_images, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
|
| | else:
|
| | syncnet_input = rearrange(pred_latents, "b c f h w -> b (f c) h w")
|
| |
|
| | if syncnet_config.data.lower_half:
|
| | height = syncnet_input.shape[2]
|
| | syncnet_input = syncnet_input[:, :, height // 2 :, :]
|
| | ones_tensor = torch.ones((config.data.batch_size, 1)).float().to(device=device)
|
| | vision_embeds, audio_embeds = syncnet(syncnet_input, mel)
|
| | sync_loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), ones_tensor).mean()
|
| | sync_loss_list.append(gather_loss(sync_loss, device))
|
| | else:
|
| | sync_loss = 0
|
| |
|
| | loss = (
|
| | recon_loss * config.run.recon_loss_weight
|
| | + sync_loss * config.run.sync_loss_weight
|
| | + lpips_loss * config.run.perceptual_loss_weight
|
| | + trepa_loss * config.run.trepa_loss_weight
|
| | )
|
| |
|
| | train_step_list.append(global_step)
|
| | if config.run.recon_loss_weight != 0:
|
| | recon_loss_list.append(gather_loss(recon_loss, device))
|
| |
|
| | optimizer.zero_grad()
|
| |
|
| |
|
| | if config.run.mixed_precision_training:
|
| | scaler.scale(loss).backward()
|
| | """ >>> gradient clipping >>> """
|
| | scaler.unscale_(optimizer)
|
| | torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm)
|
| | """ <<< gradient clipping <<< """
|
| | scaler.step(optimizer)
|
| | scaler.update()
|
| | else:
|
| | loss.backward()
|
| | """ >>> gradient clipping >>> """
|
| | torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm)
|
| | """ <<< gradient clipping <<< """
|
| | optimizer.step()
|
| |
|
| |
|
| |
|
| |
|
| | lr_scheduler.step()
|
| | progress_bar.update(1)
|
| | global_step += 1
|
| |
|
| |
|
| |
|
| |
|
| | if is_main_process and (global_step % config.ckpt.save_ckpt_steps == 0):
|
| | if config.run.recon_loss_weight != 0:
|
| | plot_loss_chart(
|
| | os.path.join(output_dir, f"loss_charts/recon_loss_chart-{global_step}.png"),
|
| | ("Reconstruction loss", train_step_list, recon_loss_list),
|
| | )
|
| | if config.model.add_audio_layer:
|
| | if sync_loss_list != []:
|
| | plot_loss_chart(
|
| | os.path.join(output_dir, f"loss_charts/sync_loss_chart-{global_step}.png"),
|
| | ("Sync loss", train_step_list, sync_loss_list),
|
| | )
|
| | model_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt")
|
| | state_dict = {
|
| | "global_step": global_step,
|
| | "state_dict": unet.module.state_dict(),
|
| | }
|
| | try:
|
| | torch.save(state_dict, model_save_path)
|
| | logger.info(f"Saved checkpoint to {model_save_path}")
|
| | except Exception as e:
|
| | logger.error(f"Error saving model: {e}")
|
| |
|
| |
|
| | logger.info("Running validation... ")
|
| |
|
| | validation_video_out_path = os.path.join(output_dir, f"val_videos/val_video_{global_step}.mp4")
|
| | validation_video_mask_path = os.path.join(output_dir, f"val_videos/val_video_mask.mp4")
|
| |
|
| | with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| | pipeline(
|
| | config.data.val_video_path,
|
| | config.data.val_audio_path,
|
| | validation_video_out_path,
|
| | validation_video_mask_path,
|
| | num_frames=config.data.num_frames,
|
| | num_inference_steps=config.run.inference_steps,
|
| | guidance_scale=config.run.guidance_scale,
|
| | weight_dtype=torch.float16,
|
| | width=config.data.resolution,
|
| | height=config.data.resolution,
|
| | mask=config.data.mask,
|
| | )
|
| |
|
| | logger.info(f"Saved validation video output to {validation_video_out_path}")
|
| |
|
| | val_step_list.append(global_step)
|
| |
|
| | if config.model.add_audio_layer:
|
| | try:
|
| | _, conf = syncnet_eval(syncnet_eval_model, syncnet_detector, validation_video_out_path, "temp")
|
| | except Exception as e:
|
| | logger.info(e)
|
| | conf = 0
|
| | sync_conf_list.append(conf)
|
| | plot_loss_chart(
|
| | os.path.join(output_dir, f"loss_charts/sync_conf_chart-{global_step}.png"),
|
| | ("Sync confidence", val_step_list, sync_conf_list),
|
| | )
|
| |
|
| | logs = {"step_loss": loss.item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| | progress_bar.set_postfix(**logs)
|
| |
|
| | if global_step >= config.run.max_train_steps:
|
| | break
|
| |
|
| | progress_bar.close()
|
| | dist.destroy_process_group()
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | parser = argparse.ArgumentParser()
|
| |
|
| |
|
| | parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml")
|
| |
|
| | args = parser.parse_args()
|
| | config = OmegaConf.load(args.unet_config_path)
|
| | config.unet_config_path = args.unet_config_path
|
| |
|
| | main(config)
|
| |
|