Spaces:
Running
Running
| import torch | |
| import os | |
| import argparse | |
| import logging | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| import torch.distributed as dist | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from copy import deepcopy | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.utils.data import DataLoader | |
| from glob import glob | |
| import yaml | |
| from collections import OrderedDict | |
| from time import time | |
| from einops import rearrange, repeat | |
| from diffusers import AutoencoderKL | |
| from transformers import SpeechT5HifiGan | |
| from audioldm2.utilities.data.dataset import AudioDataset | |
| from constants import build_model | |
| from utils import load_clip, load_clap, load_t5 | |
| from thop import profile | |
| def update_ema(ema_model, model, decay=0.9999): | |
| """ | |
| Step the EMA model towards the current model. | |
| """ | |
| ema_params = OrderedDict(ema_model.named_parameters()) | |
| model_params = OrderedDict(model.named_parameters()) | |
| for name, param in model_params.items(): | |
| # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed | |
| ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) | |
| def requires_grad(model, flag=True): | |
| """ | |
| Set requires_grad flag for all parameters in a model. | |
| """ | |
| for p in model.parameters(): | |
| p.requires_grad = flag | |
| def cleanup(): | |
| """ | |
| End DDP training. | |
| """ | |
| dist.destroy_process_group() | |
| def create_logger(logging_dir): | |
| """ | |
| Create a logger that writes to a log file and stdout. | |
| """ | |
| if dist.get_rank() == 0: # real logger | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='[\033[34m%(asctime)s\033[0m] %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S', | |
| handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| else: # dummy logger (does nothing) | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(logging.NullHandler()) | |
| return logger | |
| class RF(torch.nn.Module): | |
| def __init__(self, ln=True): | |
| super().__init__() | |
| self.ln = ln | |
| self.stratified = False | |
| def forward(self, model, x, **kwargs): | |
| b = x.size(0) | |
| if self.ln: | |
| if self.stratified: | |
| # stratified sampling of normals | |
| # first stratified sample from uniform | |
| quantiles = torch.linspace(0, 1, b + 1).to(x.device) | |
| z = quantiles[:-1] + torch.rand((b,)).to(x.device) / b | |
| # now transform to normal | |
| z = torch.erfinv(2 * z - 1) * math.sqrt(2) | |
| t = torch.sigmoid(z) | |
| else: | |
| nt = torch.randn((b,)).to(x.device) | |
| t = torch.sigmoid(nt) | |
| else: | |
| t = torch.rand((b,)).to(x.device) | |
| texp = t.view([b, *([1] * len(x.shape[1:]))]) | |
| z1 = torch.randn_like(x) | |
| zt = (1 - texp) * x + texp * z1 | |
| # make t, zt into same dtype as x | |
| zt, t = zt.to(x.dtype), t.to(x.dtype) | |
| vtheta = model(x=zt, t=t, **kwargs) | |
| # print(z1.size(), x.size(), vtheta.size()) | |
| batchwise_mse = ((z1 - x - vtheta) ** 2).mean(dim=list(range(1, len(x.shape)))) | |
| tlist = batchwise_mse.detach().cpu().reshape(-1).tolist() | |
| ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)] | |
| return batchwise_mse.mean(), {"batchwise_loss": ttloss} | |
| def sample(self, model, z, conds, null_cond=None, sample_steps=50, cfg=2.0, **kwargs): | |
| b = z.size(0) | |
| dt = 1.0 / sample_steps | |
| dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))]) | |
| images = [z] | |
| for i in range(sample_steps, 0, -1): | |
| t = i / sample_steps | |
| t = torch.tensor([t] * b).to(z.device) | |
| vc = model(x=z, t=t, **conds) | |
| if null_cond is not None: | |
| vu = model(x=z, t=t, **null_cond) | |
| vc = vu + cfg * (vc - vu) | |
| z = z - dt * vc | |
| images.append(z) | |
| return images | |
| def sample_with_xps(self, model, z, conds, null_cond=None, sample_steps=50, cfg=2.0, **kwargs): | |
| b = z.size(0) | |
| dt = 1.0 / sample_steps | |
| dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))]) | |
| images = [z] | |
| for i in range(sample_steps, 0, -1): | |
| t = i / sample_steps | |
| t = torch.tensor([t] * b).to(z.device) | |
| # print(z.size(), t.size()) | |
| vc = model(x=z, t=t, **conds) | |
| if null_cond is not None: | |
| vu = model(x=z, t=t, **null_cond) | |
| vc = vu + cfg * (vc - vu) | |
| x = z - i * dt * vc | |
| z = z - dt * vc | |
| images.append(x) | |
| return images | |
| def prepare_model_inputs(args, batch, device, vae, clip, t5,): | |
| text_embedding, text_embedding_mask = batch['text_embedding'], batch['text_embedding_mask'] | |
| text_embedding_t5, text_embedding_mask_t5 = batch['text_embedding_t5'], batch['text_embedding_mask_t5'] | |
| # print(image.size(), text_embedding.size(), text_embedding_t5.size()) | |
| # clip & mT5 text embedding | |
| text_embedding = text_embedding.to(device) | |
| text_embedding_mask = text_embedding_mask.to(device) | |
| with torch.no_grad(): | |
| encoder_hidden_states = clip.hf_module( | |
| text_embedding.to(device), | |
| attention_mask=text_embedding_mask, | |
| output_hidden_states=False, | |
| )["pooler_output"] # () | |
| # print(encoder_hidden_states.size()) | |
| text_embedding_t5 = text_embedding_t5.to(device).squeeze(1) | |
| text_embedding_mask_t5 = text_embedding_mask_t5.to(device).squeeze(1) | |
| with torch.no_grad(): | |
| output_t5 = t5.hf_module( | |
| input_ids=text_embedding_t5, | |
| attention_mask=text_embedding_mask_t5, | |
| output_hidden_states=False, | |
| ) | |
| encoder_hidden_states_t5 = output_t5["last_hidden_state"].detach() | |
| with torch.no_grad(): | |
| image = vae.encode(batch['log_mel_spec'].unsqueeze(1).to(device)).latent_dist.sample().mul_(vae.config.scaling_factor) | |
| # positional embedding | |
| bs, c, h, w = image.shape | |
| image = rearrange(image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).float() | |
| img_ids = torch.zeros(h // 2, w // 2, 3) | |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] | |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] | |
| img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) | |
| txt_ids = torch.zeros(bs, encoder_hidden_states_t5.shape[1], 3) | |
| # Model conditions | |
| model_kwargs = dict( | |
| img_ids=img_ids.to(image.device), | |
| txt = encoder_hidden_states_t5.to(image.device).float(), | |
| txt_ids = txt_ids.to(image.device), | |
| y = encoder_hidden_states.to(image.device).float(), | |
| ) | |
| return image, model_kwargs | |
| def main(args): | |
| assert torch.cuda.is_available(), "Training currently requires at least one GPU." | |
| dist.init_process_group("nccl") | |
| assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." | |
| rank = dist.get_rank() | |
| device = rank % torch.cuda.device_count() | |
| seed = args.global_seed * dist.get_world_size() + rank | |
| torch.manual_seed(seed) | |
| torch.cuda.set_device(device) | |
| print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") | |
| # Setup an experiment folder: | |
| if rank == 0: | |
| os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) | |
| experiment_index = len(glob(f"{args.results_dir}/*")) | |
| model_string_name = args.version.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders) | |
| experiment_dir = f"{args.results_dir}/{model_string_name}" # Create an experiment folder | |
| checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| logger = create_logger(experiment_dir) | |
| logger.info(f"Experiment directory created at {experiment_dir}") | |
| else: | |
| logger = create_logger(None) | |
| model = build_model(args.version).to(device) | |
| parameters_sum = sum(x.numel() for x in model.parameters()) | |
| logger.info(f"{parameters_sum / 1000000.0} M") | |
| if args.resume is not None: | |
| print('load from: ', args.resume) | |
| resume_ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage)['ema'] | |
| model.load_state_dict(resume_ckpt) | |
| # Note that parameter initialization is done within the DiT constructor | |
| ema = deepcopy(model).to(device) # Create an EMA of the model for use after training | |
| requires_grad(ema, False) | |
| model = DDP(model.to(device), device_ids=[rank]) | |
| diffusion = RF() | |
| model_path = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioldm2' | |
| vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae')).to(device) | |
| # vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder')).to(device) | |
| t5 = load_t5(device, max_length=256) | |
| clap = load_clap(device, max_length=256) | |
| # clip = load_clip(device) | |
| opt = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0) | |
| config = yaml.load( | |
| open( | |
| 'config/16k_64.yaml', | |
| 'r' | |
| ), | |
| Loader=yaml.FullLoader, | |
| ) | |
| dataset = AudioDataset( | |
| config=config, split="train", | |
| waveform_only=False, | |
| dataset_json_path=args.data_path, | |
| tokenizer=clap.tokenizer, | |
| uncond_pro=0.1, | |
| text_ctx_len=77, | |
| tokenizer_t5=t5.tokenizer, | |
| text_ctx_len_t5=256, | |
| uncond_pro_t5=0.1, | |
| ) | |
| sampler = DistributedSampler( | |
| dataset, | |
| num_replicas=dist.get_world_size(), | |
| rank=rank, | |
| shuffle=True, | |
| seed=args.global_seed | |
| ) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=int(args.global_batch_size // dist.get_world_size()), | |
| shuffle=False, | |
| sampler=sampler, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=True | |
| ) | |
| logger.info(f"Dataset contains {len(dataset):,}") | |
| # Prepare models for training: | |
| update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights | |
| model.train() # important! This enables embedding dropout for classifier-free guidance | |
| ema.eval() # EMA model should always be in eval mode | |
| # Variables for monitoring/logging purposes: | |
| train_steps = 0 | |
| log_steps = 0 | |
| running_loss = 0 | |
| start_time = time() | |
| logger.info(f"Training for {args.epochs} epochs...") | |
| for epoch in range(args.epochs): | |
| sampler.set_epoch(epoch) | |
| logger.info(f"Beginning epoch {epoch}...") | |
| data_iter_step = 0 | |
| for batch in loader: | |
| latents, model_kwargs = prepare_model_inputs(args, batch, device, vae, clap, t5,) | |
| loss, _ = diffusion.forward(model=model, x=latents, **model_kwargs) | |
| # print(loss) | |
| if (data_iter_step + 1) % args.accum_iter == 0: | |
| opt.zero_grad() | |
| loss.backward() | |
| opt.step() | |
| update_ema(ema, model.module) | |
| data_iter_step += 1 | |
| # Log loss values: | |
| running_loss += loss.item() | |
| log_steps += 1 | |
| train_steps += 1 | |
| if train_steps % args.log_every == 0: | |
| # Measure training speed: | |
| torch.cuda.synchronize() | |
| end_time = time() | |
| steps_per_sec = log_steps / (end_time - start_time) | |
| # Reduce loss history over all processes: | |
| avg_loss = torch.tensor(running_loss / log_steps, device=device) | |
| dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) | |
| avg_loss = avg_loss.item() / dist.get_world_size() | |
| logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") | |
| # Reset monitoring variables: | |
| running_loss = 0 | |
| log_steps = 0 | |
| start_time = time() | |
| # Save DiT checkpoint: | |
| if train_steps % args.ckpt_every == 0 and train_steps > 0: | |
| if rank == 0: | |
| checkpoint = { | |
| # "model": model.module.state_dict(), | |
| "ema": ema.state_dict(), | |
| "opt": opt.state_dict(), | |
| "args": args | |
| } | |
| checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" | |
| try: | |
| torch.save(checkpoint, checkpoint_path) | |
| except Exception as e: | |
| print(e) | |
| logger.info(f"Saved checkpoint to {checkpoint_path}") | |
| dist.barrier() | |
| # model.eval() # important! This disables randomized embedding dropout | |
| # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... | |
| logger.info("Done!") | |
| cleanup() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data-path", type=str, default='fma_dataset.json') | |
| parser.add_argument("--results-dir", type=str, default="results") | |
| parser.add_argument("--resume", type=str, default=None) | |
| parser.add_argument("--version", type=str, default="large") | |
| parser.add_argument("--vae-path", type=str, default='audioldm2/vae') | |
| parser.add_argument("--epochs", type=int, default=1400) | |
| parser.add_argument("--global_batch_size", type=int, default=32) | |
| parser.add_argument("--global-seed", type=int, default=1234) | |
| parser.add_argument("--num-workers", type=int, default=4) | |
| parser.add_argument("--log-every", type=int, default=100) | |
| parser.add_argument('--accum_iter', default=16, type=int,) | |
| parser.add_argument("--ckpt-every", type=int, default=100_000) | |
| parser.add_argument('--local-rank', type=int, default=-1, help='local rank passed from distributed launcher') | |
| args = parser.parse_args() | |
| main(args) |