HBDing commited on
Commit
f53fb95
·
1 Parent(s): 796dc92
Files changed (7) hide show
  1. Dockerfile +6 -0
  2. app.py +40 -5
  3. migc/migc_arch.py +220 -0
  4. migc/migc_layers.py +241 -0
  5. migc/migc_pipeline.py +928 -0
  6. migc/migc_utils.py +143 -0
  7. requirements.txt +3 -1
Dockerfile CHANGED
@@ -1,6 +1,12 @@
1
 
2
  FROM python:3.10
3
 
 
 
 
 
 
 
4
  WORKDIR /code
5
 
6
  COPY --link --chown=1000 . .
 
1
 
2
  FROM python:3.10
3
 
4
+ # Set up a new user named "user" with user ID 1000
5
+ RUN useradd -m -u 1000 user
6
+
7
+ # Switch to the "user" user
8
+ USER user
9
+
10
  WORKDIR /code
11
 
12
  COPY --link --chown=1000 . .
app.py CHANGED
@@ -4,14 +4,44 @@ from gradio_image_annotation import image_annotator
4
  from diffusers import StableDiffusionPipeline
5
  import os
6
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  # Load model
10
- pipe = StableDiffusionPipeline.from_pretrained(
11
- "runwayml/stable-diffusion-v1-5",
12
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
13
- ).to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
14
  pipe.safety_checker = None
 
 
 
 
 
 
15
 
16
  example_annotation = {
17
  "image": os.path.join(os.path.dirname(__file__), "background.png"),
@@ -26,15 +56,20 @@ def get_boxes_json(annotations):
26
  height = image.shape[0]
27
  boxes = annotations["boxes"]
28
  prompt_final = [[]]
 
29
  for box in boxes:
30
  box["xmin"] = box["xmin"] / width
31
  box["xmax"] = box["xmax"] / width
32
  box["ymin"] = box["ymin"] / height
33
  box["ymax"] = box["ymax"] / height
34
  prompt_final[0].append(box["label"])
 
35
  # import pdb; pdb.set_trace()
36
  prompt = ", ".join(prompt_final[0])
37
- image = pipe(prompt).images[0]
 
 
 
38
  return image
39
  # return annotations["boxes"]
40
 
 
4
  from diffusers import StableDiffusionPipeline
5
  import os
6
  import torch
7
+ from diffusers import EulerDiscreteScheduler
8
+ from migc.migc_utils import seed_everything
9
+ from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore
10
+
11
+
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ # 下载文件
15
+ migc_ckpt_path = hf_hub_download(
16
+ repo_id="limuloo1999/MIGC",
17
+ filename="MIGC_SD14.ckpt",
18
+ repo_type="model" # 也可以省略,默认就是 model
19
+ )
20
+
21
+ RV_path = hf_hub_download(
22
+ repo_id="SG161222/Realistic_Vision_V6.0_B1_noVAE",
23
+ filename="Realistic_Vision_V6.0_NV_B1.safetensors",
24
+ repo_type="model" # 也可以省略,默认就是 model
25
+ )
26
+
27
 
28
 
29
  # Load model
30
+ # pipe = StableDiffusionMIGCPipeline.from_pretrained(
31
+ # "rSG161222/Realistic_Vision_V6.0_B1_noVAE",
32
+ # torch_dtype=torch.float32
33
+ # )
34
+ pipe = StableDiffusionMIGCPipeline.from_single_file(
35
+ RV_path,
36
+ torch_dtype=torch.float32
37
+ )
38
  pipe.safety_checker = None
39
+ pipe.attention_store = AttentionStore()
40
+ from migc.migc_utils import load_migc
41
+ load_migc(pipe.unet , pipe.attention_store,
42
+ migc_ckpt_path, attn_processor=MIGCProcessor)
43
+ pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
44
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
45
 
46
  example_annotation = {
47
  "image": os.path.join(os.path.dirname(__file__), "background.png"),
 
56
  height = image.shape[0]
57
  boxes = annotations["boxes"]
58
  prompt_final = [[]]
59
+ bboxes = [[]]
60
  for box in boxes:
61
  box["xmin"] = box["xmin"] / width
62
  box["xmax"] = box["xmax"] / width
63
  box["ymin"] = box["ymin"] / height
64
  box["ymax"] = box["ymax"] / height
65
  prompt_final[0].append(box["label"])
66
+ bboxes[0].append([box["xmin"], box["ymin"], box["xmax"], box["ymax"]])
67
  # import pdb; pdb.set_trace()
68
  prompt = ", ".join(prompt_final[0])
69
+ prompt_final[0].insert(0, prompt)
70
+ negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry'
71
+ image = pipe(prompt_final, bboxes, num_inference_steps=30, guidance_scale=7.5,
72
+ MIGCsteps=15, aug_phase_with_and=False, negative_prompt=negative_prompt).images[0]
73
  return image
74
  # return annotations["boxes"]
75
 
migc/migc_arch.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from migc.migc_layers import CBAM, CrossAttention, LayoutAttention
6
+
7
+
8
+ class FourierEmbedder():
9
+ def __init__(self, num_freqs=64, temperature=100):
10
+ self.num_freqs = num_freqs
11
+ self.temperature = temperature
12
+ self.freq_bands = temperature ** ( torch.arange(num_freqs) / num_freqs )
13
+
14
+ @ torch.no_grad()
15
+ def __call__(self, x, cat_dim=-1):
16
+ out = []
17
+ for freq in self.freq_bands:
18
+ out.append( torch.sin( freq*x ) )
19
+ out.append( torch.cos( freq*x ) )
20
+ return torch.cat(out, cat_dim) # torch.Size([5, 30, 64])
21
+
22
+
23
+ class PositionNet(nn.Module):
24
+ def __init__(self, in_dim, out_dim, fourier_freqs=8):
25
+ super().__init__()
26
+ self.in_dim = in_dim
27
+ self.out_dim = out_dim
28
+
29
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
30
+ self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
31
+
32
+ # -------------------------------------------------------------- #
33
+ self.linears_position = nn.Sequential(
34
+ nn.Linear(self.position_dim, 512),
35
+ nn.SiLU(),
36
+ nn.Linear(512, 512),
37
+ nn.SiLU(),
38
+ nn.Linear(512, out_dim),
39
+ )
40
+
41
+ def forward(self, boxes):
42
+
43
+ # embedding position (it may includes padding as placeholder)
44
+ xyxy_embedding = self.fourier_embedder(boxes) # B*1*4 --> B*1*C torch.Size([5, 1, 64])
45
+ xyxy_embedding = self.linears_position(xyxy_embedding) # B*1*C --> B*1*768 torch.Size([5, 1, 768])
46
+
47
+ return xyxy_embedding
48
+
49
+
50
+ class SAC(nn.Module):
51
+ def __init__(self, C, number_pro=30):
52
+ super().__init__()
53
+ self.C = C
54
+ self.number_pro = number_pro
55
+ self.conv1 = nn.Conv2d(C + 1, C, 1, 1)
56
+ self.cbam1 = CBAM(C)
57
+ self.conv2 = nn.Conv2d(C, 1, 1, 1)
58
+ self.cbam2 = CBAM(number_pro, reduction_ratio=1)
59
+
60
+ def forward(self, x, guidance_mask, sac_scale=None):
61
+ '''
62
+ :param x: (B, phase_num, HW, C)
63
+ :param guidance_mask: (B, phase_num, H, W)
64
+ :return:
65
+ '''
66
+ B, phase_num, HW, C = x.shape
67
+ _, _, H, W = guidance_mask.shape
68
+ guidance_mask = guidance_mask.view(guidance_mask.shape[0], phase_num, -1)[
69
+ ..., None] # (B, phase_num, HW, 1)
70
+
71
+ null_x = torch.zeros_like(x[:, [0], ...]).to(x.device)
72
+ null_mask = torch.zeros_like(guidance_mask[:, [0], ...]).to(guidance_mask.device)
73
+
74
+ x = torch.cat([x, null_x], dim=1)
75
+ guidance_mask = torch.cat([guidance_mask, null_mask], dim=1)
76
+ phase_num += 1
77
+
78
+
79
+ scale = torch.cat([x, guidance_mask], dim=-1) # (B, phase_num, HW, C+1)
80
+ scale = scale.view(-1, H, W, C + 1) # (B * phase_num, H, W, C+1)
81
+ scale = scale.permute(0, 3, 1, 2) # (B * phase_num, C+1, H, W)
82
+ scale = self.conv1(scale) # (B * phase_num, C, H, W)
83
+ scale = self.cbam1(scale) # (B * phase_num, C, H, W)
84
+ scale = self.conv2(scale) # (B * phase_num, 1, H, W)
85
+ scale = scale.view(B, phase_num, H, W) # (B, phase_num, H, W)
86
+
87
+ null_scale = scale[:, [-1], ...]
88
+ scale = scale[:, :-1, ...]
89
+ x = x[:, :-1, ...]
90
+
91
+ pad_num = self.number_pro - phase_num + 1
92
+
93
+ ori_phase_num = scale[:, 1:-1, ...].shape[1]
94
+ phase_scale = torch.cat([scale[:, 1:-1, ...], null_scale.repeat(1, pad_num, 1, 1)], dim=1)
95
+ shuffled_order = torch.randperm(phase_scale.shape[1])
96
+ inv_shuffled_order = torch.argsort(shuffled_order)
97
+
98
+ random_phase_scale = phase_scale[:, shuffled_order, ...]
99
+
100
+ scale = torch.cat([scale[:, [0], ...], random_phase_scale, scale[:, [-1], ...]], dim=1)
101
+ # (B, number_pro, H, W)
102
+
103
+ scale = self.cbam2(scale) # (B, number_pro, H, W)
104
+ scale = scale.view(B, self.number_pro, HW)[..., None] # (B, number_pro, HW)
105
+
106
+ random_phase_scale = scale[:, 1: -1, ...]
107
+ phase_scale = random_phase_scale[:, inv_shuffled_order[:ori_phase_num], :]
108
+ if sac_scale is not None:
109
+ instance_num = len(sac_scale)
110
+ for i in range(instance_num):
111
+ phase_scale[:, i, ...] = phase_scale[:, i, ...] * sac_scale[i]
112
+
113
+
114
+ scale = torch.cat([scale[:, [0], ...], phase_scale, scale[:, [-1], ...]], dim=1)
115
+
116
+ scale = scale.softmax(dim=1) # (B, phase_num, HW, 1)
117
+ out = (x * scale).sum(dim=1, keepdims=True) # (B, 1, HW, C)
118
+ return out, scale
119
+
120
+
121
+ class MIGC(nn.Module):
122
+ def __init__(self, C, attn_type='base', context_dim=768, heads=8):
123
+ super().__init__()
124
+ self.ea = CrossAttention(query_dim=C, context_dim=context_dim,
125
+ heads=heads, dim_head=C // heads,
126
+ dropout=0.0)
127
+ self.la = LayoutAttention(query_dim=C,
128
+ heads=heads, dim_head=C // heads,
129
+ dropout=0.0)
130
+ self.norm = nn.LayerNorm(C)
131
+ self.sac = SAC(C)
132
+ self.pos_net = PositionNet(in_dim=768, out_dim=768)
133
+
134
+ def forward(self, ca_x, guidance_mask, other_info, return_fuser_info=False):
135
+ # x: (B, instance_num+1, HW, C)
136
+ # guidance_mask: (B, instance_num, H, W)
137
+ # box: (instance_num, 4)
138
+ # image_token: (B, instance_num+1, HW, C)
139
+ full_H = other_info['height']
140
+ full_W = other_info['width']
141
+ B, _, HW, C = ca_x.shape
142
+ instance_num = guidance_mask.shape[1]
143
+ down_scale = int(math.sqrt(full_H * full_W // ca_x.shape[2]))
144
+ H = full_H // down_scale
145
+ W = full_W // down_scale
146
+ guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='bilinear') # (B, instance_num, H, W)
147
+
148
+
149
+ supplement_mask = other_info['supplement_mask'] # (B, 1, 64, 64)
150
+ supplement_mask = F.interpolate(supplement_mask, size=(H, W), mode='bilinear') # (B, 1, H, W)
151
+ image_token = other_info['image_token']
152
+ assert image_token.shape == ca_x.shape
153
+ context = other_info['context_pooler']
154
+ box = other_info['box']
155
+ box = box.view(B * instance_num, 1, -1)
156
+ box_token = self.pos_net(box)
157
+ context = torch.cat([context[1:, ...], box_token], dim=1)
158
+ ca_scale = other_info['ca_scale'] if 'ca_scale' in other_info else None
159
+ ea_scale = other_info['ea_scale'] if 'ea_scale' in other_info else None
160
+ sac_scale = other_info['sac_scale'] if 'sac_scale' in other_info else None
161
+
162
+ ea_x, ea_attn = self.ea(self.norm(image_token[:, 1:, ...].view(B * instance_num, HW, C)),
163
+ context=context, return_attn=True)
164
+ ea_x = ea_x.view(B, instance_num, HW, C)
165
+ ea_x = ea_x * guidance_mask.view(B, instance_num, HW, 1)
166
+
167
+ ca_x[:, 1:, ...] = ca_x[:, 1:, ...] * guidance_mask.view(B, instance_num, HW, 1) # (B, phase_num, HW, C)
168
+ if ca_scale is not None:
169
+ assert len(ca_scale) == instance_num
170
+ for i in range(instance_num):
171
+ ca_x[:, i+1, ...] = ca_x[:, i+1, ...] * ca_scale[i] + ea_x[:, i, ...] * ea_scale[i]
172
+ else:
173
+ ca_x[:, 1:, ...] = ca_x[:, 1:, ...] + ea_x
174
+
175
+ ori_image_token = image_token[:, 0, ...] # (B, HW, C)
176
+ fusion_template = self.la(x=ori_image_token, guidance_mask=torch.cat([guidance_mask[:, :, ...], supplement_mask], dim=1)) # (B, HW, C)
177
+ fusion_template = fusion_template.view(B, 1, HW, C) # (B, 1, HW, C)
178
+
179
+ ca_x = torch.cat([ca_x, fusion_template], dim = 1)
180
+ ca_x[:, 0, ...] = ca_x[:, 0, ...] * supplement_mask.view(B, HW, 1)
181
+ guidance_mask = torch.cat([
182
+ supplement_mask,
183
+ guidance_mask,
184
+ torch.ones(B, 1, H, W).to(guidance_mask.device)
185
+ ], dim=1)
186
+
187
+
188
+ out_MIGC, sac_scale = self.sac(ca_x, guidance_mask, sac_scale=sac_scale)
189
+ if return_fuser_info:
190
+ fuser_info = {}
191
+ fuser_info['sac_scale'] = sac_scale.view(B, instance_num + 2, H, W)
192
+ fuser_info['ea_attn'] = ea_attn.mean(dim=1).view(B, instance_num, H, W, 2)
193
+ return out_MIGC, fuser_info
194
+ else:
195
+ return out_MIGC
196
+
197
+
198
+ class NaiveFuser(nn.Module):
199
+ def __init__(self):
200
+ super().__init__()
201
+ def forward(self, ca_x, guidance_mask, other_info, return_fuser_info=False):
202
+ # ca_x: (B, instance_num+1, HW, C)
203
+ # guidance_mask: (B, instance_num, H, W)
204
+ # box: (instance_num, 4)
205
+ # image_token: (B, instance_num+1, HW, C)
206
+ full_H = other_info['height']
207
+ full_W = other_info['width']
208
+ B, _, HW, C = ca_x.shape
209
+ instance_num = guidance_mask.shape[1]
210
+ down_scale = int(math.sqrt(full_H * full_W // ca_x.shape[2]))
211
+ H = full_H // down_scale
212
+ W = full_W // down_scale
213
+ guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='bilinear') # (B, instance_num, H, W)
214
+ guidance_mask = torch.cat([torch.ones(B, 1, H, W).to(guidance_mask.device), guidance_mask * 10], dim=1) # (B, instance_num+1, H, W)
215
+ guidance_mask = guidance_mask.view(B, instance_num + 1, HW, 1)
216
+ out_MIGC = (ca_x * guidance_mask).sum(dim=1) / (guidance_mask.sum(dim=1) + 1e-6)
217
+ if return_fuser_info:
218
+ return out_MIGC, None
219
+ else:
220
+ return out_MIGC
migc/migc_layers.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import random
5
+ import math
6
+ from inspect import isfunction
7
+ from einops import rearrange, repeat
8
+ from torch import nn, einsum
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def default(val, d):
16
+ if exists(val):
17
+ return val
18
+ return d() if isfunction(d) else d
19
+
20
+
21
+ class CrossAttention(nn.Module):
22
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
23
+ super().__init__()
24
+ inner_dim = dim_head * heads
25
+ context_dim = default(context_dim, query_dim)
26
+
27
+ self.scale = dim_head ** -0.5
28
+ self.heads = heads
29
+
30
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
31
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
32
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
33
+
34
+ self.to_out = nn.Sequential(
35
+ nn.Linear(inner_dim, query_dim),
36
+ nn.Dropout(dropout)
37
+ )
38
+
39
+ def forward(self, x, context=None, mask=None, return_attn=False, need_softmax=True, guidance_mask=None,
40
+ forward_layout_guidance=False):
41
+ h = self.heads
42
+ b = x.shape[0]
43
+
44
+ q = self.to_q(x)
45
+ context = default(context, x)
46
+ k = self.to_k(context)
47
+ v = self.to_v(context)
48
+
49
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
50
+
51
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
52
+ if forward_layout_guidance:
53
+ # sim: (B * phase_num * h, HW, 77), b = B * phase_num
54
+ # guidance_mask: (B, phase_num, 64, 64)
55
+ HW = sim.shape[1]
56
+ H = W = int(math.sqrt(HW))
57
+ guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='nearest') # (B, phase_num, H, W)
58
+ sim = sim.view(b, h, HW, 77)
59
+ guidance_mask = guidance_mask.view(b, 1, HW, 1)
60
+ guidance_mask[guidance_mask == 1] = 5.0
61
+ guidance_mask[guidance_mask == 0] = 0.1
62
+ sim[:, :, :, 1:] = sim[:, :, :, 1:] * guidance_mask
63
+ sim = sim.view(b * h, HW, 77)
64
+
65
+ if exists(mask):
66
+ mask = rearrange(mask, 'b ... -> b (...)')
67
+ max_neg_value = -torch.finfo(sim.dtype).max
68
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
69
+ sim.masked_fill_(~mask, max_neg_value)
70
+
71
+ if need_softmax:
72
+ attn = sim.softmax(dim=-1)
73
+ else:
74
+ attn = sim
75
+
76
+ out = einsum('b i j, b j d -> b i d', attn, v)
77
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
78
+ if return_attn:
79
+ attn = attn.view(b, h, attn.shape[-2], attn.shape[-1])
80
+ return self.to_out(out), attn
81
+ else:
82
+ return self.to_out(out)
83
+
84
+
85
+ class LayoutAttention(nn.Module):
86
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., use_lora=False):
87
+ super().__init__()
88
+ inner_dim = dim_head * heads
89
+ context_dim = default(context_dim, query_dim)
90
+
91
+ self.use_lora = use_lora
92
+ self.scale = dim_head ** -0.5
93
+ self.heads = heads
94
+
95
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
96
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
97
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
98
+
99
+ self.to_out = nn.Sequential(
100
+ nn.Linear(inner_dim, query_dim),
101
+ nn.Dropout(dropout)
102
+ )
103
+
104
+ def forward(self, x, context=None, mask=None, return_attn=False, need_softmax=True, guidance_mask=None):
105
+ h = self.heads
106
+ b = x.shape[0]
107
+
108
+ q = self.to_q(x)
109
+ context = default(context, x)
110
+ k = self.to_k(context)
111
+ v = self.to_v(context)
112
+
113
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
114
+
115
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
116
+
117
+ _, phase_num, H, W = guidance_mask.shape
118
+ HW = H * W
119
+ guidance_mask_o = guidance_mask.view(b * phase_num, HW, 1)
120
+ guidance_mask_t = guidance_mask.view(b * phase_num, 1, HW)
121
+ guidance_mask_sim = torch.bmm(guidance_mask_o, guidance_mask_t) # (B * phase_num, HW, HW)
122
+ guidance_mask_sim = guidance_mask_sim.view(b, phase_num, HW, HW).sum(dim=1)
123
+ guidance_mask_sim[guidance_mask_sim > 1] = 1 # (B, HW, HW)
124
+ guidance_mask_sim = guidance_mask_sim.view(b, 1, HW, HW)
125
+ guidance_mask_sim = guidance_mask_sim.repeat(1, self.heads, 1, 1)
126
+ guidance_mask_sim = guidance_mask_sim.view(b * self.heads, HW, HW) # (B * head, HW, HW)
127
+
128
+ sim[:, :, :HW][guidance_mask_sim == 0] = -torch.finfo(sim.dtype).max
129
+
130
+ if exists(mask):
131
+ mask = rearrange(mask, 'b ... -> b (...)')
132
+ max_neg_value = -torch.finfo(sim.dtype).max
133
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
134
+ sim.masked_fill_(~mask, max_neg_value)
135
+
136
+ # attention, what we cannot get enough of
137
+
138
+ if need_softmax:
139
+ attn = sim.softmax(dim=-1)
140
+ else:
141
+ attn = sim
142
+
143
+ out = einsum('b i j, b j d -> b i d', attn, v)
144
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
145
+ if return_attn:
146
+ attn = attn.view(b, h, attn.shape[-2], attn.shape[-1])
147
+ return self.to_out(out), attn
148
+ else:
149
+ return self.to_out(out)
150
+
151
+
152
+ class BasicConv(nn.Module):
153
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=False, bias=False):
154
+ super(BasicConv, self).__init__()
155
+ self.out_channels = out_planes
156
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
157
+ self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
158
+ self.relu = nn.ReLU() if relu else None
159
+
160
+ def forward(self, x):
161
+ x = self.conv(x)
162
+ if self.bn is not None:
163
+ x = self.bn(x)
164
+ if self.relu is not None:
165
+ x = self.relu(x)
166
+ return x
167
+
168
+ class Flatten(nn.Module):
169
+ def forward(self, x):
170
+ return x.view(x.size(0), -1)
171
+
172
+ class ChannelGate(nn.Module):
173
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
174
+ super(ChannelGate, self).__init__()
175
+ self.gate_channels = gate_channels
176
+ self.mlp = nn.Sequential(
177
+ Flatten(),
178
+ nn.Linear(gate_channels, gate_channels // reduction_ratio),
179
+ nn.ReLU(),
180
+ nn.Linear(gate_channels // reduction_ratio, gate_channels)
181
+ )
182
+ self.pool_types = pool_types
183
+ def forward(self, x):
184
+ channel_att_sum = None
185
+ for pool_type in self.pool_types:
186
+ if pool_type=='avg':
187
+ avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
188
+ channel_att_raw = self.mlp( avg_pool )
189
+ elif pool_type=='max':
190
+ max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
191
+ channel_att_raw = self.mlp( max_pool )
192
+ elif pool_type=='lp':
193
+ lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
194
+ channel_att_raw = self.mlp( lp_pool )
195
+ elif pool_type=='lse':
196
+ # LSE pool only
197
+ lse_pool = logsumexp_2d(x)
198
+ channel_att_raw = self.mlp( lse_pool )
199
+
200
+ if channel_att_sum is None:
201
+ channel_att_sum = channel_att_raw
202
+ else:
203
+ channel_att_sum = channel_att_sum + channel_att_raw
204
+
205
+ scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
206
+ return x * scale
207
+
208
+ def logsumexp_2d(tensor):
209
+ tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
210
+ s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
211
+ outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
212
+ return outputs
213
+
214
+ class ChannelPool(nn.Module):
215
+ def forward(self, x):
216
+ return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
217
+
218
+ class SpatialGate(nn.Module):
219
+ def __init__(self):
220
+ super(SpatialGate, self).__init__()
221
+ kernel_size = 7
222
+ self.compress = ChannelPool()
223
+ self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
224
+ def forward(self, x):
225
+ x_compress = self.compress(x)
226
+ x_out = self.spatial(x_compress)
227
+ scale = F.sigmoid(x_out) # broadcasting
228
+ return x * scale
229
+
230
+ class CBAM(nn.Module):
231
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
232
+ super(CBAM, self).__init__()
233
+ self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
234
+ self.no_spatial=no_spatial
235
+ if not no_spatial:
236
+ self.SpatialGate = SpatialGate()
237
+ def forward(self, x):
238
+ x_out = self.ChannelGate(x)
239
+ if not self.no_spatial:
240
+ x_out = self.SpatialGate(x_out)
241
+ return x_out
migc/migc_pipeline.py ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import random
3
+ import time
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
+ # import moxing as mox
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.loaders import TextualInversionLoaderMixin
9
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
10
+ from diffusers.models.attention_processor import Attention
11
+ from diffusers.pipelines.stable_diffusion import (
12
+ StableDiffusionPipeline,
13
+ StableDiffusionPipelineOutput,
14
+ StableDiffusionSafetyChecker,
15
+ )
16
+ from diffusers.schedulers import KarrasDiffusionSchedulers
17
+ from diffusers.utils import logging
18
+ from PIL import Image, ImageDraw, ImageFont
19
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
20
+ import inspect
21
+ import os
22
+ import math
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ # from utils import load_utils
26
+ import argparse
27
+ import yaml
28
+ import cv2
29
+ import math
30
+ from migc.migc_arch import MIGC, NaiveFuser
31
+ from scipy.ndimage import uniform_filter, gaussian_filter
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ class AttentionStore:
36
+ @staticmethod
37
+ def get_empty_store():
38
+ return {"down": [], "mid": [], "up": []}
39
+
40
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
41
+ if is_cross:
42
+ if attn.shape[1] in self.attn_res:
43
+ self.step_store[place_in_unet].append(attn)
44
+
45
+ self.cur_att_layer += 1
46
+ if self.cur_att_layer == self.num_att_layers:
47
+ self.cur_att_layer = 0
48
+ self.between_steps()
49
+
50
+ def between_steps(self):
51
+ self.attention_store = self.step_store
52
+ self.step_store = self.get_empty_store()
53
+
54
+ def maps(self, block_type: str):
55
+ return self.attention_store[block_type]
56
+
57
+ def reset(self):
58
+ self.cur_att_layer = 0
59
+ self.step_store = self.get_empty_store()
60
+ self.attention_store = {}
61
+
62
+ def __init__(self, attn_res=[64*64, 32*32, 16*16, 8*8]):
63
+ """
64
+ Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
65
+ process
66
+ """
67
+ self.num_att_layers = -1
68
+ self.cur_att_layer = 0
69
+ self.step_store = self.get_empty_store()
70
+ self.attention_store = {}
71
+ self.curr_step_index = 0
72
+ self.attn_res = attn_res
73
+
74
+
75
+ def get_sup_mask(mask_list):
76
+ or_mask = np.zeros_like(mask_list[0])
77
+ for mask in mask_list:
78
+ or_mask += mask
79
+ or_mask[or_mask >= 1] = 1
80
+ sup_mask = 1 - or_mask
81
+ return sup_mask
82
+
83
+
84
+ class MIGCProcessor(nn.Module):
85
+ def __init__(self, config, attnstore, place_in_unet):
86
+ super().__init__()
87
+ self.attnstore = attnstore
88
+ self.place_in_unet = place_in_unet
89
+ self.not_use_migc = config['not_use_migc']
90
+ self.naive_fuser = NaiveFuser()
91
+ self.embedding = {}
92
+ if not self.not_use_migc:
93
+ self.migc = MIGC(config['C'])
94
+
95
+ def __call__(
96
+ self,
97
+ attn: Attention,
98
+ hidden_states,
99
+ encoder_hidden_states=None,
100
+ attention_mask=None,
101
+ prompt_nums=[],
102
+ bboxes=[],
103
+ ith=None,
104
+ embeds_pooler=None,
105
+ timestep=None,
106
+ height=512,
107
+ width=512,
108
+ MIGCsteps=20,
109
+ NaiveFuserSteps=-1,
110
+ ca_scale=None,
111
+ ea_scale=None,
112
+ sac_scale=None,
113
+ use_sa_preserve=False,
114
+ sa_preserve=False,
115
+ ):
116
+ batch_size, sequence_length, _ = hidden_states.shape
117
+ assert(batch_size == 2, "We currently only implement sampling with batch_size=1, \
118
+ and we will implement sampling with batch_size=N as soon as possible.")
119
+ attention_mask = attn.prepare_attention_mask(
120
+ attention_mask, sequence_length, batch_size
121
+ )
122
+
123
+ instance_num = len(bboxes[0])
124
+
125
+ if ith > MIGCsteps:
126
+ not_use_migc = True
127
+ else:
128
+ not_use_migc = self.not_use_migc
129
+ is_vanilla_cross = (not_use_migc and ith > NaiveFuserSteps)
130
+ if instance_num == 0:
131
+ is_vanilla_cross = True
132
+
133
+ is_cross = encoder_hidden_states is not None
134
+
135
+ ori_hidden_states = hidden_states.clone()
136
+
137
+ # Only Need Negative Prompt and Global Prompt.
138
+ if is_cross and is_vanilla_cross:
139
+ encoder_hidden_states = encoder_hidden_states[:2, ...]
140
+
141
+ # In this case, we need to use MIGC or naive_fuser, so we copy the hidden_states_cond (instance_num+1) times for QKV
142
+ if is_cross and not is_vanilla_cross:
143
+ hidden_states_uncond = hidden_states[[0], ...]
144
+ hidden_states_cond = hidden_states[[1], ...].repeat(instance_num + 1, 1, 1)
145
+ hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond])
146
+
147
+ # QKV Operation of Vanilla Self-Attention or Cross-Attention
148
+ query = attn.to_q(hidden_states)
149
+
150
+ if (
151
+ not is_cross
152
+ and use_sa_preserve
153
+ and timestep.item() in self.embedding
154
+ and self.place_in_unet == "up"
155
+ ):
156
+ hidden_states = torch.cat((hidden_states, torch.from_numpy(self.embedding[timestep.item()]).to(hidden_states.device)), dim=1)
157
+
158
+ if not is_cross and sa_preserve and self.place_in_unet == "up":
159
+ self.embedding[timestep.item()] = ori_hidden_states.cpu().numpy()
160
+
161
+ encoder_hidden_states = (
162
+ encoder_hidden_states
163
+ if encoder_hidden_states is not None
164
+ else hidden_states
165
+ )
166
+ key = attn.to_k(encoder_hidden_states)
167
+ value = attn.to_v(encoder_hidden_states)
168
+ query = attn.head_to_batch_dim(query)
169
+ key = attn.head_to_batch_dim(key)
170
+ value = attn.head_to_batch_dim(value)
171
+ attention_probs = attn.get_attention_scores(query, key, attention_mask) # 48 4096 77
172
+ self.attnstore(attention_probs, is_cross, self.place_in_unet)
173
+ hidden_states = torch.bmm(attention_probs, value)
174
+ hidden_states = attn.batch_to_head_dim(hidden_states)
175
+ hidden_states = attn.to_out[0](hidden_states)
176
+ hidden_states = attn.to_out[1](hidden_states)
177
+
178
+ ###### Self-Attention Results ######
179
+ if not is_cross:
180
+ return hidden_states
181
+
182
+ ###### Vanilla Cross-Attention Results ######
183
+ if is_vanilla_cross:
184
+ return hidden_states
185
+
186
+ ###### Cross-Attention with MIGC ######
187
+ assert (not is_vanilla_cross)
188
+ # hidden_states: torch.Size([1+1+instance_num, HW, C]), the first 1 is the uncond ca output, the second 1 is the global ca output.
189
+ hidden_states_uncond = hidden_states[[0], ...] # torch.Size([1, HW, C])
190
+ cond_ca_output = hidden_states[1: , ...].unsqueeze(0) # torch.Size([1, 1+instance_num, 5, 64, 1280])
191
+ guidance_masks = []
192
+ in_box = []
193
+ # Construct Instance Guidance Mask
194
+ for bbox in bboxes[0]:
195
+ guidance_mask = np.zeros((height, width))
196
+ w_min = int(width * bbox[0])
197
+ w_max = int(width * bbox[2])
198
+ h_min = int(height * bbox[1])
199
+ h_max = int(height * bbox[3])
200
+ guidance_mask[h_min: h_max, w_min: w_max] = 1.0
201
+ guidance_masks.append(guidance_mask[None, ...])
202
+ in_box.append([bbox[0], bbox[2], bbox[1], bbox[3]])
203
+
204
+ # Construct Background Guidance Mask
205
+ sup_mask = get_sup_mask(guidance_masks)
206
+ supplement_mask = torch.from_numpy(sup_mask[None, ...])
207
+ supplement_mask = F.interpolate(supplement_mask, (height//8, width//8), mode='bilinear').float()
208
+ supplement_mask = supplement_mask.to(hidden_states.device) # (1, 1, H, W)
209
+
210
+ guidance_masks = np.concatenate(guidance_masks, axis=0)
211
+ guidance_masks = guidance_masks[None, ...]
212
+ guidance_masks = torch.from_numpy(guidance_masks).float().to(cond_ca_output.device)
213
+ guidance_masks = F.interpolate(guidance_masks, (height//8, width//8), mode='bilinear') # (1, instance_num, H, W)
214
+
215
+ in_box = torch.from_numpy(np.array(in_box))[None, ...].float().to(cond_ca_output.device) # (1, instance_num, 4)
216
+
217
+ other_info = {}
218
+ other_info['image_token'] = hidden_states_cond[None, ...]
219
+ other_info['context'] = encoder_hidden_states[1:, ...]
220
+ other_info['box'] = in_box
221
+ other_info['context_pooler'] =embeds_pooler # (instance_num, 1, 768)
222
+ other_info['supplement_mask'] = supplement_mask
223
+ other_info['attn2'] = None
224
+ other_info['attn'] = attn
225
+ other_info['height'] = height
226
+ other_info['width'] = width
227
+ other_info['ca_scale'] = ca_scale
228
+ other_info['ea_scale'] = ea_scale
229
+ other_info['sac_scale'] = sac_scale
230
+
231
+ if not not_use_migc:
232
+ hidden_states_cond, fuser_info = self.migc(cond_ca_output,
233
+ guidance_masks,
234
+ other_info=other_info,
235
+ return_fuser_info=True)
236
+ else:
237
+ hidden_states_cond, fuser_info = self.naive_fuser(cond_ca_output,
238
+ guidance_masks,
239
+ other_info=other_info,
240
+ return_fuser_info=True)
241
+ hidden_states_cond = hidden_states_cond.squeeze(1)
242
+
243
+ hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond])
244
+ return hidden_states
245
+
246
+
247
+ class StableDiffusionMIGCPipeline(StableDiffusionPipeline):
248
+ def __init__(
249
+ self,
250
+ vae: AutoencoderKL,
251
+ text_encoder: CLIPTextModel,
252
+ tokenizer: CLIPTokenizer,
253
+ unet: UNet2DConditionModel,
254
+ scheduler: KarrasDiffusionSchedulers,
255
+ safety_checker: StableDiffusionSafetyChecker,
256
+ feature_extractor: CLIPImageProcessor,
257
+ image_encoder: CLIPVisionModelWithProjection = None,
258
+ requires_safety_checker: bool = True,
259
+ ):
260
+ # Get the parameter signature of the parent class constructor
261
+ parent_init_signature = inspect.signature(super().__init__)
262
+ parent_init_params = parent_init_signature.parameters
263
+
264
+ # Dynamically build a parameter dictionary based on the parameters of the parent class constructor
265
+ init_kwargs = {
266
+ "vae": vae,
267
+ "text_encoder": text_encoder,
268
+ "tokenizer": tokenizer,
269
+ "unet": unet,
270
+ "scheduler": scheduler,
271
+ "safety_checker": safety_checker,
272
+ "feature_extractor": feature_extractor,
273
+ "requires_safety_checker": requires_safety_checker
274
+ }
275
+ if 'image_encoder' in parent_init_params.items():
276
+ init_kwargs['image_encoder'] = image_encoder
277
+ super().__init__(**init_kwargs)
278
+
279
+ self.instance_set = set()
280
+ self.embedding = {}
281
+
282
+ def _encode_prompt(
283
+ self,
284
+ prompts,
285
+ device,
286
+ num_images_per_prompt,
287
+ do_classifier_free_guidance,
288
+ negative_prompt=None,
289
+ prompt_embeds: Optional[torch.FloatTensor] = None,
290
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
291
+ ):
292
+ r"""
293
+ Encodes the prompt into text encoder hidden states.
294
+
295
+ Args:
296
+ prompt (`str` or `List[str]`, *optional*):
297
+ prompt to be encoded
298
+ device: (`torch.device`):
299
+ torch device
300
+ num_images_per_prompt (`int`):
301
+ number of images that should be generated per prompt
302
+ do_classifier_free_guidance (`bool`):
303
+ whether to use classifier free guidance or not
304
+ negative_prompt (`str` or `List[str]`, *optional*):
305
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
306
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
307
+ less than `1`).
308
+ prompt_embeds (`torch.FloatTensor`, *optional*):
309
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
310
+ provided, text embeddings will be generated from `prompt` input argument.
311
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
312
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
313
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
314
+ argument.
315
+ """
316
+ if prompts is not None and isinstance(prompts, str):
317
+ batch_size = 1
318
+ elif prompts is not None and isinstance(prompts, list):
319
+ batch_size = len(prompts)
320
+ else:
321
+ batch_size = prompt_embeds.shape[0]
322
+
323
+ prompt_embeds_none_flag = (prompt_embeds is None)
324
+ prompt_embeds_list = []
325
+ embeds_pooler_list = []
326
+ for prompt in prompts:
327
+ if prompt_embeds_none_flag:
328
+ # textual inversion: procecss multi-vector tokens if necessary
329
+ if isinstance(self, TextualInversionLoaderMixin):
330
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
331
+
332
+ text_inputs = self.tokenizer(
333
+ prompt,
334
+ padding="max_length",
335
+ max_length=self.tokenizer.model_max_length,
336
+ truncation=True,
337
+ return_tensors="pt",
338
+ )
339
+ text_input_ids = text_inputs.input_ids
340
+ untruncated_ids = self.tokenizer(
341
+ prompt, padding="longest", return_tensors="pt"
342
+ ).input_ids
343
+
344
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
345
+ -1
346
+ ] and not torch.equal(text_input_ids, untruncated_ids):
347
+ removed_text = self.tokenizer.batch_decode(
348
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
349
+ )
350
+ logger.warning(
351
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
352
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
353
+ )
354
+
355
+ if (
356
+ hasattr(self.text_encoder.config, "use_attention_mask")
357
+ and self.text_encoder.config.use_attention_mask
358
+ ):
359
+ attention_mask = text_inputs.attention_mask.to(device)
360
+ else:
361
+ attention_mask = None
362
+
363
+ prompt_embeds = self.text_encoder(
364
+ text_input_ids.to(device),
365
+ attention_mask=attention_mask,
366
+ )
367
+ embeds_pooler = prompt_embeds.pooler_output
368
+ prompt_embeds = prompt_embeds[0]
369
+
370
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
371
+ embeds_pooler = embeds_pooler.to(dtype=self.text_encoder.dtype, device=device)
372
+
373
+ bs_embed, seq_len, _ = prompt_embeds.shape
374
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
375
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
376
+ embeds_pooler = embeds_pooler.repeat(1, num_images_per_prompt)
377
+ prompt_embeds = prompt_embeds.view(
378
+ bs_embed * num_images_per_prompt, seq_len, -1
379
+ )
380
+ embeds_pooler = embeds_pooler.view(
381
+ bs_embed * num_images_per_prompt, -1
382
+ )
383
+ prompt_embeds_list.append(prompt_embeds)
384
+ embeds_pooler_list.append(embeds_pooler)
385
+ prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
386
+ embeds_pooler = torch.cat(embeds_pooler_list, dim=0)
387
+ # negative_prompt_embeds: (prompt_nums[0]+prompt_nums[1]+...prompt_nums[n], token_num, token_channel), <class 'torch.Tensor'>
388
+
389
+ # get unconditional embeddings for classifier free guidance
390
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
391
+ uncond_tokens: List[str]
392
+ if negative_prompt is None:
393
+ negative_prompt = "worst quality, low quality, bad anatomy"
394
+ uncond_tokens = [negative_prompt] * batch_size
395
+
396
+ # textual inversion: procecss multi-vector tokens if necessary
397
+ if isinstance(self, TextualInversionLoaderMixin):
398
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
399
+
400
+ max_length = prompt_embeds.shape[1]
401
+ uncond_input = self.tokenizer(
402
+ uncond_tokens,
403
+ padding="max_length",
404
+ max_length=max_length,
405
+ truncation=True,
406
+ return_tensors="pt",
407
+ )
408
+
409
+ if (
410
+ hasattr(self.text_encoder.config, "use_attention_mask")
411
+ and self.text_encoder.config.use_attention_mask
412
+ ):
413
+ attention_mask = uncond_input.attention_mask.to(device)
414
+ else:
415
+ attention_mask = None
416
+
417
+ negative_prompt_embeds = self.text_encoder(
418
+ uncond_input.input_ids.to(device),
419
+ attention_mask=attention_mask,
420
+ )
421
+ negative_prompt_embeds = negative_prompt_embeds[0]
422
+
423
+ if do_classifier_free_guidance:
424
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
425
+ seq_len = negative_prompt_embeds.shape[1]
426
+
427
+ negative_prompt_embeds = negative_prompt_embeds.to(
428
+ dtype=self.text_encoder.dtype, device=device
429
+ )
430
+
431
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
432
+ 1, num_images_per_prompt, 1
433
+ )
434
+ negative_prompt_embeds = negative_prompt_embeds.view(
435
+ batch_size * num_images_per_prompt, seq_len, -1
436
+ )
437
+ # negative_prompt_embeds: (len(prompt_nums), token_num, token_channel), <class 'torch.Tensor'>
438
+
439
+ # For classifier free guidance, we need to do two forward passes.
440
+ # Here we concatenate the unconditional and text embeddings into a single batch
441
+ # to avoid doing two forward passes
442
+ final_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
443
+
444
+ return final_prompt_embeds, prompt_embeds, embeds_pooler[:, None, :]
445
+
446
+ def check_inputs(
447
+ self,
448
+ prompt,
449
+ token_indices,
450
+ bboxes,
451
+ height,
452
+ width,
453
+ callback_steps,
454
+ negative_prompt=None,
455
+ prompt_embeds=None,
456
+ negative_prompt_embeds=None,
457
+ ):
458
+ if height % 8 != 0 or width % 8 != 0:
459
+ raise ValueError(
460
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
461
+ )
462
+
463
+ if (callback_steps is None) or (
464
+ callback_steps is not None
465
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
466
+ ):
467
+ raise ValueError(
468
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
469
+ f" {type(callback_steps)}."
470
+ )
471
+
472
+ if prompt is not None and prompt_embeds is not None:
473
+ raise ValueError(
474
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
475
+ " only forward one of the two."
476
+ )
477
+ elif prompt is None and prompt_embeds is None:
478
+ raise ValueError(
479
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
480
+ )
481
+ elif prompt is not None and (
482
+ not isinstance(prompt, str) and not isinstance(prompt, list)
483
+ ):
484
+ raise ValueError(
485
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
486
+ )
487
+
488
+ if negative_prompt is not None and negative_prompt_embeds is not None:
489
+ raise ValueError(
490
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
491
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
492
+ )
493
+
494
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
495
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
496
+ raise ValueError(
497
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
498
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
499
+ f" {negative_prompt_embeds.shape}."
500
+ )
501
+
502
+ if token_indices is not None:
503
+ if isinstance(token_indices, list):
504
+ if isinstance(token_indices[0], list):
505
+ if isinstance(token_indices[0][0], list):
506
+ token_indices_batch_size = len(token_indices)
507
+ elif isinstance(token_indices[0][0], int):
508
+ token_indices_batch_size = 1
509
+ else:
510
+ raise TypeError(
511
+ "`token_indices` must be a list of lists of integers or a list of integers."
512
+ )
513
+ else:
514
+ raise TypeError(
515
+ "`token_indices` must be a list of lists of integers or a list of integers."
516
+ )
517
+ else:
518
+ raise TypeError(
519
+ "`token_indices` must be a list of lists of integers or a list of integers."
520
+ )
521
+
522
+ if bboxes is not None:
523
+ if isinstance(bboxes, list):
524
+ if isinstance(bboxes[0], list):
525
+ if (
526
+ isinstance(bboxes[0][0], list)
527
+ and len(bboxes[0][0]) == 4
528
+ and all(isinstance(x, float) for x in bboxes[0][0])
529
+ ):
530
+ bboxes_batch_size = len(bboxes)
531
+ elif (
532
+ isinstance(bboxes[0], list)
533
+ and len(bboxes[0]) == 4
534
+ and all(isinstance(x, float) for x in bboxes[0])
535
+ ):
536
+ bboxes_batch_size = 1
537
+ else:
538
+ print(isinstance(bboxes[0], list), len(bboxes[0]))
539
+ raise TypeError(
540
+ "`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats."
541
+ )
542
+ else:
543
+ print(isinstance(bboxes[0], list), len(bboxes[0]))
544
+ raise TypeError(
545
+ "`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats."
546
+ )
547
+ else:
548
+ print(isinstance(bboxes[0], list), len(bboxes[0]))
549
+ raise TypeError(
550
+ "`bboxes` must be a list of lists of list with four floats or a list of tuples with four floats."
551
+ )
552
+
553
+ if prompt is not None and isinstance(prompt, str):
554
+ prompt_batch_size = 1
555
+ elif prompt is not None and isinstance(prompt, list):
556
+ prompt_batch_size = len(prompt)
557
+ elif prompt_embeds is not None:
558
+ prompt_batch_size = prompt_embeds.shape[0]
559
+
560
+ if token_indices_batch_size != prompt_batch_size:
561
+ raise ValueError(
562
+ f"token indices batch size must be same as prompt batch size. token indices batch size: {token_indices_batch_size}, prompt batch size: {prompt_batch_size}"
563
+ )
564
+
565
+ if bboxes_batch_size != prompt_batch_size:
566
+ raise ValueError(
567
+ f"bbox batch size must be same as prompt batch size. bbox batch size: {bboxes_batch_size}, prompt batch size: {prompt_batch_size}"
568
+ )
569
+
570
+ def get_indices(self, prompt: str) -> Dict[str, int]:
571
+ """Utility function to list the indices of the tokens you wish to alte"""
572
+ ids = self.tokenizer(prompt).input_ids
573
+ indices = {
574
+ i: tok
575
+ for tok, i in zip(
576
+ self.tokenizer.convert_ids_to_tokens(ids), range(len(ids))
577
+ )
578
+ }
579
+ return indices
580
+
581
+ @staticmethod
582
+ def draw_box(pil_img: Image, bboxes: List[List[float]]) -> Image:
583
+ """Utility function to draw bbox on the image"""
584
+ width, height = pil_img.size
585
+ draw = ImageDraw.Draw(pil_img)
586
+
587
+ for obj_box in bboxes:
588
+ x_min, y_min, x_max, y_max = (
589
+ obj_box[0] * width,
590
+ obj_box[1] * height,
591
+ obj_box[2] * width,
592
+ obj_box[3] * height,
593
+ )
594
+ draw.rectangle(
595
+ [int(x_min), int(y_min), int(x_max), int(y_max)],
596
+ outline="red",
597
+ width=4,
598
+ )
599
+
600
+ return pil_img
601
+
602
+
603
+ @staticmethod
604
+ def draw_box_desc(pil_img: Image, bboxes: List[List[float]], prompt: List[str]) -> Image:
605
+ """Utility function to draw bbox on the image"""
606
+ color_list = ['red', 'blue', 'yellow', 'purple', 'green', 'black', 'brown', 'orange', 'white', 'gray']
607
+ width, height = pil_img.size
608
+ draw = ImageDraw.Draw(pil_img)
609
+ font_folder = os.path.dirname(os.path.dirname(__file__))
610
+ font_path = os.path.join(font_folder, 'Rainbow-Party-2.ttf')
611
+ font = ImageFont.truetype(font_path, 30)
612
+
613
+ for box_id in range(len(bboxes)):
614
+ obj_box = bboxes[box_id]
615
+ text = prompt[box_id]
616
+ fill = 'black'
617
+ for color in prompt[box_id].split(' '):
618
+ if color in color_list:
619
+ fill = color
620
+ text = text.split(',')[0]
621
+ x_min, y_min, x_max, y_max = (
622
+ obj_box[0] * width,
623
+ obj_box[1] * height,
624
+ obj_box[2] * width,
625
+ obj_box[3] * height,
626
+ )
627
+ draw.rectangle(
628
+ [int(x_min), int(y_min), int(x_max), int(y_max)],
629
+ outline=fill,
630
+ width=4,
631
+ )
632
+ draw.text((int(x_min), int(y_min)), text, fill=fill, font=font)
633
+
634
+ return pil_img
635
+
636
+
637
+ @torch.no_grad()
638
+ def __call__(
639
+ self,
640
+ prompt: List[List[str]] = None,
641
+ bboxes: List[List[List[float]]] = None,
642
+ height: Optional[int] = None,
643
+ width: Optional[int] = None,
644
+ num_inference_steps: int = 50,
645
+ guidance_scale: float = 7.5,
646
+ negative_prompt: Optional[Union[str, List[str]]] = None,
647
+ num_images_per_prompt: Optional[int] = 1,
648
+ eta: float = 0.0,
649
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
650
+ latents: Optional[torch.FloatTensor] = None,
651
+ prompt_embeds: Optional[torch.FloatTensor] = None,
652
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
653
+ output_type: Optional[str] = "pil",
654
+ return_dict: bool = True,
655
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
656
+ callback_steps: int = 1,
657
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
658
+ MIGCsteps=20,
659
+ NaiveFuserSteps=-1,
660
+ ca_scale=None,
661
+ ea_scale=None,
662
+ sac_scale=None,
663
+ aug_phase_with_and=False,
664
+ sa_preserve=False,
665
+ use_sa_preserve=False,
666
+ clear_set=False,
667
+ GUI_progress=None
668
+ ):
669
+ r"""
670
+ Function invoked when calling the pipeline for generation.
671
+
672
+ Args:
673
+ prompt (`str` or `List[str]`, *optional*):
674
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
675
+ instead.
676
+ token_indices (Union[List[List[List[int]]], List[List[int]]], optional):
677
+ The list of the indexes in the prompt to layout. Defaults to None.
678
+ bboxes (Union[List[List[List[float]]], List[List[float]]], optional):
679
+ The bounding boxes of the indexes to maintain layout in the image. Defaults to None.
680
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
681
+ The height in pixels of the generated image.
682
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
683
+ The width in pixels of the generated image.
684
+ num_inference_steps (`int`, *optional*, defaults to 50):
685
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
686
+ expense of slower inference.
687
+ guidance_scale (`float`, *optional*, defaults to 7.5):
688
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
689
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
690
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
691
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
692
+ usually at the expense of lower image quality.
693
+ negative_prompt (`str` or `List[str]`, *optional*):
694
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
695
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
696
+ less than `1`).
697
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
698
+ The number of images to generate per prompt.
699
+ eta (`float`, *optional*, defaults to 0.0):
700
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
701
+ [`schedulers.DDIMScheduler`], will be ignored for others.
702
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
703
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
704
+ to make generation deterministic.
705
+ latents (`torch.FloatTensor`, *optional*):
706
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
707
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
708
+ tensor will ge generated by sampling using the supplied random `generator`.
709
+ prompt_embeds (`torch.FloatTensor`, *optional*):
710
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
711
+ provided, text embeddings will be generated from `prompt` input argument.
712
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
713
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
714
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
715
+ argument.
716
+ output_type (`str`, *optional*, defaults to `"pil"`):
717
+ The output format of the generate image. Choose between
718
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
719
+ return_dict (`bool`, *optional*, defaults to `True`):
720
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
721
+ plain tuple.
722
+ callback (`Callable`, *optional*):
723
+ A function that will be called every `callback_steps` steps during inference. The function will be
724
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
725
+ callback_steps (`int`, *optional*, defaults to 1):
726
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
727
+ called at every step.
728
+ cross_attention_kwargs (`dict`, *optional*):
729
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
730
+ `self.processor` in
731
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
732
+ max_guidance_iter (`int`, *optional*, defaults to `10`):
733
+ The maximum number of iterations for the layout guidance on attention maps in diffusion mode.
734
+ max_guidance_iter_per_step (`int`, *optional*, defaults to `5`):
735
+ The maximum number of iterations to run during each time step for layout guidance.
736
+ scale_factor (`int`, *optional*, defaults to `50`):
737
+ The scale factor used to update the latents during optimization.
738
+
739
+ Examples:
740
+
741
+ Returns:
742
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
743
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
744
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
745
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
746
+ (nsfw) content, according to the `safety_checker`.
747
+ """
748
+ def aug_phase_with_and_function(phase, instance_num):
749
+ instance_num = min(instance_num, 7)
750
+ copy_phase = [phase] * instance_num
751
+ phase = ', and '.join(copy_phase)
752
+ return phase
753
+
754
+ if aug_phase_with_and:
755
+ instance_num = len(prompt[0]) - 1
756
+ for i in range(1, len(prompt[0])):
757
+ prompt[0][i] = aug_phase_with_and_function(prompt[0][i],
758
+ instance_num)
759
+ # 0. Default height and width to unet
760
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
761
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
762
+
763
+ # 2. Define call parameters
764
+ if prompt is not None and isinstance(prompt, str):
765
+ batch_size = 1
766
+ elif prompt is not None and isinstance(prompt, list):
767
+ batch_size = len(prompt)
768
+ else:
769
+ batch_size = prompt_embeds.shape[0]
770
+
771
+ prompt_nums = [0] * len(prompt)
772
+ for i, _ in enumerate(prompt):
773
+ prompt_nums[i] = len(_)
774
+
775
+ device = self._execution_device
776
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
777
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
778
+ # corresponds to doing no classifier free guidance.
779
+ do_classifier_free_guidance = guidance_scale > 1.0
780
+
781
+ # 3. Encode input prompt
782
+ prompt_embeds, cond_prompt_embeds, embeds_pooler = self._encode_prompt(
783
+ prompt,
784
+ device,
785
+ num_images_per_prompt,
786
+ do_classifier_free_guidance,
787
+ negative_prompt,
788
+ prompt_embeds=prompt_embeds,
789
+ negative_prompt_embeds=negative_prompt_embeds,
790
+ )
791
+ # print(prompt_embeds.shape) 3 77 768
792
+
793
+ # 4. Prepare timesteps
794
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
795
+ timesteps = self.scheduler.timesteps
796
+
797
+ # 5. Prepare latent variables
798
+ num_channels_latents = self.unet.config.in_channels
799
+ latents = self.prepare_latents(
800
+ batch_size * num_images_per_prompt,
801
+ num_channels_latents,
802
+ height,
803
+ width,
804
+ prompt_embeds.dtype,
805
+ device,
806
+ generator,
807
+ latents,
808
+ )
809
+
810
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
811
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
812
+
813
+ # 7. Denoising loop
814
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
815
+
816
+ if clear_set:
817
+ self.instance_set = set()
818
+ self.embedding = {}
819
+
820
+ now_set = set()
821
+ for i in range(len(bboxes[0])):
822
+ now_set.add((tuple(bboxes[0][i]), prompt[0][i + 1]))
823
+
824
+ mask_set = (now_set | self.instance_set) - (now_set & self.instance_set)
825
+ self.instance_set = now_set
826
+
827
+ guidance_mask = np.full((4, height // 8, width // 8), 1.0)
828
+
829
+ for bbox, _ in mask_set:
830
+ w_min = max(0, int(width * bbox[0] // 8) - 5)
831
+ w_max = min(width, int(width * bbox[2] // 8) + 5)
832
+ h_min = max(0, int(height * bbox[1] // 8) - 5)
833
+ h_max = min(height, int(height * bbox[3] // 8) + 5)
834
+ guidance_mask[:, h_min:h_max, w_min:w_max] = 0
835
+
836
+ kernal_size = 5
837
+ guidance_mask = uniform_filter(
838
+ guidance_mask, axes = (1, 2), size = kernal_size
839
+ )
840
+
841
+ guidance_mask = torch.from_numpy(guidance_mask).to(self.device).unsqueeze(0)
842
+
843
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
844
+ for i, t in enumerate(timesteps):
845
+ if GUI_progress is not None:
846
+ GUI_progress[0] = int((i + 1) / len(timesteps) * 100)
847
+ # expand the latents if we are doing classifier free guidance
848
+ latent_model_input = (
849
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
850
+ )
851
+ latent_model_input = self.scheduler.scale_model_input(
852
+ latent_model_input, t
853
+ )
854
+
855
+ # predict the noise residual
856
+ cross_attention_kwargs = {'prompt_nums': prompt_nums,
857
+ 'bboxes': bboxes,
858
+ 'ith': i,
859
+ 'embeds_pooler': embeds_pooler,
860
+ 'timestep': t,
861
+ 'height': height,
862
+ 'width': width,
863
+ 'MIGCsteps': MIGCsteps,
864
+ 'NaiveFuserSteps': NaiveFuserSteps,
865
+ 'ca_scale': ca_scale,
866
+ 'ea_scale': ea_scale,
867
+ 'sac_scale': sac_scale,
868
+ 'sa_preserve': sa_preserve,
869
+ 'use_sa_preserve': use_sa_preserve}
870
+
871
+ self.unet.eval()
872
+ noise_pred = self.unet(
873
+ latent_model_input,
874
+ t,
875
+ encoder_hidden_states=prompt_embeds,
876
+ cross_attention_kwargs=cross_attention_kwargs,
877
+ ).sample
878
+
879
+ # perform guidance
880
+ if do_classifier_free_guidance:
881
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
882
+ noise_pred = noise_pred_uncond + guidance_scale * (
883
+ noise_pred_text - noise_pred_uncond
884
+ )
885
+
886
+ step_output = self.scheduler.step(
887
+ noise_pred, t, latents, **extra_step_kwargs
888
+ )
889
+ latents = step_output.prev_sample
890
+
891
+ ori_input = latents.detach().clone()
892
+ if use_sa_preserve and i in self.embedding:
893
+ latents = (
894
+ latents * (1.0 - guidance_mask)
895
+ + torch.from_numpy(self.embedding[i]).to(latents.device) * guidance_mask
896
+ ).float()
897
+
898
+ if sa_preserve:
899
+ self.embedding[i] = ori_input.cpu().numpy()
900
+
901
+ # call the callback, if provided
902
+ if i == len(timesteps) - 1 or (
903
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
904
+ ):
905
+ progress_bar.update()
906
+ if callback is not None and i % callback_steps == 0:
907
+ callback(i, t, latents)
908
+
909
+ if output_type == "latent":
910
+ image = latents
911
+ elif output_type == "pil":
912
+ # 8. Post-processing
913
+ image = self.decode_latents(latents)
914
+ image = self.numpy_to_pil(image)
915
+ else:
916
+ # 8. Post-processing
917
+ image = self.decode_latents(latents)
918
+
919
+ # Offload last model to CPU
920
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
921
+ self.final_offload_hook.offload()
922
+
923
+ if not return_dict:
924
+ return (image, None)
925
+
926
+ return StableDiffusionPipelineOutput(
927
+ images=image, nsfw_content_detected=None
928
+ )
migc/migc_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ import yaml
6
+ import random
7
+ from diffusers.utils.import_utils import is_accelerate_available
8
+ from transformers import CLIPTextModel, CLIPTokenizer
9
+ from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore
10
+ from diffusers import EulerDiscreteScheduler
11
+ if is_accelerate_available():
12
+ from accelerate import init_empty_weights
13
+ from contextlib import nullcontext
14
+
15
+
16
+ def seed_everything(seed):
17
+ # np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ random.seed(seed)
21
+
22
+
23
+ import torch
24
+ from typing import Callable, Dict, List, Optional, Union
25
+ from collections import defaultdict
26
+
27
+ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
28
+
29
+ # We need to set Attention Processors for the following keys.
30
+ all_processor_keys = [
31
+ 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor',
32
+ 'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor',
33
+ 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
34
+ 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
35
+ 'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor',
36
+ 'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor',
37
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
38
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
39
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor',
40
+ 'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor',
41
+ 'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor',
42
+ 'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor',
43
+ 'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor',
44
+ 'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor',
45
+ 'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor',
46
+ 'mid_block.attentions.0.transformer_blocks.0.attn1.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor'
47
+ ]
48
+
49
+ def load_migc(unet, attention_store, pretrained_MIGC_path: Union[str, Dict[str, torch.Tensor]], attn_processor,
50
+ **kwargs):
51
+
52
+ state_dict = torch.load(pretrained_MIGC_path, map_location="cpu")
53
+
54
+ # fill attn processors
55
+ attn_processors = {}
56
+ state_dict = state_dict['state_dict']
57
+
58
+
59
+ adapter_grouped_dict = defaultdict(dict)
60
+
61
+ # change the key of MIGC.ckpt as the form of diffusers unet
62
+ for key, value in state_dict.items():
63
+ key_list = key.split(".")
64
+ assert 'migc' in key_list
65
+ if 'input_blocks' in key_list:
66
+ model_type = 'down_blocks'
67
+ elif 'middle_block' in key_list:
68
+ model_type = 'mid_block'
69
+ else:
70
+ model_type = 'up_blocks'
71
+ index_number = int(key_list[3])
72
+ if model_type == 'down_blocks':
73
+ input_num1 = str(index_number//3)
74
+ input_num2 = str((index_number%3)-1)
75
+ elif model_type == 'mid_block':
76
+ input_num1 = '0'
77
+ input_num2 = '0'
78
+ else:
79
+ input_num1 = str(index_number//3)
80
+ input_num2 = str(index_number%3)
81
+ attn_key_list = [model_type,input_num1,'attentions',input_num2,'transformer_blocks','0']
82
+ if model_type == 'mid_block':
83
+ attn_key_list = [model_type,'attentions',input_num2,'transformer_blocks','0']
84
+ attn_processor_key = '.'.join(attn_key_list)
85
+ sub_key = '.'.join(key_list[key_list.index('migc'):])
86
+ adapter_grouped_dict[attn_processor_key][sub_key] = value
87
+
88
+ # Create MIGC Processor
89
+ config = {'not_use_migc': False}
90
+ for key, value_dict in adapter_grouped_dict.items():
91
+ dim = value_dict['migc.norm.bias'].shape[0]
92
+ config['C'] = dim
93
+ key_final = key + '.attn2.processor'
94
+ if key_final.startswith("mid_block"):
95
+ place_in_unet = "mid"
96
+ elif key_final.startswith("up_blocks"):
97
+ place_in_unet = "up"
98
+ elif key_final.startswith("down_blocks"):
99
+ place_in_unet = "down"
100
+
101
+ attn_processors[key_final] = attn_processor(config, attention_store, place_in_unet)
102
+ attn_processors[key_final].load_state_dict(value_dict)
103
+ attn_processors[key_final].to(device=unet.device, dtype=unet.dtype)
104
+
105
+ # Create CrossAttention/SelfAttention Processor
106
+ config = {'not_use_migc': True}
107
+ for key in all_processor_keys:
108
+ if key not in attn_processors.keys():
109
+ if key.startswith("mid_block"):
110
+ place_in_unet = "mid"
111
+ elif key.startswith("up_blocks"):
112
+ place_in_unet = "up"
113
+ elif key.startswith("down_blocks"):
114
+ place_in_unet = "down"
115
+ attn_processors[key] = attn_processor(config, attention_store, place_in_unet)
116
+ unet.set_attn_processor(attn_processors)
117
+ attention_store.num_att_layers = 32
118
+
119
+
120
+ def offlinePipelineSetupWithSafeTensor(sd_safetensors_path):
121
+ project_dir = os.path.dirname(os.path.dirname(__file__))
122
+ migc_ckpt_path = os.path.join(project_dir, 'pretrained_weights/MIGC_SD14.ckpt')
123
+ clip_model_path = os.path.join(project_dir, 'migc_gui_weights/clip/text_encoder')
124
+ clip_tokenizer_path = os.path.join(project_dir, 'migc_gui_weights/clip/tokenizer')
125
+ original_config_file = os.path.join(project_dir, 'migc_gui_weights/v1-inference.yaml')
126
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
127
+ with ctx():
128
+ # text_encoder = CLIPTextModel(config)
129
+ text_encoder = CLIPTextModel.from_pretrained(clip_model_path)
130
+ tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path)
131
+ pipe = StableDiffusionMIGCPipeline.from_single_file(sd_safetensors_path,
132
+ original_config_file=original_config_file,
133
+ text_encoder=text_encoder,
134
+ tokenizer=tokenizer,
135
+ load_safety_checker=False)
136
+ print('Initializing pipeline')
137
+ pipe.attention_store = AttentionStore()
138
+ from migc.migc_utils import load_migc
139
+ load_migc(pipe.unet , pipe.attention_store,
140
+ migc_ckpt_path, attn_processor=MIGCProcessor)
141
+
142
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
143
+ return pipe
requirements.txt CHANGED
@@ -4,4 +4,6 @@ diffusers
4
  invisible_watermark
5
  torch
6
  transformers
7
- xformers
 
 
 
4
  invisible_watermark
5
  torch
6
  transformers
7
+ xformers
8
+ einops
9
+ scipy