johnowhitaker commited on
Commit
7fc747c
·
1 Parent(s): 78b9d2f

Upload train_latent_diffusion.py

Browse files

Adding the version of the training script used to train the model

Files changed (1) hide show
  1. train_latent_diffusion.py +586 -0
train_latent_diffusion.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ from contextlib import contextmanager
5
+ from copy import deepcopy
6
+ from functools import partial
7
+ import math
8
+ import random
9
+ from pathlib import Path
10
+ import json
11
+ import pickle
12
+ import sys
13
+
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+ sys.path.append('./taming-transformers')
17
+ from taming.models import cond_transformer, vqgan
18
+ sys.path.append('./latent-diffusion')
19
+ import ldm.models.autoencoder
20
+ sys.path.append('./v-diffusion-pytorch')
21
+ from diffusion import sampling
22
+ from diffusion import utils as diffusion_utils
23
+ import pytorch_lightning as pl
24
+ from pytorch_lightning.utilities.distributed import rank_zero_only
25
+ import torch
26
+ from torch import optim, nn
27
+ from torch.nn import functional as F
28
+ from torch.utils import data
29
+ from torchvision.io import read_image
30
+ from torchvision import transforms, utils, datasets
31
+ from torchvision.transforms import functional as TF
32
+ import torchvision.transforms as T
33
+ from tqdm import trange
34
+ import wandb
35
+
36
+ from CLIP import clip
37
+
38
+ sys.path.append('./cloob-training')
39
+ from cloob_training import model_pt, pretrained
40
+
41
+ # Define utility functions
42
+
43
+ def load_vqgan_model(config_path, checkpoint_path):
44
+ config = OmegaConf.load(config_path)
45
+ if config.model.target == 'taming.models.vqgan.VQModel':
46
+ model = vqgan.VQModel(**config.model.params)
47
+ model.eval().requires_grad_(False)
48
+ model.init_from_ckpt(checkpoint_path)
49
+ elif config.model.target == 'taming.models.vqgan.GumbelVQ':
50
+ model = vqgan.GumbelVQ(**config.model.params)
51
+ model.eval().requires_grad_(False)
52
+ model.init_from_ckpt(checkpoint_path)
53
+ elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
54
+ parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
55
+ parent_model.eval().requires_grad_(False)
56
+ parent_model.init_from_ckpt(checkpoint_path)
57
+ model = parent_model.first_stage_model
58
+ else:
59
+ raise ValueError(f'unknown model type: {config.model.target}')
60
+ del model.loss
61
+ return model
62
+
63
+ @contextmanager
64
+ def train_mode(model, mode=True):
65
+ """A context manager that places a model into training mode and restores
66
+ the previous mode on exit."""
67
+ modes = [module.training for module in model.modules()]
68
+ try:
69
+ yield model.train(mode)
70
+ finally:
71
+ for i, module in enumerate(model.modules()):
72
+ module.training = modes[i]
73
+
74
+
75
+ def eval_mode(model):
76
+ """A context manager that places a model into evaluation mode and restores
77
+ the previous mode on exit."""
78
+ return train_mode(model, False)
79
+
80
+
81
+ @torch.no_grad()
82
+ def ema_update(model, averaged_model, decay):
83
+ """Incorporates updated model parameters into an exponential moving averaged
84
+ version of a model. It should be called after each optimizer step."""
85
+ model_params = dict(model.named_parameters())
86
+ averaged_params = dict(averaged_model.named_parameters())
87
+ assert model_params.keys() == averaged_params.keys()
88
+
89
+ for name, param in model_params.items():
90
+ averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
91
+
92
+ model_buffers = dict(model.named_buffers())
93
+ averaged_buffers = dict(averaged_model.named_buffers())
94
+ assert model_buffers.keys() == averaged_buffers.keys()
95
+
96
+ for name, buf in model_buffers.items():
97
+ averaged_buffers[name].copy_(buf)
98
+
99
+
100
+ # Define the diffusion noise schedule
101
+
102
+ def get_alphas_sigmas(t):
103
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
104
+
105
+
106
+ # Define the model (a residual U-Net)
107
+
108
+ class ResidualBlock(nn.Module):
109
+ def __init__(self, main, skip=None):
110
+ super().__init__()
111
+ self.main = nn.Sequential(*main)
112
+ self.skip = skip if skip else nn.Identity()
113
+
114
+ def forward(self, input):
115
+ return self.main(input) + self.skip(input)
116
+
117
+
118
+ class ResLinearBlock(ResidualBlock):
119
+ def __init__(self, f_in, f_mid, f_out, is_last=False):
120
+ skip = None if f_in == f_out else nn.Linear(f_in, f_out, bias=False)
121
+ super().__init__([
122
+ nn.Linear(f_in, f_mid),
123
+ nn.ReLU(inplace=True),
124
+ nn.Linear(f_mid, f_out),
125
+ nn.ReLU(inplace=True) if not is_last else nn.Identity(),
126
+ ], skip)
127
+
128
+
129
+ class Modulation2d(nn.Module):
130
+ def __init__(self, state, feats_in, c_out):
131
+ super().__init__()
132
+ self.state = state
133
+ self.layer = nn.Linear(feats_in, c_out * 2, bias=False)
134
+
135
+ def forward(self, input):
136
+ scales, shifts = self.layer(self.state['cond']).chunk(2, dim=-1)
137
+ return torch.addcmul(shifts[..., None, None], input, scales[..., None, None] + 1)
138
+
139
+
140
+ class ResModConvBlock(ResidualBlock):
141
+ def __init__(self, state, feats_in, c_in, c_mid, c_out, is_last=False):
142
+ skip = None if c_in == c_out else nn.Conv2d(c_in, c_out, 1, bias=False)
143
+ super().__init__([
144
+ nn.Conv2d(c_in, c_mid, 3, padding=1),
145
+ nn.GroupNorm(1, c_mid, affine=False),
146
+ Modulation2d(state, feats_in, c_mid),
147
+ nn.ReLU(inplace=True),
148
+ nn.Conv2d(c_mid, c_out, 3, padding=1),
149
+ nn.GroupNorm(1, c_out, affine=False) if not is_last else nn.Identity(),
150
+ Modulation2d(state, feats_in, c_out) if not is_last else nn.Identity(),
151
+ nn.ReLU(inplace=True) if not is_last else nn.Identity(),
152
+ ], skip)
153
+
154
+
155
+ class SkipBlock(nn.Module):
156
+ def __init__(self, main, skip=None):
157
+ super().__init__()
158
+ self.main = nn.Sequential(*main)
159
+ self.skip = skip if skip else nn.Identity()
160
+
161
+ def forward(self, input):
162
+ return torch.cat([self.main(input), self.skip(input)], dim=1)
163
+
164
+
165
+ class FourierFeatures(nn.Module):
166
+ def __init__(self, in_features, out_features, std=1.):
167
+ super().__init__()
168
+ assert out_features % 2 == 0
169
+ self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)
170
+ self.weight.requires_grad_(False)
171
+ # self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
172
+
173
+ def forward(self, input):
174
+ f = 2 * math.pi * input @ self.weight.T
175
+ return torch.cat([f.cos(), f.sin()], dim=-1)
176
+
177
+
178
+ class SelfAttention2d(nn.Module):
179
+ def __init__(self, c_in, n_head=1, dropout_rate=0.1):
180
+ super().__init__()
181
+ assert c_in % n_head == 0
182
+ self.norm = nn.GroupNorm(1, c_in)
183
+ self.n_head = n_head
184
+ self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
185
+ self.out_proj = nn.Conv2d(c_in, c_in, 1)
186
+ self.dropout = nn.Identity() # nn.Dropout2d(dropout_rate, inplace=True)
187
+
188
+ def forward(self, input):
189
+ n, c, h, w = input.shape
190
+ qkv = self.qkv_proj(self.norm(input))
191
+ qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
192
+ q, k, v = qkv.chunk(3, dim=1)
193
+ scale = k.shape[3]**-0.25
194
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
195
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
196
+ return input + self.dropout(self.out_proj(y))
197
+
198
+
199
+ def expand_to_planes(input, shape):
200
+ return input[..., None, None].repeat([1, 1, shape[2], shape[3]])
201
+
202
+
203
+ class DiffusionModel(nn.Module):
204
+ def __init__(self, base_channels, cm, autoencoder_scale=1):
205
+ super().__init__()
206
+ c = base_channels # The base channel count
207
+ cs = [c * cm[0], c * cm[1], c * cm[2], c * cm[3]]
208
+
209
+ self.mapping_timestep_embed = FourierFeatures(1, 128)
210
+ self.mapping = nn.Sequential(
211
+ ResLinearBlock(512 + 128, 1024, 1024),
212
+ ResLinearBlock(1024, 1024, 1024, is_last=True),
213
+ )
214
+
215
+ with torch.no_grad():
216
+ for param in self.mapping.parameters():
217
+ param *= 0.5**0.5
218
+
219
+ self.state = {}
220
+ conv_block = partial(ResModConvBlock, self.state, 1024)
221
+
222
+ self.register_buffer('autoencoder_scale', autoencoder_scale)
223
+ self.timestep_embed = FourierFeatures(1, 16)
224
+ self.down = nn.AvgPool2d(2)
225
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
226
+
227
+ self.net = nn.Sequential( # 32x32
228
+ conv_block(4 + 16, cs[0], cs[0]),
229
+ conv_block(cs[0], cs[0], cs[0]),
230
+ conv_block(cs[0], cs[0], cs[0]),
231
+ conv_block(cs[0], cs[0], cs[0]),
232
+ SkipBlock([
233
+ self.down, # 16x16
234
+ conv_block(cs[0], cs[1], cs[1]),
235
+ SelfAttention2d(cs[1], cs[1] // 64),
236
+ conv_block(cs[1], cs[1], cs[1]),
237
+ SelfAttention2d(cs[1], cs[1] // 64),
238
+ conv_block(cs[1], cs[1], cs[1]),
239
+ SelfAttention2d(cs[1], cs[1] // 64),
240
+ conv_block(cs[1], cs[1], cs[1]),
241
+ SelfAttention2d(cs[1], cs[1] // 64),
242
+ SkipBlock([
243
+ self.down, # 8x8
244
+ conv_block(cs[1], cs[2], cs[2]),
245
+ SelfAttention2d(cs[2], cs[2] // 64),
246
+ conv_block(cs[2], cs[2], cs[2]),
247
+ SelfAttention2d(cs[2], cs[2] // 64),
248
+ conv_block(cs[2], cs[2], cs[2]),
249
+ SelfAttention2d(cs[2], cs[2] // 64),
250
+ conv_block(cs[2], cs[2], cs[2]),
251
+ SelfAttention2d(cs[2], cs[2] // 64),
252
+ SkipBlock([
253
+ self.down, # 4x4
254
+ conv_block(cs[2], cs[3], cs[3]),
255
+ SelfAttention2d(cs[3], cs[3] // 64),
256
+ conv_block(cs[3], cs[3], cs[3]),
257
+ SelfAttention2d(cs[3], cs[3] // 64),
258
+ conv_block(cs[3], cs[3], cs[3]),
259
+ SelfAttention2d(cs[3], cs[3] // 64),
260
+ conv_block(cs[3], cs[3], cs[3]),
261
+ SelfAttention2d(cs[3], cs[3] // 64),
262
+ conv_block(cs[3], cs[3], cs[3]),
263
+ SelfAttention2d(cs[3], cs[3] // 64),
264
+ conv_block(cs[3], cs[3], cs[3]),
265
+ SelfAttention2d(cs[3], cs[3] // 64),
266
+ conv_block(cs[3], cs[3], cs[3]),
267
+ SelfAttention2d(cs[3], cs[3] // 64),
268
+ conv_block(cs[3], cs[3], cs[2]),
269
+ SelfAttention2d(cs[2], cs[2] // 64),
270
+ self.up,
271
+ ]),
272
+ conv_block(cs[2] * 2, cs[2], cs[2]),
273
+ SelfAttention2d(cs[2], cs[2] // 64),
274
+ conv_block(cs[2], cs[2], cs[2]),
275
+ SelfAttention2d(cs[2], cs[2] // 64),
276
+ conv_block(cs[2], cs[2], cs[2]),
277
+ SelfAttention2d(cs[2], cs[2] // 64),
278
+ conv_block(cs[2], cs[2], cs[1]),
279
+ SelfAttention2d(cs[1], cs[1] // 64),
280
+ self.up,
281
+ ]),
282
+ conv_block(cs[1] * 2, cs[1], cs[1]),
283
+ SelfAttention2d(cs[1], cs[1] // 64),
284
+ conv_block(cs[1], cs[1], cs[1]),
285
+ SelfAttention2d(cs[1], cs[1] // 64),
286
+ conv_block(cs[1], cs[1], cs[1]),
287
+ SelfAttention2d(cs[1], cs[1] // 64),
288
+ conv_block(cs[1], cs[1], cs[0]),
289
+ SelfAttention2d(cs[0], cs[0] // 64),
290
+ self.up,
291
+ ]),
292
+ conv_block(cs[0] * 2, cs[0], cs[0]),
293
+ conv_block(cs[0], cs[0], cs[0]),
294
+ conv_block(cs[0], cs[0], cs[0]),
295
+ conv_block(cs[0], cs[0], 4, is_last=True),)
296
+ with torch.no_grad():
297
+ for param in self.net.parameters():
298
+ param *= 0.5**0.5
299
+
300
+ def forward(self, input, t, clip_embed):
301
+ clip_embed = F.normalize(clip_embed, dim=-1) * clip_embed.shape[-1]**0.5
302
+ mapping_timestep_embed = self.mapping_timestep_embed(t[:, None])
303
+ self.state['cond'] = self.mapping(torch.cat([clip_embed, mapping_timestep_embed], dim=1))
304
+ timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)
305
+ out = self.net(torch.cat([input, timestep_embed], dim=1))
306
+ self.state.clear()
307
+ return out
308
+
309
+
310
+ class TokenizerWrapper:
311
+ def __init__(self, max_len=None):
312
+ self.tokenizer = clip.simple_tokenizer.SimpleTokenizer()
313
+ self.sot_token = self.tokenizer.encoder['<|startoftext|>']
314
+ self.eot_token = self.tokenizer.encoder['<|endoftext|>']
315
+ self.context_length = 77
316
+ self.max_len = self.context_length - 2 if max_len is None else max_len
317
+
318
+ def __call__(self, texts):
319
+ if isinstance(texts, str):
320
+ texts = [texts]
321
+ result = torch.zeros([len(texts), self.context_length], dtype=torch.long)
322
+ for i, text in enumerate(texts):
323
+ tokens_trunc = self.tokenizer.encode(text)[:self.max_len]
324
+ tokens = [self.sot_token, *tokens_trunc, self.eot_token]
325
+ result[i, :len(tokens)] = torch.tensor(tokens)
326
+ return result
327
+
328
+
329
+ class ToMode:
330
+ def __init__(self, mode):
331
+ self.mode = mode
332
+
333
+ def __call__(self, image):
334
+ return image.convert(self.mode)
335
+
336
+
337
+ class LightningDiffusion(pl.LightningModule):
338
+ def __init__(self, cloob_checkpoint, vqgan_model, train_dl, autoencoder_scale,
339
+ base_channels=128, channel_multipliers="4,4,8,8", ema_decay_at=200000,
340
+ load_from=None #<<<
341
+ ):
342
+ super().__init__()
343
+
344
+ # autoencoder
345
+ ae_config = OmegaConf.load(vqgan_model + '.yaml')
346
+ self.ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
347
+ self.ae_model.eval().requires_grad_(False)
348
+ self.ae_model.init_from_ckpt(vqgan_model + '.ckpt')
349
+ self.register_buffer('scale_factor', autoencoder_scale)
350
+
351
+ # CLOOB
352
+ cloob_config = pretrained.get_config(cloob_checkpoint)
353
+ self.cloob = model_pt.get_pt_model(cloob_config)
354
+ checkpoint = pretrained.download_checkpoint(cloob_config)
355
+ self.cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
356
+ self.cloob.eval().requires_grad_(False)
357
+
358
+ # Diffusion model
359
+ self.model = DiffusionModel(base_channels,
360
+ [int(i) for i in channel_multipliers.strip().split(",")],
361
+ autoencoder_scale)
362
+
363
+ if load_from != None: # <<<
364
+ self.model.load_state_dict(torch.load(load_from)) # <<<
365
+
366
+ self.model_ema = deepcopy(self.model)
367
+ self.ema_decay_at = ema_decay_at
368
+
369
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
370
+
371
+ def encode(self, image):
372
+ return self.ae_model.encode(image).sample() / self.scale_factor
373
+
374
+ def decode(self, latent):
375
+ return self.ae_model.decode(latent * self.scale_factor)
376
+
377
+ def forward(self, *args, **kwargs):
378
+ if self.training:
379
+ return self.model(*args, **kwargs)
380
+ return self.model_ema(*args, **kwargs)
381
+
382
+ def configure_optimizers(self):
383
+ return optim.AdamW(self.model.parameters(), lr=3e-5, weight_decay=0.01)
384
+ # return optim.AdamW(self.model.parameters(), lr=5e-6, weight_decay=0.01)
385
+
386
+ def eval_batch(self, batch):
387
+ reals, _ = batch
388
+ cloob_reals = F.interpolate(reals, (224, 224), mode='bicubic', align_corners=False)
389
+ cond = self.cloob.image_encoder(self.cloob.normalize(cloob_reals))
390
+ del cloob_reals
391
+ reals = self.encode(reals * 2 - 1)
392
+ p = torch.rand([reals.shape[0], 1], device=reals.device)
393
+ cond = torch.where(p > 0.2, cond, torch.zeros_like(cond))
394
+
395
+ # Sample timesteps
396
+ t = self.rng.draw(reals.shape[0])[:, 0].to(reals)
397
+
398
+ # Calculate the noise schedule parameters for those timesteps
399
+ alphas, sigmas = get_alphas_sigmas(t)
400
+
401
+ # Combine the ground truth images and the noise
402
+ alphas = alphas[:, None, None, None]
403
+ sigmas = sigmas[:, None, None, None]
404
+ noise = torch.randn_like(reals)
405
+ noised_reals = reals * alphas + noise * sigmas
406
+ targets = noise * alphas - reals * sigmas
407
+
408
+ # Compute the model output and the loss.
409
+ v = self(noised_reals, t, cond)
410
+ return F.mse_loss(v, targets)
411
+
412
+ def training_step(self, batch, batch_idx):
413
+ loss = self.eval_batch(batch)
414
+ log_dict = {'train/loss': loss.detach()}
415
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
416
+ return loss
417
+
418
+ def on_before_zero_grad(self, *args, **kwargs):
419
+ if self.trainer.global_step < 20000:
420
+ decay = 0.99
421
+ elif self.trainer.global_step < self.ema_decay_at:
422
+ decay = 0.999
423
+ else:
424
+ decay = 0.9999
425
+ ema_update(self.model, self.model_ema, decay)
426
+
427
+
428
+ class DemoCallback(pl.Callback):
429
+ def __init__(self, prompts, prompts_toks, demo_every=2000):
430
+ super().__init__()
431
+ self.prompts = prompts
432
+ self.prompts_toks = prompts_toks
433
+ self.demo_every = demo_every
434
+
435
+ @rank_zero_only
436
+ @torch.no_grad()
437
+ def on_batch_end(self, trainer, module):
438
+ if trainer.global_step % self.demo_every != 0:
439
+ return
440
+
441
+ lines = [f'({i // 4}, {i % 4}) {line}' for i, line in enumerate(self.prompts)]
442
+ lines_text = '\n'.join(lines)
443
+ Path('demo_prompts_out.txt').write_text(lines_text)
444
+
445
+ noise = torch.randn([16, 4, 32, 32], device=module.device)
446
+ clip_embed = module.cloob.text_encoder(self.prompts_toks.to(module.device))
447
+ t = torch.linspace(1, 0, 50 + 1)[:-1]
448
+ steps = diffusion_utils.get_spliced_ddpm_cosine_schedule(t)
449
+ def model_fn(x, t, clip_embed):
450
+ x_in = torch.cat([x, x])
451
+ t_in = torch.cat([t, t])
452
+ clip_embed_in = torch.cat([torch.zeros_like(clip_embed), clip_embed])
453
+ v_uncond, v_cond = module(x_in, t_in, clip_embed_in).chunk(2, dim=0)
454
+ return v_uncond + (v_cond - v_uncond) * 3
455
+ with eval_mode(module):
456
+ fakes = sampling.plms_sample(model_fn, noise, steps, {'clip_embed': clip_embed})
457
+ # fakes = sample(module, noise, 1000, 1, {'clip_embed': clip_embed}, guidance_scale=3.)
458
+ fakes = module.decode(fakes)
459
+
460
+ grid = utils.make_grid(fakes, 4, padding=0).cpu()
461
+ image = TF.to_pil_image(grid.add(1).div(2).clamp(0, 1))
462
+ filename = f'demo_{trainer.global_step:08}.png'
463
+ image.save(filename)
464
+ log_dict = {'demo_grid': wandb.Image(image),
465
+ 'prompts': wandb.Html(f'<pre>{lines_text}</pre>')}
466
+ trainer.logger.experiment.log(log_dict, step=trainer.global_step)
467
+ del(clip_embed)
468
+
469
+
470
+ class ExceptionCallback(pl.Callback):
471
+ def on_exception(self, trainer, module, err):
472
+ print(f'{type(err).__name__}: {err!s}', file=sys.stderr)
473
+
474
+
475
+ def worker_init_fn(worker_id):
476
+ random.seed(torch.initial_seed())
477
+
478
+ def main():
479
+ p = argparse.ArgumentParser()
480
+ p.add_argument("--cloob-checkpoint", type=str,
481
+ default='cloob_laion_400m_vit_b_16_16_epochs',
482
+ help="the CLOOB to condition with")
483
+ p.add_argument("--vqgan-model", type=str, required=True,
484
+ help="the VQGAN checkpoint")
485
+ p.add_argument("--autoencoder-scale",
486
+ type=lambda x: torch.tensor(float(x)), required=True,
487
+ help="the VQGAN autoencoder scale")
488
+ p.add_argument('--train-set', type=Path, required=True,
489
+ help='path to the text file containing your training paths')
490
+ p.add_argument('--checkpoint-every', type=int, default=50000,
491
+ help='output a model checkpoint every N steps')
492
+ p.add_argument('--resume-from', type=str, default=None,
493
+ help='resume from (or finetune) the checkpoint at path')
494
+ p.add_argument('--demo-prompts', type=Path, required=True,
495
+ help='the demo prompts')
496
+ p.add_argument('--demo-every', type=int, default=2000,
497
+ help='output a demo grid every N steps')
498
+ p.add_argument('--wandb-project', type=str, required=True,
499
+ help='the wandb project to log to for this run')
500
+ p.add_argument('--fprecision', type=int, default=32,
501
+ help='The precision to train in (32, 16, etc)')
502
+ p.add_argument('--num-gpus', type=int, default=1,
503
+ help='the number of gpus to train with')
504
+ p.add_argument('--num-workers', type=int, default=12,
505
+ help='the number of workers to load batches with')
506
+ p.add_argument('--batch-size', type=int, default=64,
507
+ help='the batch size to use per step')
508
+ p.add_argument('--base-channels', type=int, default=128,
509
+ help='the base channel count (width) for the model')
510
+ p.add_argument('--channel-multipliers', type=str, default="4,4,8,8",
511
+ help='comma separated multiplier constants for the four model resolutions')
512
+ p.add_argument('--ema-decay-at', type=int, default=200000,
513
+ help='the step to tighten ema decay at')
514
+ args = p.parse_args()
515
+
516
+ batch_size = args.batch_size
517
+ size = 256
518
+
519
+ TRAIN_PATHS = args.train_set
520
+
521
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
522
+ print('Using device:', device)
523
+
524
+
525
+ def tf(image):
526
+ return transforms.Compose([
527
+ ToMode('RGB'),
528
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.LANCZOS),
529
+ transforms.CenterCrop(size),
530
+ transforms.ToTensor(),
531
+ ])(image)
532
+ tok_wrap = TokenizerWrapper()
533
+
534
+
535
+ class CustomDataset(data.Dataset):
536
+ def __init__(self, train_paths, transform=None, target_transform=None):
537
+ with open(train_paths) as infile:
538
+ self.paths = [line.strip() for line in infile.readlines() if line.strip()]
539
+ self.transform = transform
540
+ self.target_transform = target_transform
541
+
542
+ def __len__(self):
543
+ return len(self.paths)
544
+
545
+ def __getitem__(self, idx):
546
+ img_path = self.paths[idx]
547
+ image = Image.open(img_path)
548
+ if self.transform:
549
+ image = self.transform(image)
550
+ return image, 0 # Pretend this is a None
551
+
552
+ train_set = CustomDataset(TRAIN_PATHS, transform=tf)
553
+ train_dl = data.DataLoader(train_set, batch_size, shuffle=True, drop_last=True,
554
+ num_workers=args.num_workers, persistent_workers=True, pin_memory=True)
555
+
556
+ demo_prompts = Path(args.demo_prompts).read_text().strip().split('\n')
557
+ demo_prompts = tok_wrap(demo_prompts)
558
+
559
+ model = LightningDiffusion(args.cloob_checkpoint, args.vqgan_model, train_dl,
560
+ args.autoencoder_scale,
561
+ args.base_channels, args.channel_multipliers, args.ema_decay_at,
562
+ load_from=args.resume_from # <<<
563
+ )
564
+
565
+ wandb_logger = pl.loggers.WandbLogger(project=args.wandb_project)
566
+ wandb_logger.watch(model.model)
567
+ ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1)
568
+ demo_callback = DemoCallback(demo_prompts, demo_prompts, args.demo_every)
569
+ exc_callback = ExceptionCallback()
570
+ trainer = pl.Trainer(
571
+ gpus=args.num_gpus,
572
+ num_nodes=1,
573
+ strategy='ddp',
574
+ precision=args.fprecision,
575
+ callbacks=[ckpt_callback, demo_callback, exc_callback],
576
+ logger=wandb_logger,
577
+ log_every_n_steps=1,
578
+ max_epochs=10000000,
579
+ # resume_from_checkpoint=args.resume_from, # <<<
580
+ )
581
+
582
+ trainer.fit(model, train_dl)
583
+
584
+
585
+ if __name__ == '__main__':
586
+ main()