import os import gc import argparse import datetime from io import BytesIO from glob import glob from tqdm.auto import tqdm from PIL import Image import matplotlib.pyplot as plt import torch import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter from torch.utils.data import Dataset, DataLoader from torchvision.transforms import v2, InterpolationMode import datasets import bitsandbytes as bnb from transformers import CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, UNet2DConditionModel def parse_args(): parser = argparse.ArgumentParser( description = "DiT training script", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--output_dir", type = str, default = "./outputs", help = "Output directory for training results", ) parser.add_argument( "--unet", type = str, default = "./sd_flow_unet", help = "folder for unet init", ) parser.add_argument( "--seed", type = int, default = 42, help = "Seed for reproducible training", ) parser.add_argument( "--batch_size", type = int, default = 16, ) parser.add_argument( "--base_lr", type = float, default = 2e-6, help = "Base learning rate, will be scaled by sqrt(batch_size)", ) parser.add_argument( "--shift", type = float, default = 2.0, help = "Noise schedule shift for training (shift > 1 will spend more effort on early timesteps/high noise)", ) parser.add_argument( "--dropout", type = float, default = 0.1, help = "Probability to drop out conditioning (to support CFG)", ) parser.add_argument( "--max_train_steps", type = int, default = 50_000, help = "Total number of training steps", ) parser.add_argument( "--checkpointing_steps", type = int, default = 1000, help = "Save a checkpoint of the training state every X steps", ) args = parser.parse_args() return args def train(args): device = "cuda" torch.backends.cuda.matmul.allow_tf32 = True # faster but slightly less accurate torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") real_output_dir = os.path.join(args.output_dir, date_time) os.makedirs(real_output_dir, exist_ok=True) t_writer = SummaryWriter(log_dir=real_output_dir, flush_secs=60) data_files = glob("E:/datasets/commoncatalog-cc-by/**/*.parquet", recursive=True) train_dataset = datasets.load_dataset("parquet", data_files=data_files, split="train", streaming=True) train_dataset = train_dataset.shuffle(seed=args.seed, buffer_size=1000) image_transforms = v2.Compose([ v2.ToImage(), v2.ToDtype(dtype=torch.float32, scale=True), v2.Resize(512), v2.CenterCrop(512), ]) def collate_fn(examples): captions = [] pixel_values = [] for example in examples: captions.append(example["blip2_caption"]) image = Image.open(BytesIO(example["jpg"])).convert('RGB') image = image_transforms(image) * 2 - 1 image = torch.clamp(torch.nan_to_num(image), min=-1, max=1) pixel_values.append(image) pixel_values = torch.stack(pixel_values, dim=0).contiguous() return pixel_values, captions train_dataloader = DataLoader( dataset = train_dataset, batch_size = args.batch_size, collate_fn = collate_fn, num_workers = 0, ) tokenizer = CLIPTokenizer.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="text_encoder") text_encoder = text_encoder.to(dtype=torch.bfloat16, device=device) text_encoder.requires_grad_(False) text_encoder.eval() vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae") vae = vae.to(dtype=torch.bfloat16, device=device) vae.requires_grad_(False) vae.eval() unet = UNet2DConditionModel.from_pretrained(args.unet).to(device) unet.requires_grad_(True) unet.enable_gradient_checkpointing() unet.train() optimizer = bnb.optim.AdamW8bit( unet.parameters(), lr = args.base_lr * (args.batch_size ** 0.5), ) global_step = 0 train_logs = {"train_step": [], "train_loss": [], "train_timestep": []} def encode_captions(captions): input_ids = [] for caption in captions: if torch.rand(1) < args.dropout: caption = "" # caption dropout for better CFG ids = tokenizer( caption, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt", ).input_ids input_ids.append(ids) input_ids = torch.stack(input_ids, dim=0).to(device) return text_encoder(input_ids, return_dict=False)[0].float() def vae_encode(pixels): latents = vae.encode(pixels.to(dtype=torch.bfloat16, device=device)).latent_dist.sample() return latents.float() * vae.config.scaling_factor def get_pred(batch, log_to=None): pixels, captions = batch encoder_hidden_states = encode_captions(captions) latents = vae_encode(pixels) sigmas = torch.rand(latents.shape[0]).to(device) sigmas = (args.shift * sigmas) / (1 + (args.shift - 1) * sigmas) timesteps = sigmas * 1000 sigmas = sigmas[:, None, None, None] noise = torch.randn_like(latents) noisy_latents = noise * sigmas + latents * (1 - sigmas) target = noise - latents pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] loss = F.mse_loss(pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) # reduce over all dimensions except batch if log_to is not None: for i in range(timesteps.shape[0]): log_to["train_step"].append(global_step) log_to["train_loss"].append(loss[i].item()) log_to["train_timestep"].append(timesteps[i].item()) return loss.mean() def plot_logs(log_dict): plt.scatter(log_dict["train_timestep"], log_dict["train_loss"], s=3, c=log_dict["train_step"], marker=".", cmap='cool') plt.xlabel("timestep") plt.ylabel("loss") plt.yscale("log") progress_bar = tqdm(range(0, args.max_train_steps)) while True: for step, batch in enumerate(train_dataloader): loss = get_pred(batch, log_to=train_logs) t_writer.add_scalar("train/loss", loss.detach().item(), global_step) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), 2.0) t_writer.add_scalar("train/grad_norm", grad_norm.detach().item(), global_step) optimizer.step() optimizer.zero_grad() progress_bar.update(1) global_step += 1 if global_step % 100 == 0: plot_logs(train_logs) t_writer.add_figure("train_loss", plt.gcf(), global_step) if global_step >= args.max_train_steps or global_step % args.checkpointing_steps == 0: checkpoint_path = os.path.join(real_output_dir, f"checkpoint-{global_step:08}") unet.save_pretrained(os.path.join(checkpoint_path, "unet"), safe_serialization=True) if global_step >= args.max_train_steps: break if __name__ == "__main__": train(parse_args())