Spaces:
Runtime error
Runtime error
| # adapted from EveryDream2Trainer | |
| import contextlib | |
| import traceback | |
| import torch | |
| from torch.cuda.amp import GradScaler | |
| from StableDiffuser import StableDiffuser | |
| class MemoryEfficiencyWrapper: | |
| def __init__(self, | |
| diffuser: StableDiffuser, | |
| use_amp: bool, | |
| use_xformers: bool, | |
| use_gradient_checkpointing: bool): | |
| self.diffuser = diffuser | |
| self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == [8, 8, 8, 8] | |
| self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == 8 or self.is_sd1attn | |
| self.use_amp = use_amp | |
| self.use_xformers = use_xformers | |
| self.use_gradient_checkpointing = use_gradient_checkpointing | |
| def __enter__(self): | |
| if self.use_gradient_checkpointing: | |
| self.diffuser.unet.enable_gradient_checkpointing() | |
| self.diffuser.text_encoder.gradient_checkpointing_enable() | |
| if self.use_xformers: | |
| if (self.use_amp and self.is_sd1attn) or (not self.is_sd1attn): | |
| try: | |
| self.diffuser.unet.enable_xformers_memory_efficient_attention() | |
| print("Enabled xformers") | |
| except Exception as ex: | |
| print("failed to load xformers, using attention slicing instead") | |
| self.diffuser.unet.set_attention_slice("auto") | |
| pass | |
| elif (not self.use_amp and self.is_sd1attn): | |
| print("AMP is disabled but model is SD1.X, using attention slicing instead of xformers") | |
| self.diffuser.unet.set_attention_slice("auto") | |
| else: | |
| print("xformers disabled via arg, using attention slicing instead") | |
| self.diffuser.unet.set_attention_slice("auto") | |
| #self.diffuser.vae = self.diffuser.vae.to(self.diffuser.vae.device, dtype=torch.float16 if self.use_amp else torch.float32) | |
| self.diffuser.unet = self.diffuser.unet.to(self.diffuser.unet.device, dtype=torch.float32) | |
| try: | |
| # unet = torch.compile(unet) | |
| # text_encoder = torch.compile(text_encoder) | |
| # vae = torch.compile(vae) | |
| torch.set_float32_matmul_precision('high') | |
| torch.backends.cudnn.allow_tf32 = True | |
| # logging.info("Successfully compiled models") | |
| except Exception as ex: | |
| print(f"Failed to compile model, continuing anyway, ex: {ex}") | |
| pass | |
| self.grad_scaler = GradScaler( | |
| enabled=self.use_amp, | |
| init_scale=2 ** 17.5, | |
| growth_factor=2, | |
| backoff_factor=1.0 / 2, | |
| growth_interval=25, | |
| ) | |
| def step(self, optimizer, loss): | |
| self.grad_scaler.scale(loss).backward() | |
| self.grad_scaler.step(optimizer) | |
| self.grad_scaler.update() | |
| def __exit__(self, exc_type, exc_value, tb): | |
| if exc_type is not None: | |
| traceback.print_exception(exc_type, exc_value, tb) | |
| # return False # uncomment to pass exception through): | |
| self.diffuser.unet.disable_gradient_checkpointing() | |
| try: | |
| self.diffuser.text_encoder.gradient_checkpointing_disable() | |
| except AttributeError: | |
| # self.diffuser.text_encoder is likely `del`eted | |
| pass | |
| self.diffuser.unet.disable_xformers_memory_efficient_attention() | |
| self.diffuser.unet.set_attention_slice("auto") | |