Spaces:
Runtime error
Runtime error
Damian Stewart
commited on
Commit
·
d8ffb68
1
Parent(s):
209d166
allow different resolutions for w/h
Browse files- StableDiffuser.py +23 -33
- app.py +78 -14
- finetuning.py +2 -7
- train.py +1 -1
StableDiffuser.py
CHANGED
|
@@ -1,17 +1,13 @@
|
|
| 1 |
import argparse
|
| 2 |
-
import traceback
|
| 3 |
|
| 4 |
import torch
|
| 5 |
from baukit import TraceDict
|
| 6 |
-
from diffusers import
|
| 7 |
from PIL import Image
|
| 8 |
from tqdm.auto import tqdm
|
| 9 |
-
from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
|
| 10 |
-
from diffusers.schedulers import EulerAncestralDiscreteScheduler
|
| 11 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
| 12 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
| 13 |
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
| 14 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
| 15 |
import util
|
| 16 |
|
| 17 |
|
|
@@ -39,31 +35,17 @@ class StableDiffuser(torch.nn.Module):
|
|
| 39 |
def __init__(self,
|
| 40 |
scheduler='LMS',
|
| 41 |
repo_id_or_path="CompVis/stable-diffusion-v1-4",
|
|
|
|
| 42 |
):
|
| 43 |
|
| 44 |
super().__init__()
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
self.
|
| 52 |
-
repo_id_or_path, subfolder="tokenizer")
|
| 53 |
-
self.text_encoder = CLIPTextModel.from_pretrained(
|
| 54 |
-
repo_id_or_path, subfolder="text_encoder")
|
| 55 |
-
|
| 56 |
-
# The UNet model for generating the latents.
|
| 57 |
-
self.unet = UNet2DConditionModel.from_pretrained(
|
| 58 |
-
repo_id_or_path, subfolder="unet")
|
| 59 |
-
|
| 60 |
-
try:
|
| 61 |
-
self.feature_extractor = CLIPFeatureExtractor.from_pretrained(repo_id_or_path, subfolder="feature_extractor")
|
| 62 |
-
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(repo_id_or_path, subfolder="safety_checker")
|
| 63 |
-
except Exception as error:
|
| 64 |
-
print(f"caught exception {error} making feature extractor / safety checker")
|
| 65 |
-
self.feature_extractor = None
|
| 66 |
-
self.safety_checker = None
|
| 67 |
|
| 68 |
if scheduler == 'LMS':
|
| 69 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
|
@@ -74,10 +56,14 @@ class StableDiffuser(torch.nn.Module):
|
|
| 74 |
|
| 75 |
self.eval()
|
| 76 |
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
param = list(self.parameters())[0]
|
| 79 |
return torch.randn(
|
| 80 |
-
(batch_size, self.unet.in_channels,
|
| 81 |
generator=generator).type(param.dtype).to(param.device)
|
| 82 |
|
| 83 |
def add_noise(self, latents, noise, step):
|
|
@@ -109,8 +95,8 @@ class StableDiffuser(torch.nn.Module):
|
|
| 109 |
def set_scheduler_timesteps(self, n_steps):
|
| 110 |
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
| 111 |
|
| 112 |
-
def get_initial_latents(self, n_imgs,
|
| 113 |
-
noise = self.get_noise(n_imgs,
|
| 114 |
latents = noise * self.scheduler.init_noise_sigma
|
| 115 |
return latents
|
| 116 |
|
|
@@ -196,7 +182,8 @@ class StableDiffuser(torch.nn.Module):
|
|
| 196 |
def __call__(self,
|
| 197 |
prompts,
|
| 198 |
negative_prompts,
|
| 199 |
-
|
|
|
|
| 200 |
n_steps=50,
|
| 201 |
n_imgs=1,
|
| 202 |
end_iteration=None,
|
|
@@ -210,7 +197,7 @@ class StableDiffuser(torch.nn.Module):
|
|
| 210 |
prompts = [prompts]
|
| 211 |
|
| 212 |
self.set_scheduler_timesteps(n_steps)
|
| 213 |
-
latents = self.get_initial_latents(n_imgs,
|
| 214 |
text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
|
| 215 |
end_iteration = end_iteration or n_steps
|
| 216 |
latents_steps, trace_steps = self.diffusion(
|
|
@@ -239,13 +226,16 @@ class StableDiffuser(torch.nn.Module):
|
|
| 239 |
|
| 240 |
return images_steps
|
| 241 |
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
if __name__ == '__main__':
|
| 244 |
|
| 245 |
parser = default_parser()
|
| 246 |
args = parser.parse_args()
|
| 247 |
|
| 248 |
-
diffuser = StableDiffuser(
|
| 249 |
|
| 250 |
images = diffuser(args.prompts,
|
| 251 |
n_steps=args.nsteps,
|
|
|
|
| 1 |
import argparse
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
from baukit import TraceDict
|
| 5 |
+
from diffusers import StableDiffusionPipeline
|
| 6 |
from PIL import Image
|
| 7 |
from tqdm.auto import tqdm
|
|
|
|
|
|
|
| 8 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
| 9 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
| 10 |
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
|
|
|
| 11 |
import util
|
| 12 |
|
| 13 |
|
|
|
|
| 35 |
def __init__(self,
|
| 36 |
scheduler='LMS',
|
| 37 |
repo_id_or_path="CompVis/stable-diffusion-v1-4",
|
| 38 |
+
variant='fp16'
|
| 39 |
):
|
| 40 |
|
| 41 |
super().__init__()
|
| 42 |
|
| 43 |
+
self.pipeline = StableDiffusionPipeline.from_pretrained(repo_id_or_path, variant=variant)
|
| 44 |
+
|
| 45 |
+
self.vae = self.pipeline.vae
|
| 46 |
+
self.unet = self.pipeline.unet
|
| 47 |
+
self.tokenizer = self.pipeline.tokenizer
|
| 48 |
+
self.text_encoder = self.pipeline.text_encoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
if scheduler == 'LMS':
|
| 51 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
|
|
|
| 56 |
|
| 57 |
self.eval()
|
| 58 |
|
| 59 |
+
@property
|
| 60 |
+
def safety_checker(self):
|
| 61 |
+
return self.pipeline.safety_checker
|
| 62 |
+
|
| 63 |
+
def get_noise(self, batch_size, width, height, generator=None):
|
| 64 |
param = list(self.parameters())[0]
|
| 65 |
return torch.randn(
|
| 66 |
+
(batch_size, self.unet.in_channels, width // 8, height // 8),
|
| 67 |
generator=generator).type(param.dtype).to(param.device)
|
| 68 |
|
| 69 |
def add_noise(self, latents, noise, step):
|
|
|
|
| 95 |
def set_scheduler_timesteps(self, n_steps):
|
| 96 |
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
| 97 |
|
| 98 |
+
def get_initial_latents(self, n_imgs, width, height, n_prompts, generator=None):
|
| 99 |
+
noise = self.get_noise(n_imgs, width, height, generator=generator).repeat(n_prompts, 1, 1, 1)
|
| 100 |
latents = noise * self.scheduler.init_noise_sigma
|
| 101 |
return latents
|
| 102 |
|
|
|
|
| 182 |
def __call__(self,
|
| 183 |
prompts,
|
| 184 |
negative_prompts,
|
| 185 |
+
width=512,
|
| 186 |
+
height=512,
|
| 187 |
n_steps=50,
|
| 188 |
n_imgs=1,
|
| 189 |
end_iteration=None,
|
|
|
|
| 197 |
prompts = [prompts]
|
| 198 |
|
| 199 |
self.set_scheduler_timesteps(n_steps)
|
| 200 |
+
latents = self.get_initial_latents(n_imgs, width, height, len(prompts), generator=generator)
|
| 201 |
text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
|
| 202 |
end_iteration = end_iteration or n_steps
|
| 203 |
latents_steps, trace_steps = self.diffusion(
|
|
|
|
| 226 |
|
| 227 |
return images_steps
|
| 228 |
|
| 229 |
+
def save_pretrained(self, path, **kwargs):
|
| 230 |
+
self.pipeline.save_pretrained(path, **kwargs)
|
| 231 |
+
|
| 232 |
|
| 233 |
if __name__ == '__main__':
|
| 234 |
|
| 235 |
parser = default_parser()
|
| 236 |
args = parser.parse_args()
|
| 237 |
|
| 238 |
+
diffuser = StableDiffuser(scheduler='DDIM').to(torch.device(args.device)).half()
|
| 239 |
|
| 240 |
images = diffuser(args.prompts,
|
| 241 |
n_steps=args.nsteps,
|
app.py
CHANGED
|
@@ -86,8 +86,16 @@ class Demo:
|
|
| 86 |
label="Seed",
|
| 87 |
value=42
|
| 88 |
)
|
| 89 |
-
self.
|
| 90 |
-
label="Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
minimum=256,
|
| 92 |
maximum=1024,
|
| 93 |
value=512,
|
|
@@ -190,11 +198,51 @@ class Demo:
|
|
| 190 |
|
| 191 |
self.download = gr.Files()
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
self.infr_button.click(self.inference, inputs = [
|
| 194 |
self.prompt_input_infr,
|
| 195 |
self.negative_prompt_input_infr,
|
| 196 |
self.seed_infr,
|
| 197 |
-
self.
|
|
|
|
| 198 |
self.model_dropdown,
|
| 199 |
self.base_repo_id_or_path_input_infr
|
| 200 |
],
|
|
@@ -214,6 +262,14 @@ class Demo:
|
|
| 214 |
],
|
| 215 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
| 216 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
|
| 219 |
|
|
@@ -251,42 +307,50 @@ class Demo:
|
|
| 251 |
|
| 252 |
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
-
def inference(self, prompt, negative_prompt, seed,
|
| 256 |
|
| 257 |
seed = seed or 42
|
| 258 |
-
generator = torch.manual_seed(seed)
|
| 259 |
model_path = model_map[model_name]
|
| 260 |
checkpoint = torch.load(model_path)
|
| 261 |
|
| 262 |
self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval().half()
|
| 263 |
finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
|
| 264 |
-
torch.cuda.empty_cache()
|
| 265 |
|
|
|
|
|
|
|
|
|
|
| 266 |
images = self.diffuser(
|
| 267 |
prompt,
|
| 268 |
negative_prompt,
|
| 269 |
-
|
|
|
|
| 270 |
n_steps=50,
|
| 271 |
generator=generator
|
| 272 |
)
|
| 273 |
-
|
| 274 |
-
|
| 275 |
orig_image = images[0][0]
|
| 276 |
|
| 277 |
torch.cuda.empty_cache()
|
| 278 |
-
|
| 279 |
-
generator = torch.manual_seed(seed)
|
| 280 |
-
|
| 281 |
with finetuner:
|
| 282 |
-
|
| 283 |
images = self.diffuser(
|
| 284 |
prompt,
|
| 285 |
negative_prompt,
|
|
|
|
|
|
|
| 286 |
n_steps=50,
|
| 287 |
generator=generator
|
| 288 |
)
|
| 289 |
-
|
| 290 |
edited_image = images[0][0]
|
| 291 |
|
| 292 |
del finetuner
|
|
|
|
| 86 |
label="Seed",
|
| 87 |
value=42
|
| 88 |
)
|
| 89 |
+
self.img_width_infr = gr.Slider(
|
| 90 |
+
label="Image width",
|
| 91 |
+
minimum=256,
|
| 92 |
+
maximum=1024,
|
| 93 |
+
value=512,
|
| 94 |
+
step=64
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.img_height_infr = gr.Slider(
|
| 98 |
+
label="Image height",
|
| 99 |
minimum=256,
|
| 100 |
maximum=1024,
|
| 101 |
value=512,
|
|
|
|
| 198 |
|
| 199 |
self.download = gr.Files()
|
| 200 |
|
| 201 |
+
with gr.Tab("Export") as export_column:
|
| 202 |
+
|
| 203 |
+
with gr.Row():
|
| 204 |
+
|
| 205 |
+
self.explain_train= gr.Markdown(interactive=False,
|
| 206 |
+
value='Export a model to Diffusers format. Please enter the base model and select the editing weights.')
|
| 207 |
+
|
| 208 |
+
with gr.Row():
|
| 209 |
+
|
| 210 |
+
with gr.Column(scale=3):
|
| 211 |
+
|
| 212 |
+
self.base_repo_id_or_path_input_export = gr.Text(
|
| 213 |
+
label="Base model",
|
| 214 |
+
value="CompVis/stable-diffusion-v1-4",
|
| 215 |
+
info="Path or huggingface repo id of the base model that this edit was done against"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
self.model_dropdown_export = gr.Dropdown(
|
| 219 |
+
label="ESD Model",
|
| 220 |
+
choices=list(model_map.keys()),
|
| 221 |
+
value='Van Gogh',
|
| 222 |
+
interactive=True
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
self.save_path_input_export = gr.Text(
|
| 226 |
+
label="Output path",
|
| 227 |
+
placeholder="./exported_models/model_name",
|
| 228 |
+
info="Path to export the model to. A diffusers folder will be written to this location."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
self.save_half_export = gr.Checkbox(
|
| 232 |
+
label="Save as fp16"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
with gr.Column(scale=1):
|
| 236 |
+
self.export_button = gr.Button(
|
| 237 |
+
value="Export",
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
self.infr_button.click(self.inference, inputs = [
|
| 241 |
self.prompt_input_infr,
|
| 242 |
self.negative_prompt_input_infr,
|
| 243 |
self.seed_infr,
|
| 244 |
+
self.img_width_infr,
|
| 245 |
+
self.img_height_infr,
|
| 246 |
self.model_dropdown,
|
| 247 |
self.base_repo_id_or_path_input_infr
|
| 248 |
],
|
|
|
|
| 262 |
],
|
| 263 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
| 264 |
)
|
| 265 |
+
self.export_button.click(self.export, inputs = [
|
| 266 |
+
self.model_dropdown_export,
|
| 267 |
+
self.base_repo_id_or_path_input_export,
|
| 268 |
+
self.save_path_input_export,
|
| 269 |
+
self.save_half_export
|
| 270 |
+
],
|
| 271 |
+
outputs=[self.export_button]
|
| 272 |
+
)
|
| 273 |
|
| 274 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
|
| 275 |
|
|
|
|
| 307 |
|
| 308 |
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
|
| 309 |
|
| 310 |
+
def export(self, model_name, base_repo_id_or_path, save_path, save_half):
|
| 311 |
+
model_path = model_map[model_name]
|
| 312 |
+
checkpoint = torch.load(model_path)
|
| 313 |
+
self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval()
|
| 314 |
+
finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval()
|
| 315 |
+
with finetuner:
|
| 316 |
+
if save_half:
|
| 317 |
+
self.diffuser = self.diffuser.half()
|
| 318 |
+
self.diffuser.pipeline.to(torch.float16, torch_device=self.diffuser.device)
|
| 319 |
+
self.diffuser.save_pretrained(save_path)
|
| 320 |
+
|
| 321 |
|
| 322 |
+
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
| 323 |
|
| 324 |
seed = seed or 42
|
|
|
|
| 325 |
model_path = model_map[model_name]
|
| 326 |
checkpoint = torch.load(model_path)
|
| 327 |
|
| 328 |
self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval().half()
|
| 329 |
finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
|
|
|
|
| 330 |
|
| 331 |
+
generator = torch.manual_seed(seed)
|
| 332 |
+
|
| 333 |
+
torch.cuda.empty_cache()
|
| 334 |
images = self.diffuser(
|
| 335 |
prompt,
|
| 336 |
negative_prompt,
|
| 337 |
+
width=width,
|
| 338 |
+
height=height,
|
| 339 |
n_steps=50,
|
| 340 |
generator=generator
|
| 341 |
)
|
|
|
|
|
|
|
| 342 |
orig_image = images[0][0]
|
| 343 |
|
| 344 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
| 345 |
with finetuner:
|
|
|
|
| 346 |
images = self.diffuser(
|
| 347 |
prompt,
|
| 348 |
negative_prompt,
|
| 349 |
+
width=width,
|
| 350 |
+
height=height,
|
| 351 |
n_steps=50,
|
| 352 |
generator=generator
|
| 353 |
)
|
|
|
|
| 354 |
edited_image = images[0][0]
|
| 355 |
|
| 356 |
del finetuner
|
finetuning.py
CHANGED
|
@@ -2,11 +2,12 @@ import copy
|
|
| 2 |
import re
|
| 3 |
import torch
|
| 4 |
import util
|
|
|
|
| 5 |
|
| 6 |
class FineTunedModel(torch.nn.Module):
|
| 7 |
|
| 8 |
def __init__(self,
|
| 9 |
-
model,
|
| 10 |
modules,
|
| 11 |
frozen_modules=[]
|
| 12 |
):
|
|
@@ -24,11 +25,8 @@ class FineTunedModel(torch.nn.Module):
|
|
| 24 |
|
| 25 |
for module_name, module in model.named_modules():
|
| 26 |
for ft_module_regex in modules:
|
| 27 |
-
|
| 28 |
match = re.search(ft_module_regex, module_name)
|
| 29 |
-
|
| 30 |
if match is not None:
|
| 31 |
-
|
| 32 |
ft_module = copy.deepcopy(module)
|
| 33 |
|
| 34 |
self.orig_modules[module_name] = module
|
|
@@ -39,13 +37,10 @@ class FineTunedModel(torch.nn.Module):
|
|
| 39 |
print(f"=> Finetuning {module_name}")
|
| 40 |
|
| 41 |
for ft_module_name, module in ft_module.named_modules():
|
| 42 |
-
|
| 43 |
ft_module_name = f"{module_name}.{ft_module_name}"
|
| 44 |
-
|
| 45 |
for freeze_module_name in frozen_modules:
|
| 46 |
|
| 47 |
match = re.search(freeze_module_name, ft_module_name)
|
| 48 |
-
|
| 49 |
if match:
|
| 50 |
print(f"=> Freezing {ft_module_name}")
|
| 51 |
util.freeze(module)
|
|
|
|
| 2 |
import re
|
| 3 |
import torch
|
| 4 |
import util
|
| 5 |
+
from StableDiffuser import StableDiffuser
|
| 6 |
|
| 7 |
class FineTunedModel(torch.nn.Module):
|
| 8 |
|
| 9 |
def __init__(self,
|
| 10 |
+
model: StableDiffuser,
|
| 11 |
modules,
|
| 12 |
frozen_modules=[]
|
| 13 |
):
|
|
|
|
| 25 |
|
| 26 |
for module_name, module in model.named_modules():
|
| 27 |
for ft_module_regex in modules:
|
|
|
|
| 28 |
match = re.search(ft_module_regex, module_name)
|
|
|
|
| 29 |
if match is not None:
|
|
|
|
| 30 |
ft_module = copy.deepcopy(module)
|
| 31 |
|
| 32 |
self.orig_modules[module_name] = module
|
|
|
|
| 37 |
print(f"=> Finetuning {module_name}")
|
| 38 |
|
| 39 |
for ft_module_name, module in ft_module.named_modules():
|
|
|
|
| 40 |
ft_module_name = f"{module_name}.{ft_module_name}"
|
|
|
|
| 41 |
for freeze_module_name in frozen_modules:
|
| 42 |
|
| 43 |
match = re.search(freeze_module_name, ft_module_name)
|
|
|
|
| 44 |
if match:
|
| 45 |
print(f"=> Freezing {ft_module_name}")
|
| 46 |
util.freeze(module)
|
train.py
CHANGED
|
@@ -36,7 +36,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
| 36 |
optimizer.zero_grad()
|
| 37 |
|
| 38 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
| 39 |
-
latents = diffuser.get_initial_latents(1, img_size, 1)
|
| 40 |
|
| 41 |
with finetuner:
|
| 42 |
latents_steps, _ = diffuser.diffusion(
|
|
|
|
| 36 |
optimizer.zero_grad()
|
| 37 |
|
| 38 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
| 39 |
+
latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
|
| 40 |
|
| 41 |
with finetuner:
|
| 42 |
latents_steps, _ = diffuser.diffusion(
|