Spaces:
Runtime error
Runtime error
| import os.path | |
| import random | |
| from accelerate.utils import set_seed | |
| from diffusers import StableDiffusionPipeline | |
| from torch.cuda.amp import autocast | |
| from torchvision import transforms | |
| from StableDiffuser import StableDiffuser | |
| from finetuning import FineTunedModel | |
| import torch | |
| from tqdm import tqdm | |
| from isolate_rng import isolate_rng | |
| from memory_efficiency import MemoryEfficiencyWrapper | |
| from torch.utils.tensorboard import SummaryWriter | |
| training_should_cancel = False | |
| def validate(diffuser: StableDiffuser, finetuner: FineTunedModel, | |
| validation_embeddings: torch.FloatTensor, | |
| neutral_embeddings: torch.FloatTensor, | |
| sample_embeddings: torch.FloatTensor, | |
| logger: SummaryWriter, use_amp: bool, | |
| global_step: int, | |
| validation_seed: int = 555, | |
| ): | |
| print("validating...") | |
| with isolate_rng(include_cuda=True), torch.no_grad(): | |
| set_seed(validation_seed) | |
| criteria = torch.nn.MSELoss() | |
| negative_guidance = 1 | |
| val_count = 5 | |
| nsteps=50 | |
| num_validation_prompts = validation_embeddings.shape[0] // 2 | |
| for i in range(0, num_validation_prompts): | |
| accumulated_loss = None | |
| this_validation_embeddings = validation_embeddings[i*2:i*2+2] | |
| for j in range(val_count): | |
| iteration = random.randint(1, nsteps) | |
| diffused_latents = get_diffused_latents(diffuser, nsteps, this_validation_embeddings, iteration, use_amp) | |
| with autocast(enabled=use_amp): | |
| positive_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1) | |
| neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_embeddings, guidance_scale=1) | |
| with finetuner, autocast(enabled=use_amp): | |
| negative_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1) | |
| loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) | |
| accumulated_loss = (accumulated_loss or 0) + loss.item() | |
| logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step) | |
| num_samples = sample_embeddings.shape[0] // 2 | |
| for i in range(0, num_samples): | |
| print(f'making sample {i}...') | |
| with finetuner: | |
| pipeline = StableDiffusionPipeline(vae=diffuser.vae, | |
| text_encoder=diffuser.text_encoder, | |
| tokenizer=diffuser.tokenizer, | |
| unet=diffuser.unet, | |
| scheduler=diffuser.scheduler, | |
| safety_checker=None, | |
| feature_extractor=None, | |
| requires_safety_checker=False) | |
| images = pipeline(prompt_embeds=sample_embeddings[i*2+1:i*2+2], negative_prompt_embeds=sample_embeddings[i*2:i*2+1], | |
| num_inference_steps=50) | |
| image_tensor = transforms.ToTensor()(images.images[0]) | |
| logger.add_image(f"samples/{i}", img_tensor=image_tensor, global_step=global_step) | |
| """ | |
| with finetuner, torch.cuda.amp.autocast(enabled=use_amp): | |
| images = diffuser( | |
| combined_embeddings=sample_embeddings[i*2:i*2+2], | |
| n_steps=50 | |
| ) | |
| logger.add_images(f"samples/{i}", images) | |
| """ | |
| torch.cuda.empty_cache() | |
| def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path, | |
| use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1, | |
| save_every_n_steps=-1, validate_every_n_steps=-1, | |
| validation_prompts=[], sample_positive_prompts=[], sample_negative_prompts=[]): | |
| diffuser = None | |
| loss = None | |
| optimizer = None | |
| finetuner = None | |
| negative_latents = None | |
| neutral_latents = None | |
| positive_latents = None | |
| nsteps = 50 | |
| print(f"using img_size of {img_size}") | |
| diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda') | |
| logger = SummaryWriter(log_dir=f"logs/{os.path.splitext(os.path.basename(save_path))[0]}") | |
| memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers, | |
| use_gradient_checkpointing=use_gradient_checkpointing ) | |
| with memory_efficiency_wrapper: | |
| diffuser.train() | |
| finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules) | |
| if use_adamw8bit: | |
| print("using AdamW 8Bit optimizer") | |
| import bitsandbytes as bnb | |
| optimizer = bnb.optim.AdamW8bit(finetuner.parameters(), | |
| lr=lr, | |
| betas=(0.9, 0.999), | |
| weight_decay=0.010, | |
| eps=1e-8 | |
| ) | |
| else: | |
| print("using Adam optimizer") | |
| optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) | |
| criteria = torch.nn.MSELoss() | |
| pbar = tqdm(range(iterations)) | |
| with torch.no_grad(): | |
| neutral_text_embeddings = diffuser.get_cond_and_uncond_embeddings([''], n_imgs=1) | |
| positive_text_embeddings = diffuser.get_cond_and_uncond_embeddings([prompt], n_imgs=1) | |
| validation_embeddings = diffuser.get_cond_and_uncond_embeddings(validation_prompts, n_imgs=1) | |
| sample_embeddings = diffuser.get_cond_and_uncond_embeddings(sample_positive_prompts, sample_negative_prompts, n_imgs=1) | |
| #if use_amp: | |
| # diffuser.vae = diffuser.vae.to(diffuser.vae.device, dtype=torch.float16) | |
| #del diffuser.text_encoder | |
| #del diffuser.tokenizer | |
| torch.cuda.empty_cache() | |
| if seed == -1: | |
| seed = random.randint(0, 2 ** 30) | |
| set_seed(int(seed)) | |
| prev_losses = [] | |
| start_loss = None | |
| max_prev_loss_count = 10 | |
| try: | |
| for i in pbar: | |
| if training_should_cancel: | |
| print("received cancellation request") | |
| return None | |
| with torch.no_grad(): | |
| optimizer.zero_grad() | |
| iteration = torch.randint(1, nsteps - 1, (1,)).item() | |
| with finetuner: | |
| diffused_latents = get_diffused_latents(diffuser, nsteps, positive_text_embeddings, iteration, use_amp) | |
| iteration = int(iteration / nsteps * 1000) | |
| with autocast(enabled=use_amp): | |
| positive_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1) | |
| neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_text_embeddings, guidance_scale=1) | |
| with finetuner: | |
| with autocast(enabled=use_amp): | |
| negative_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1) | |
| positive_latents.requires_grad = False | |
| neutral_latents.requires_grad = False | |
| # loss = criteria(e_n, e_0) works the best try 5000 epochs | |
| loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) | |
| memory_efficiency_wrapper.step(optimizer, loss) | |
| optimizer.zero_grad() | |
| logger.add_scalar("loss", loss.item(), global_step=i) | |
| # print moving average loss | |
| prev_losses.append(loss.detach().clone()) | |
| if len(prev_losses) > max_prev_loss_count: | |
| prev_losses.pop(0) | |
| if start_loss is None: | |
| start_loss = prev_losses[-1] | |
| if len(prev_losses) >= max_prev_loss_count: | |
| moving_average_loss = sum(prev_losses) / len(prev_losses) | |
| print( | |
| f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}") | |
| else: | |
| print(f"step {i}: loss={loss.item()}") | |
| if save_every_n_steps > 0 and ((i+1) % save_every_n_steps) == 0: | |
| torch.save(finetuner.state_dict(), save_path + f"__step_{i+1}.pt") | |
| if validate_every_n_steps > 0 and ((i+1) % validate_every_n_steps) == 0: | |
| validate(diffuser, finetuner, | |
| validation_embeddings=validation_embeddings, | |
| sample_embeddings=sample_embeddings, | |
| neutral_embeddings=neutral_text_embeddings, | |
| logger=logger, use_amp=False, global_step=i) | |
| torch.save(finetuner.state_dict(), save_path) | |
| return save_path | |
| finally: | |
| del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents | |
| torch.cuda.empty_cache() | |
| def get_diffused_latents(diffuser, nsteps, text_embeddings, end_iteration, use_amp): | |
| diffuser.set_scheduler_timesteps(nsteps) | |
| latents = diffuser.get_initial_latents(1, n_prompts=1) | |
| latents_steps, _ = diffuser.diffusion( | |
| latents, | |
| text_embeddings, | |
| start_iteration=0, | |
| end_iteration=end_iteration, | |
| guidance_scale=3, | |
| show_progress=False, | |
| use_amp=use_amp | |
| ) | |
| # because return_latents is not passed to diffuser.diffusion(), latents_steps should have only 1 entry | |
| # but we take the "last" (-1) entry because paranoia | |
| diffused_latents = latents_steps[-1] | |
| diffuser.set_scheduler_timesteps(1000) | |
| del latents_steps, latents | |
| return diffused_latents | |
| if __name__ == '__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--repo_id_or_path", required=True) | |
| parser.add_argument("--img_size", type=int, required=False, default=512) | |
| parser.add_argument('--prompt', required=True) | |
| parser.add_argument('--modules', required=True) | |
| parser.add_argument('--freeze_modules', nargs='+', required=True) | |
| parser.add_argument('--save_path', required=True) | |
| parser.add_argument('--iterations', type=int, required=True) | |
| parser.add_argument('--lr', type=float, required=True) | |
| parser.add_argument('--negative_guidance', type=float, required=True) | |
| parser.add_argument('--seed', type=int, required=False, default=-1, | |
| help='Training seed for reproducible results, or -1 to pick a random seed') | |
| parser.add_argument('--use_adamw8bit', action='store_true') | |
| parser.add_argument('--use_xformers', action='store_true') | |
| parser.add_argument('--use_amp', action='store_true') | |
| parser.add_argument('--use_gradient_checkpointing', action='store_true') | |
| train(**vars(parser.parse_args())) |