Nupur Kumari commited on
Commit
8173ae1
·
0 Parent(s):

concept ablation

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ images/applications.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Erasing Concepts from Diffusion Models
3
+ emoji: 💡
4
+ colorFrom: indigo
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.21.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+
14
+
15
+ # Erasing Concepts from Diffusion Models
16
+
17
+ Project Website [https://erasing.baulab.info](https://erasing.baulab.info) <br>
18
+ Arxiv Preprint [https://arxiv.org/pdf/2303.07345.pdf](https://arxiv.org/pdf/2303.07345.pdf) <br>
19
+ Fine-tuned Weights [https://erasing.baulab.info/weights/esd_models/](https://erasing.baulab.info/weights/esd_models/) <br>
20
+ <div align='center'>
21
+ <img src = 'images/applications.png'>
22
+ </div>
23
+
24
+ Motivated by recent advancements in text-to-image diffusion, we study erasure of specific concepts from the model's weights. While Stable Diffusion has shown promise in producing explicit or realistic artwork, it has raised concerns regarding its potential for misuse. We propose a fine-tuning method that can erase a visual concept from a pre-trained diffusion model, given only the name of the style and using negative guidance as a teacher. We benchmark our method against previous approaches that remove sexually explicit content and demonstrate its effectiveness, performing on par with Safe Latent Diffusion and censored training.
25
+
26
+ To evaluate artistic style removal, we conduct experiments erasing five modern artists from the network and conduct a user study to assess the human perception of the removed styles. Unlike previous methods, our approach can remove concepts from a diffusion model permanently rather than modifying the output at the inference time, so it cannot be circumvented even if a user has access to model weights
27
+
28
+ Given only a short text description of an undesired visual concept and no additional data, our method fine-tunes model weights to erase the targeted concept. Our method can avoid NSFW content, stop imitation of a specific artist's style, or even erase a whole object class from model output, while preserving the model's behavior and capabilities on other topics.
29
+
30
+ ## Demo vs github
31
+
32
+ This demo uses an updated implementation from the original Erasing codebase the publication is based from.
33
+
34
+ ## Running locally
35
+
36
+ 1.) Create an environment using the packages included in the requirements.txt file
37
+
38
+ 2.) Run `python app.py`
39
+
40
+ 3.) Open the application in browser at `http://127.0.0.1:7860/`
41
+
42
+ 4.) Train, evaluate, and save models using our method
43
+
44
+ ## Citing our work
45
+ The preprint can be cited as follows
46
+ ```
47
+ @article{gandikota2023erasing,
48
+ title={Erasing Concepts from Diffusion Models},
49
+ author={Rohit Gandikota and Joanna Materzy\'nska and Jaden Fiotto-Kaufman and David Bau},
50
+ journal={arXiv preprint arXiv:2303.07345},
51
+ year={2023}
52
+ }
53
+ ```
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import sys
4
+ from pathlib import Path
5
+ from trainer import train_submit, inference
6
+
7
+ import os
8
+ model_map = {'Van Gogh' : 'models/vangogh_ablation_delta.bin',
9
+ 'Greg Rutkowski' : 'models/greg_rutkowski_ablation_delta.bin',
10
+ }
11
+
12
+ ORIGINAL_SPACE_ID = 'nupurkmr9/concept-ablation'
13
+ SPACE_ID = os.getenv('SPACE_ID')
14
+
15
+ SHARED_UI_WARNING = f'''## Attention - the demo requires at least 24GB VRAM for style and object removal, 24 GB VRAM for memorized image removal. Please clone this repository to run on your own machine.
16
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>. This demo is partly adapted from https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion.
17
+ '''
18
+
19
+ sys.path.append("concept-ablation-diffusers")
20
+
21
+ class Demo:
22
+
23
+ def __init__(self) -> None:
24
+
25
+ self.training = False
26
+ self.generating = False
27
+
28
+ # self.diffuser = StableDiffuser(scheduler='DDIM').to('cuda').eval().half()
29
+
30
+ with gr.Blocks() as demo:
31
+ self.layout()
32
+ demo.queue(concurrency_count=5).launch()
33
+
34
+
35
+ def layout(self):
36
+
37
+ with gr.Row():
38
+
39
+ if SPACE_ID == ORIGINAL_SPACE_ID:
40
+
41
+ self.warning = gr.Markdown(SHARED_UI_WARNING)
42
+
43
+ with gr.Row():
44
+
45
+ with gr.Tab("Test") as inference_column:
46
+
47
+ with gr.Row():
48
+
49
+ self.explain_infr = gr.Markdown(interactive=False,
50
+ value='This is a demo of [Concept Ablation](https://www.cs.cmu.edu/~concept-ablation/). To try out a model where a concept has been erased, select a model and enter any prompt. For example, if you select the model "Van Gogh" you can generate images for the prompt "A portrait in the style of Van Gogh" and compare the ablated and pre-trained models. We have also provided several other pre-fine-tuned models with artistic styles and concepts ablated (Check out the "Ablated Model" drop-down). You can also train and run your own custom models. Check out the "train" section for custom ablation of concepts.')
51
+
52
+ with gr.Row():
53
+
54
+ with gr.Column(scale=1):
55
+
56
+ self.prompt_input_infr = gr.Text(
57
+ placeholder="Enter prompt...",
58
+ label="Prompt",
59
+ info="Prompt to generate"
60
+ )
61
+
62
+ with gr.Row():
63
+
64
+ self.model_dropdown = gr.Dropdown(
65
+ label="Ablated Models",
66
+ choices= list(model_map.keys()),
67
+ value='Van Gogh',
68
+ interactive=True
69
+ )
70
+
71
+ self.seed_infr = gr.Number(
72
+ label="Seed",
73
+ value=42
74
+ )
75
+
76
+ with gr.Column(scale=2):
77
+
78
+ self.infr_button = gr.Button(
79
+ value="Generate",
80
+ interactive=True
81
+ )
82
+
83
+ with gr.Row():
84
+
85
+ self.image_new = gr.Image(
86
+ label="Ablated",
87
+ interactive=False
88
+ )
89
+ self.image_orig = gr.Image(
90
+ label="SD",
91
+ interactive=False
92
+ )
93
+
94
+ with gr.Tab("Train") as training_column:
95
+
96
+ with gr.Row():
97
+
98
+ self.explain_train= gr.Markdown(interactive=False,
99
+ value='In this part you can ablate any concept from Stable Diffusion. Enter the name of the concept and select the kind of concept (e.g. object, style, memorization). You will also need to select a parent anchor concept e.g. cats when ablating grumpy cat, painting when ablating an artists\' style. When ablating a specific object or memorized image, you also need to either provide OpenAI API key or upload a file with 50-200 prompts corresponding to the ablation concept. With default settings, it takes about 20 minutes to fine-tune the model; then you can try inference above or download the weights. The training code used here is slightly different than the code tested in the original paper. Code and details are at [github link](https://github.com/nupurkmr9/concept-ablation).')
100
+
101
+ with gr.Row():
102
+
103
+ with gr.Column(scale=3):
104
+ mem_impath = []
105
+
106
+ self.prompt_input = gr.Text(
107
+ placeholder="Enter concept to remove... e.g. van gogh",
108
+ label="prompt",
109
+ info="Name of the concept to ablate from Model"
110
+ )
111
+ self.anchor_prompt = gr.Text(
112
+ placeholder="Enter anchor concept... e.g. painting",
113
+ label="anchor prompt",
114
+ info="Name of the anchor concept (superset of the concept to be ablated)"
115
+ )
116
+
117
+ choices = ['style', 'object', 'memorization']
118
+
119
+ self.concept_type = gr.Dropdown(
120
+ choices=choices,
121
+ value='style',
122
+ label='Ablated concept type',
123
+ info='Ablated concept type'
124
+ )
125
+
126
+ self.reg_lambda = gr.Number(
127
+ value=0,
128
+ label="Regularization loss",
129
+ info='Whether to add regularization loss on anchor concept. 1.0 when common words in ablated and anchor prompt e.g. grumpy cat and cat'
130
+ )
131
+
132
+ self.iterations_input = gr.Number(
133
+ value=100,
134
+ precision=0,
135
+ label="Iterations",
136
+ info='iterations used to train'
137
+ )
138
+
139
+ self.lr_input = gr.Number(
140
+ value=2e-6,
141
+ label="Learning Rate",
142
+ info='Learning rate used to train'
143
+ )
144
+
145
+ visible_openai_key = True
146
+ self.openai_key = gr.Text(
147
+ placeholder="Enter openAI API key or atleast 50 prompts if concept type is object/memorization",
148
+ label="OpenAI API key or Prompts (Required when concept type is object or memorization)",
149
+ info="If concept type is object, we use chatGPT to generate a set of prompts correspondig to the ablation concept. If concept type is memorization, we use ChatGPT to generate paraphrases of the text prompt that generates memorized image. You can either provide the api key or a set of desired prompts (atleast 50). For reference please check example prompts at https://github.com/nupurkmr9/concept-ablation/blob/main/assets/finetune_prompts/ ",
150
+ visible=visible_openai_key
151
+ )
152
+
153
+ visible = True
154
+ mem_impath.append(gr.Files(label=f'''Upload the memorized image if concept type is memorization''', visible=visible))
155
+
156
+
157
+ with gr.Column(scale=1):
158
+
159
+ self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
160
+
161
+ self.train_button = gr.Button(
162
+ value="Train",
163
+ )
164
+
165
+ self.download = gr.Files()
166
+
167
+ self.infr_button.click(self.inference, inputs = [
168
+ self.prompt_input_infr,
169
+ self.seed_infr,
170
+ self.model_dropdown
171
+ ],
172
+ outputs=[
173
+ self.image_new,
174
+ self.image_orig
175
+ ]
176
+ )
177
+ self.train_button.click(self.train, inputs = [
178
+ self.prompt_input,
179
+ self.anchor_prompt,
180
+ self.concept_type,
181
+ self.reg_lambda,
182
+ self.iterations_input,
183
+ self.lr_input,
184
+ self.openai_key,
185
+ ] + mem_impath,
186
+ outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
187
+ )
188
+
189
+ def train(self, prompt, anchor_prompt, concept_type, reg_lambda, iterations, lr, openai_key, *inputs):
190
+ self.train_status.update(value='')
191
+ if self.training:
192
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
193
+
194
+ randn = torch.randint(1, 10000000, (1,)).item()
195
+
196
+ save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}"
197
+ os.makedirs(save_path, exist_ok=True)
198
+ self.training = True
199
+ mem_impath = inputs[:1]
200
+ train_submit(prompt, anchor_prompt, concept_type, reg_lambda, iterations, lr, openai_key, save_path, mem_impath)
201
+
202
+ self.training = False
203
+
204
+ torch.cuda.empty_cache()
205
+
206
+ modelpath = sorted(Path(save_path).glob('*.bin'))[0]
207
+ model_map[f"Custom_{prompt.lower().replace(' ', '')}"] = modelpath
208
+
209
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), modelpath, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
210
+
211
+
212
+ def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
213
+
214
+ seed = seed or 42
215
+ n_steps = 50
216
+
217
+ generator = torch.manual_seed(seed)
218
+
219
+ model_path = model_map[model_name]
220
+
221
+ torch.cuda.empty_cache()
222
+
223
+ generator = torch.manual_seed(seed)
224
+
225
+ orig_image, edited_image = inference(model_path, prompt, n_steps, generator)
226
+
227
+ torch.cuda.empty_cache()
228
+
229
+ return edited_image, orig_image
230
+
231
+
232
+ demo = Demo()
233
+
assets/painting.txt ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Purple Paints
2
+ Painting a Seascape in Water color part 1
3
+ Why Arts & Crafts Are Beneficial to Early Childhood Education, Creve Coeur, Missouri
4
+ Simple Cherry Blossom Tree Painting images
5
+ Diy Wall Painting, Mural Painting, Paintings, Murals For Kids, Kids Wall Murals, Watercolor Walls, Mural Wall Art, Diy Canvas Art, Art Abstrait
6
+ Watch Doris Allen paints Roses on porcelain, China Painter Doris Drew Allen
7
+ Oil Painting with Kelsey May Connor in the tradition of Classical Realism | Art Classes Malta
8
+ Peinture LHuile Fruits Ltude Youtube
9
+ coffee painting canvas diy acrylic step by step for beginners #stepbysteppainting #tracie
10
+ Sara Barcus paints Bob Ross
11
+ Stock Video Footage of Little Baby Girl Paints Water Colors
12
+ "Jill Baker working on her painting ""Lady In Waiting"" (2016)"
13
+ Children painting a wall
14
+ Close up shot of the process of mixing oil paint with palette knife Stock Footage
15
+ toile : A young blonde woman draws a tree with paints Vidéos Libres De Droits
16
+ Acrylic Paint Grass - How To Paint Grass - with Acrylic
17
+ Snowfield Sentinel - En Plein Air Oil Demonstration
18
+ Artist painting A Stock Photography
19
+ peindre l 39 aquarelle youtube
20
+ Monet inspired Water Lily Pond - good way to practice the art of shadows
21
+ щетка для волос : A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life Стоковые видеозаписи
22
+ Wall Painting Techniques Youtube
23
+ Painting under the pergola, Yeniköy 2014
24
+ Practical Mom: Bleeding Crepe Paper Art Project for Little Kids
25
+ painter man at work with paint brush, easel, canvas and...
26
+ Watercolor landscape painting
27
+ How to paint birds in watercolour
28
+ Acrylic Painting Techniques - How to Paint Flowers - Parrot Tulip
29
+ Painting cars toddler activity
30
+ A painting student with her canvas at Dickinson Farm.
31
+ Blonde woman in denim clothing finding creativity by painting a canvas hung on white wall.
32
+ beautiful young woman painter at work,
33
+ which : A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life Stock Footage
34
+ how-to-paint-a-crown-crested-crane-in-acrylic-background
35
+ Slide Painting: A super fun indoor or outdoor process art activity for toddlers or preschoolers! Use cars, balls, or anything that rolls!
36
+ Poppy Spree - Week 2 Painting
37
+ How-To Make Colored Milk Explosion
38
+ Abstract Acrylic Painting Demo - Watercolor Look - Color Explosion - Easy how to abstract
39
+ Seine artist by triciamary
40
+ Free Painting Lessons - Cherry Blossom Bridge - Commentary by Acrylic Artist Brandon Schaefer
41
+ 3 blue trees on ice - painting video step by step demonstration
42
+ Martin Grealish painting in Venice.jpg
43
+ Live Painter for Elegant Wedding
44
+ Long-haired artist paints on canvas
45
+ Powe Painting Co Cover Photo
46
+ Live Painting Part 1 (how to begin)
47
+ which : A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life Stock Footage
48
+ Artist paints with oil paints
49
+ Stock Video Footage of hand of a young woman painting in purple
50
+ Watercolour tutorial - Summer shadows, cottages and flowers
51
+ Scrape pastel with knife to release a snowfall!
52
+ Acrylic pour painting process video
53
+ Woman painting stock photo
54
+ student painting at the Sainsbury Centre
55
+ Honfleur Painter
56
+ Bald Eagle in Wetlands Original Mini Painting on Easel
57
+ painting in nature things to do in virginia beach
58
+ Time lapse aquarelle painting Stock Footage
59
+ Watercolour Aquarelle Poppies Poppy Painting Demo Youtube
60
+ Wedding Artist Ben Keys of Wed on Canvas painting live at wedding reception. // The Graystone Inn, Wilmington, NC // Photo Courtesy of Blueberry Creative
61
+ NIOS - 225 Painting - Guide Book For Class 10th - English Medium
62
+ Fearless Watercolor #1 Frozen Marsh Full Watercolor Demonstration Music Ted Yoder
63
+ Oil Painting with Kelsey May Connor in the tradition of Classical Realism | Art Classes Malta
64
+ Michael Chambers, painting in process
65
+ KNIFE PAINTING - WATER LILY by NELLY LESTRADE (english subtitles)
66
+ A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life
67
+ Sip & Sketch at Acquiesce Winery!
68
+ painting my bedroom wall youtube how to paint a galaxy wall mural in a spaceship themed
69
+ Flat art painter workshop with paint supplies equipment tools ba Stock Illustration
70
+ Gold Leaf Gilding in Westminster London
71
+ Commission a portrait-artist-brisbane-perth-sydney-melbourne
72
+ My very first brushstrokes as a mural painter
73
+ Stock Video Footage of Little Baby Girl Paints Water Colors
74
+ bing images of flower portraits watercolor by billy showell | VIDEO :: DVD - Watercolour Flower Portraits with Billy Showell
75
+ Haidee-Jo Summers painting the winning painting at Paint Out Norwich 2014
76
+ Artist paints a picture of oil paint brush in hand with palette Stock Footage
77
+ img 9546 Sam Flores, recap: live painting in Golden Gate Park upper playground Sam Flores power to the peaceful paint live painting graffiti fifty24sf gallery art
78
+ Artist at Michael Orwick (DragonFire Gallery) Stormy Weather Arts Festival Quick Draw event at Tolovana Inn
79
+ Painting at the Tower - Stephen B Whatley, 2000 (Stephen B Whatley) Tags: uk windows england orange detail building london art architecture painting studio paint artist 2000 vibrant perspective millenium canvas creation painter expressionism expressionist toweroflondon palette damncool stkatharinedocks mywinners platinumphoto goldstaraward stephenbwhatley
80
+ Video aula curso Pintura em tela com espatula part 1
81
+ "getlinkyoutube.com-LIBERO! Full video ""orchidee"" dal artista Igor Sakharov"
82
+ Weekly Specials Painting: Royal Paint By Number Mini Junior Dolphins
83
+ Woman artist painting landscape outdoor Stock Footage
84
+ painting styrofoam trays for a spring toddler art project
85
+ Madinah Wilson, with the Biden Institute, signs a poster on the Green at the University of Delaware on Tuesday during an event for National Voter Registration Day that got students visiting the Green to register to vote.
86
+ female hands painting a water landscape with oil paints using palette knife. 4k - cavalletto attrezzatura per arti e mestieri video stock e b–roll
87
+ Art studio #miniature. Dioramas and Clever Things Dollhouse Miniatures and Accessories
88
+ INDIANA - BOB ROSS EXHIBIT - 1-29-2021 (1).jpeg
89
+ How to paint a CLOUDY sky tutorial - Jennings644
90
+ Catherine Soucy
91
+ Pin By Marionette Taboniar On Art Videos By Marionette
92
+ coffee painting canvas diy acrylic step by step for beginners #stepbysteppainting #tracie
93
+ Anita Jansen Youtube Watercolor Art Watercolor Trees Art
94
+ Beautiful young woman painter at work, on grey background — Stock Photo
95
+ In my natural habitat 🎨🎨 ~ I'm currently accepting tea recommendations. Tell me you're favourite teas(or hot drinks) in my stories! ☕️☕️ ~ #creatingsunflowers #mystudio #markmaking #naturallycurlyhair #curlybeauties #homeinthestudio #coloraddict #strongfit #fitnessenthusiast #collectart #originalpainting #colormehappy #workingprocess #intuitiveart #decorativeart #girlgains #worksonpaper #modernartist #womenwhopaint #girlgains
96
+ Judith paints in a Bob Ross style.
97
+ Royal & Langnickel: Painting by Number (Splish Splash)
98
+ Close up of a needlework landscape being made at hong ngoc handicraft center Stock Footage
99
+ Watercolour Tutorial Blea Tarn Cottage Lake District
100
+ Squishy Painting from make-it-your-own.com (Crafts & activities for kids)
101
+ Painter-girl en plein air
102
+ Kids painting art outdoor activity, montessori homeschooling education
103
+ HM Altered Book Project Page 3 - Art Journaling - Mixed Media - How to
104
+ Fototapete - Talented male Artist Works on Abstract Oil Painting, with Broad Strokes of Paint Brush he Creates Modern Masterpiece. Dark and Messy Creative Studio where Large Canvas Stands on Easel Illuminated
105
+ How I Start My Plein Air Paintings
106
+ acrylic painting how to step by step flower painting mixed media flower heads 7 acrylic
107
+ Oil painting classes oil painting classes in Ameerpet institution for oil painting classes in hyderabad canvas painting classes in hyderabad Oil paintings for sales oil painting images oil painting in canvas
108
+ палитра : Young Beautiful Female Artist is in an Art Studio, Sitting Behind an Easel and Painting on Canvas. Drawing Process: in the Art Studio of the Artists Hand Art Girl with a Brush Painting on Canvas.4K Стоковые видеозаписи
109
+ What is Plein Air Painting like in Ireland?
110
+ Seven Tips for Setting up an Impromptu Garden Art Studio
111
+ young woman painting - artist stock videos & royalty-free footage
112
+ Pébéo - Mixed Media: mixing Studio Acrylics, Vitrail et Fantasy colors on a 3D frame - YouTube
113
+ A painter To paint
114
+ How to paint a pumpkin canvas, art skills not required!!!
115
+ http://funkidos.com/pictures-world/art-world/hyper-realistic-paintings-on-wood
116
+ Web Design for Oil Painters
117
+ Painters palette with oil paints Stock Footage
118
+ live-wedding-painting-by-ben-keys-live-event-artist-of-wed-on-canvas-garden-wedding-gibbes-museum-pure-luxe-bride-for-the-knot
119
+ Lee's Painting logo
120
+ Sir Winston Churchill painting
121
+ The girl takes blue oil paint from the palette and spreads it on the canvas with Stock Footage
122
+ Van Gogh Sunflowers | You Paint the Masters
123
+ Painting lesson high school - tutorial de pintura para secundaria
124
+ Crayon Art... now this is even cooler than the other kind of
125
+ CoachingMasteryStudio-940-v1-png
126
+ finished tree with pink flowers and within the circle on a gold background
127
+ 3 blue trees on ice - painting video step by step demonstration
128
+ Talented male Artist Works on Abstract Oil Painting, with Broad Strokes of Paint Brush he Creates Modern Masterpiece. Dark and Messy Creative Studio where Large Canvas Stands on Easel Illuminated | Shutterstock HD Video #1036269983
129
+ magic painting logo designs
130
+ Workshop artistic, painting pigments color palette, color picker
131
+ Toddler Art — Autumn Tree Painting
132
+ Children painting pottery 2 Royalty Free Stock Photography
133
+ 4K Talented young graffiti street artist working on a mural in urban area Stock Footage
134
+ fairytale castle mural for kids bedroom
135
+ how to paint how to paint clouds in acrylic instructional painting
136
+ A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life
137
+ The girl draws a picture by numbers,with yellow acrylic paints. Hobby for adults Stock Footage
138
+ Sahara Mural (Jamila BC) Tags: cloud art sahara painting twilight mural desert dusk jamila jamilaproductions
139
+ ressam : Close Up of Little Girl Drawing in Nature with Teacher. Art School Classes in the City Park. Slow Motion. Creativity Inspiration Expression Concept Stok Video
140
+ Watercolor Tutorial Mountains | blue ridge mountains easy step by step watercolor
141
+ Stock Video Footage of Street painting
142
+ which : A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life Stock Footage
143
+ Palette, canvas and easel Stock Image
144
+ A woman paints a landscape with oil paints with a palette and brush. Slider shot Stock Footage
145
+ 1280x720 Blue Hydrangea Watercolor Doodle On Cotton Xuan Rice Paper By Amy
146
+ Zero technique in watercolour but had loads of fun
147
+ Taos - Plein Air Painter
148
+ Artist drawing picture of a city landscape Stock Footage
149
+ """Painters and paintings"""
150
+ Puff, Pass and Paint- 420-friendly painting in Orange County! 21+ tickets
151
+ Melted Crayon Art | Unsimple Living
152
+ Live Online Art Classes Webinar Oil Painting Sunset Pt1
153
+ finger paint col.jpg
154
+ Painting and wine class
155
+ Artist Nancy Tankersley works on a plein air oil study of the landscape at King's Creek/Choptank Wetlands preserve.
156
+ Using the palette knife to scrape away paint
157
+ Close-up shot of a woman artist making brush strokes with oil on canvas. Impressionist painter working en plein air
158
+ Students create paintings inspired by VanGogh's textured impasto
159
+ Close-up shot of impressionist painter making final brush strokes during plein air painting
160
+ Preschool Painting Activity Royalty Free Stock Images
161
+ Advance Paint Pour Workshops
162
+ Painting By Numbers Poppy Field
163
+ Artist paints a picture of oil paint brush in hand with palette. Slider shot Stock Footage
164
+ Art painting hobby leisure girl drawing picture. Art painting hobby. creative leisure. girl drawing a picture. talent inspiration creation and self expression stock photography
165
+ 1.5-Hour Children's Creative Art Session for 1 Child
166
+ Painting the Nude in Watercolour Demonstration by Anthony Barrow BA MIFL
167
+ The cherry blossoms in the Mt. Fuji Acrylic Painting Buying a new proven fact that everyone in your family will like doing together outside? Canvas Painting Tutorials, Easy Canvas Painting, Easy Paintings, Amazing Paintings, Mini Paintings, Pour Painting, Painting Videos, Painting Lessons, Small Canvas Art
168
+ paintings in oil
169
+ Printable Wall Murals japanese cherry blossoms by sole junkie youtube
170
+ How to paint a Rose in watercolors, the impressionistic way
171
+ Artist�s hand with paintbrush painting the picture photo
172
+ Childrens wall murals the muralist kids custom art for Equestrian wall mural
173
+ St Patrick's Day Toddler Craft - Toilet paper shamrocks + pop up art
174
+ Como pintar estrada, cerca e rvore - How to paint tree - how to paint landscape
175
+ How to paint tree branches - Painting Tutorial
176
+ spraying with a spray bottle
177
+ Miniature paintings on the ground, a painting on the wall and a painting on an easel of a woman
178
+ At the hairdresser on the market
179
+ Paint With Kevin Hill Snow Covered Valley
180
+ C. Rasmussen in her Los Angeles studio, 2018.
181
+ Stock Video Footage of View from behind on child paints a sheet of paper by hand
182
+ Sandra Gal seen through a ring flash tube as she works on her latest painting.
183
+ Watercolour Art Classes
184
+ Painting icon. Stock Footage
185
+ Watercolour Trees - How To Paint A Tree In Watercolour
186
+ Watercolor with Birgit O'Connor: Within the Flower
187
+ London-Fine-Art-Studios-Painting.png
188
+ Beautiful young woman painter at work, isolated on white
189
+ Kindergarten and Day Care Cartoon Wall Art Work In Mumbai
190
+ Close-up of woman artist painting still life picture on canvas in art studio
191
+ Modified Paint.Net logo
192
+ Im Grand Canyon wird Kunst zelebriert. - Foto: Arizona Office of Tourism
193
+ Melted crayons art
194
+ Woman paints a picture outdoors.
195
+ Snow Leopard Cub - Oils
196
+ Stock Video Footage of close up of a brush that paints a canvas - painter
197
+ Woman hand painting Stock Photo
198
+ Watercolour Aquarelle Demo Poppies Coquelicots Aquarelle By
199
+ painting 9 year 7 year pop painting american doll saige
200
+ Contemporary man with smartphone taking photograph of the painting on easel
concept-ablation-diffusers/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 pix2pixzero
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
concept-ablation-diffusers/model_pipeline.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ import torch
4
+ from accelerate.logging import get_logger
5
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
6
+ from diffusers.models.cross_attention import CrossAttention
7
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
8
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
9
+ StableDiffusionSafetyChecker,
10
+ )
11
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
14
+
15
+ if is_xformers_available():
16
+ import xformers
17
+ import xformers.ops
18
+ else:
19
+ xformers = None
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ def set_use_memory_efficient_attention_xformers(
25
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
26
+ ):
27
+ if use_memory_efficient_attention_xformers:
28
+ if self.added_kv_proj_dim is not None:
29
+ # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
30
+ # which uses this type of cross attention ONLY because the attention mask of format
31
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
32
+ raise NotImplementedError(
33
+ "Memory efficient attention with `xformers` is currently not supported when"
34
+ " `self.added_kv_proj_dim` is defined."
35
+ )
36
+ elif not is_xformers_available():
37
+ raise ModuleNotFoundError(
38
+ (
39
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
40
+ " xformers"
41
+ ),
42
+ name="xformers",
43
+ )
44
+ elif not torch.cuda.is_available():
45
+ raise ValueError(
46
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
47
+ " only available for GPU "
48
+ )
49
+ else:
50
+ try:
51
+ # Make sure we can run the memory efficient attention
52
+ _ = xformers.ops.memory_efficient_attention(
53
+ torch.randn((1, 2, 40), device="cuda"),
54
+ torch.randn((1, 2, 40), device="cuda"),
55
+ torch.randn((1, 2, 40), device="cuda"),
56
+ )
57
+ except Exception as e:
58
+ raise e
59
+
60
+ processor = CustomDiffusionXFormersAttnProcessor(
61
+ attention_op=attention_op)
62
+ else:
63
+ processor = CustomDiffusionAttnProcessor()
64
+
65
+ self.set_processor(processor)
66
+
67
+
68
+ class CustomDiffusionAttnProcessor:
69
+ def __call__(
70
+ self,
71
+ attn: CrossAttention,
72
+ hidden_states,
73
+ encoder_hidden_states=None,
74
+ attention_mask=None,
75
+ ):
76
+ batch_size, sequence_length, _ = hidden_states.shape
77
+ attention_mask = attn.prepare_attention_mask(
78
+ attention_mask, sequence_length, batch_size)
79
+ query = attn.to_q(hidden_states)
80
+
81
+ crossattn = False
82
+ if encoder_hidden_states is None:
83
+ encoder_hidden_states = hidden_states
84
+ else:
85
+ crossattn = True
86
+ if attn.cross_attention_norm:
87
+ encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
88
+
89
+ key = attn.to_k(encoder_hidden_states)
90
+ value = attn.to_v(encoder_hidden_states)
91
+ if crossattn:
92
+ detach = torch.ones_like(key)
93
+ detach[:, :1, :] = detach[:, :1, :] * 0.
94
+ key = detach * key + (1 - detach) * key.detach()
95
+ value = detach * value + (1 - detach) * value.detach()
96
+
97
+ query = attn.head_to_batch_dim(query)
98
+ key = attn.head_to_batch_dim(key)
99
+ value = attn.head_to_batch_dim(value)
100
+
101
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
102
+ hidden_states = torch.bmm(attention_probs, value)
103
+ hidden_states = attn.batch_to_head_dim(hidden_states)
104
+
105
+ # linear proj
106
+ hidden_states = attn.to_out[0](hidden_states)
107
+ # dropout
108
+ hidden_states = attn.to_out[1](hidden_states)
109
+
110
+ return hidden_states
111
+
112
+
113
+ class CustomDiffusionXFormersAttnProcessor:
114
+ def __init__(self, attention_op: Optional[Callable] = None):
115
+ self.attention_op = attention_op
116
+
117
+ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
118
+ batch_size, sequence_length, _ = hidden_states.shape
119
+
120
+ attention_mask = attn.prepare_attention_mask(
121
+ attention_mask, sequence_length, batch_size)
122
+
123
+ query = attn.to_q(hidden_states)
124
+
125
+ crossattn = False
126
+ if encoder_hidden_states is None:
127
+ encoder_hidden_states = hidden_states
128
+ else:
129
+ crossattn = True
130
+ if attn.cross_attention_norm:
131
+ encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
132
+
133
+ key = attn.to_k(encoder_hidden_states)
134
+ value = attn.to_v(encoder_hidden_states)
135
+ if crossattn:
136
+ detach = torch.ones_like(key)
137
+ detach[:, :1, :] = detach[:, :1, :] * 0.
138
+ key = detach * key + (1 - detach) * key.detach()
139
+ value = detach * value + (1 - detach) * value.detach()
140
+
141
+ query = attn.head_to_batch_dim(query).contiguous()
142
+ key = attn.head_to_batch_dim(key).contiguous()
143
+ value = attn.head_to_batch_dim(value).contiguous()
144
+
145
+ hidden_states = xformers.ops.memory_efficient_attention(
146
+ query, key, value, attn_bias=attention_mask, op=self.attention_op
147
+ )
148
+ hidden_states = hidden_states.to(query.dtype)
149
+ hidden_states = attn.batch_to_head_dim(hidden_states)
150
+
151
+ # linear proj
152
+ hidden_states = attn.to_out[0](hidden_states)
153
+ # dropout
154
+ hidden_states = attn.to_out[1](hidden_states)
155
+ return hidden_states
156
+
157
+
158
+ class CustomDiffusionPipeline(StableDiffusionPipeline):
159
+ r"""
160
+ Pipeline for custom diffusion model.
161
+
162
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
163
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.).
164
+
165
+ Args:
166
+ vae ([`AutoencoderKL`]):
167
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
168
+ text_encoder ([`CLIPTextModel`]):
169
+ Frozen text-encoder. Stable Diffusion uses the text portion of
170
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
171
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
172
+ tokenizer (`CLIPTokenizer`):
173
+ Tokenizer of class
174
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
175
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
176
+ scheduler ([`SchedulerMixin`]):
177
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
178
+ safety_checker ([`StableDiffusionSafetyChecker`]):
179
+ Classification module that estimates whether generated images could be considered offensive or harmful.
180
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
181
+ feature_extractor ([`CLIPFeatureExtractor`]):
182
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
183
+ modifier_token_id: list of id of tokens related to the target concept that are modified when ablated.
184
+ """
185
+ _optional_components = ["safety_checker",
186
+ "feature_extractor", "modifier_token_id"]
187
+
188
+ def __init__(
189
+ self,
190
+ vae: AutoencoderKL,
191
+ text_encoder: CLIPTextModel,
192
+ tokenizer: CLIPTokenizer,
193
+ unet: UNet2DConditionModel,
194
+ scheduler: SchedulerMixin,
195
+ safety_checker: StableDiffusionSafetyChecker,
196
+ feature_extractor: CLIPFeatureExtractor,
197
+ requires_safety_checker: bool = True,
198
+ modifier_token_id: list = [],
199
+ ):
200
+ super().__init__(vae,
201
+ text_encoder,
202
+ tokenizer,
203
+ unet,
204
+ scheduler,
205
+ safety_checker,
206
+ feature_extractor,
207
+ requires_safety_checker)
208
+
209
+ self.modifier_token_id = modifier_token_id
210
+
211
+ def save_pretrained(self, save_path, parameter_group="cross-attn", all=False):
212
+ if all:
213
+ super().save_pretrained(save_path)
214
+ else:
215
+ delta_dict = {'unet': {}}
216
+ if parameter_group == 'embedding':
217
+ delta_dict['text_encoder'] = self.text_encoder.state_dict()
218
+ for name, params in self.unet.named_parameters():
219
+ if parameter_group == "cross-attn":
220
+ if 'attn2.to_k' in name or 'attn2.to_v' in name:
221
+ delta_dict['unet'][name] = params.cpu().clone()
222
+ elif parameter_group == "full-weight":
223
+ delta_dict['unet'][name] = params.cpu().clone()
224
+ else:
225
+ raise ValueError(
226
+ "parameter_group argument only supports one of [cross-attn, full-weight, embedding]"
227
+ )
228
+ torch.save(delta_dict, save_path)
229
+
230
+ def load_model(self, save_path):
231
+ st = torch.load(save_path)
232
+ print(st.keys())
233
+ if 'text_encoder' in st:
234
+ self.text_encoder.load_state_dict(st['text_encoder'])
235
+ for name, params in self.unet.named_parameters():
236
+ if name in st['unet']:
237
+ params.data.copy_(st['unet'][f'{name}'])
concept-ablation-diffusers/train.py ADDED
@@ -0,0 +1,1199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is modified from the Huggingface repository: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py, and
2
+ import argparse
3
+ import hashlib
4
+ import itertools
5
+ import json
6
+ import logging
7
+ import math
8
+ import os
9
+ import warnings
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint
16
+ import transformers
17
+ from accelerate import Accelerator
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration, set_seed
20
+ from huggingface_hub import HfApi, create_repo
21
+ from model_pipeline import (
22
+ CustomDiffusionAttnProcessor,
23
+ CustomDiffusionPipeline,
24
+ set_use_memory_efficient_attention_xformers,
25
+ )
26
+ from packaging import version
27
+ from tqdm.auto import tqdm
28
+ from transformers import AutoTokenizer, PretrainedConfig
29
+ from utils import (
30
+ CustomDiffusionDataset,
31
+ PromptDataset,
32
+ collate_fn,
33
+ filter,
34
+ getanchorprompts,
35
+ )
36
+
37
+ import diffusers
38
+ from diffusers import (
39
+ AutoencoderKL,
40
+ DDPMScheduler,
41
+ DiffusionPipeline,
42
+ DPMSolverMultistepScheduler,
43
+ UNet2DConditionModel,
44
+ )
45
+ from diffusers.models.cross_attention import CrossAttention
46
+ from diffusers.optimization import get_scheduler
47
+ from diffusers.utils import check_min_version, is_wandb_available
48
+ from diffusers.utils.import_utils import is_xformers_available
49
+
50
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
51
+ check_min_version("0.14.0")
52
+
53
+ logger = get_logger(__name__)
54
+
55
+
56
+ def create_custom_diffusion(unet, parameter_group):
57
+ for name, params in unet.named_parameters():
58
+ if parameter_group == "cross-attn":
59
+ if 'attn2.to_k' in name or 'attn2.to_v' in name:
60
+ params.requires_grad = True
61
+ else:
62
+ params.requires_grad = False
63
+ elif parameter_group == 'full-weight':
64
+ params.requires_grad = True
65
+ elif parameter_group == 'embedding':
66
+ params.requires_grad = False
67
+ else:
68
+ raise ValueError(
69
+ "parameter_group argument only cross-attn, full-weight, embedding"
70
+ )
71
+
72
+ # change attn class
73
+ def change_attn(unet):
74
+ for layer in unet.children():
75
+ if type(layer) == CrossAttention:
76
+ bound_method = set_use_memory_efficient_attention_xformers.__get__(
77
+ layer, layer.__class__)
78
+ setattr(
79
+ layer, 'set_use_memory_efficient_attention_xformers', bound_method)
80
+ else:
81
+ change_attn(layer)
82
+
83
+ change_attn(unet)
84
+ unet.set_attn_processor(CustomDiffusionAttnProcessor())
85
+ return unet
86
+
87
+
88
+ def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
89
+ img_str = ""
90
+ for i, image in enumerate(images):
91
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
92
+ img_str += f"./image_{i}.png\n"
93
+
94
+ yaml = f"""
95
+ ---
96
+ license: creativeml-openrail-m
97
+ base_model: {base_model}
98
+ instance_prompt: {prompt}
99
+ tags:
100
+ - stable-diffusion
101
+ - stable-diffusion-diffusers
102
+ - text-to-image
103
+ - diffusers
104
+ - custom diffusion
105
+ inference: true
106
+ ---
107
+ """
108
+ model_card = f"""
109
+ # Custom Diffusion - {repo_id}
110
+
111
+ These are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n
112
+ {img_str[0]}
113
+ """
114
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
115
+ f.write(yaml + model_card)
116
+
117
+
118
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
119
+ text_encoder_config = PretrainedConfig.from_pretrained(
120
+ pretrained_model_name_or_path,
121
+ subfolder="text_encoder",
122
+ revision=revision,
123
+ )
124
+ model_class = text_encoder_config.architectures[0]
125
+
126
+ if model_class == "CLIPTextModel":
127
+ from transformers import CLIPTextModel
128
+
129
+ return CLIPTextModel
130
+ elif model_class == "RobertaSeriesModelWithTransformation":
131
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
132
+ RobertaSeriesModelWithTransformation,
133
+ )
134
+
135
+ return RobertaSeriesModelWithTransformation
136
+ else:
137
+ raise ValueError(f"{model_class} is not supported.")
138
+
139
+
140
+ def freeze_params(params):
141
+ for param in params:
142
+ param.requires_grad = False
143
+
144
+
145
+ def parse_args(input_args=None):
146
+ parser = argparse.ArgumentParser(
147
+ description="Simple example of a training script.")
148
+ parser.add_argument(
149
+ "--pretrained_model_name_or_path",
150
+ type=str,
151
+ default=None,
152
+ required=True,
153
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
154
+ )
155
+ parser.add_argument(
156
+ "--revision",
157
+ type=str,
158
+ default=None,
159
+ required=False,
160
+ help="Revision of pretrained model identifier from huggingface.co/models.",
161
+ )
162
+ parser.add_argument(
163
+ "--tokenizer_name",
164
+ type=str,
165
+ default=None,
166
+ help="Pretrained tokenizer name or path if not the same as model_name",
167
+ )
168
+ parser.add_argument(
169
+ "--concept_type",
170
+ type=str,
171
+ required=True,
172
+ choices=['style', 'object', 'memorization'],
173
+ help='the type of removed concepts'
174
+ )
175
+ parser.add_argument(
176
+ "--caption_target",
177
+ type=str,
178
+ required=True,
179
+ help="target style to remove, used when kldiv loss",
180
+ )
181
+ parser.add_argument(
182
+ "--instance_data_dir",
183
+ type=str,
184
+ default=None,
185
+ help="A folder containing the training data of instance images.",
186
+ )
187
+ parser.add_argument(
188
+ "--class_data_dir",
189
+ type=str,
190
+ default=None,
191
+ help="A folder containing the training data of class images.",
192
+ )
193
+ parser.add_argument(
194
+ "--instance_prompt",
195
+ type=str,
196
+ help="The prompt with identifier specifying the instance",
197
+ )
198
+ parser.add_argument(
199
+ "--class_prompt",
200
+ type=str,
201
+ default=None,
202
+ help="The prompt to specify images in the same class as provided instance images.",
203
+ )
204
+ parser.add_argument(
205
+ "--mem_impath",
206
+ type=str,
207
+ default="",
208
+ help='the path to saved memorized image. Required when concept_type is memorization'
209
+ )
210
+ parser.add_argument(
211
+ "--validation_prompt",
212
+ type=str,
213
+ default=None,
214
+ help="A prompt that is used during validation to verify that the model is learning.",
215
+ )
216
+ parser.add_argument(
217
+ "--num_validation_images",
218
+ type=int,
219
+ default=2,
220
+ help="Number of images that should be generated during validation with `validation_prompt`.",
221
+ )
222
+ parser.add_argument(
223
+ "--validation_steps",
224
+ type=int,
225
+ default=500,
226
+ help=(
227
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
228
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
229
+ ),
230
+ )
231
+ parser.add_argument(
232
+ "--with_prior_preservation",
233
+ default=False,
234
+ action="store_true",
235
+ help="Flag to add prior preservation loss.",
236
+ )
237
+ parser.add_argument("--prior_loss_weight", type=float,
238
+ default=1.0, help="The weight of prior preservation loss.")
239
+ parser.add_argument(
240
+ "--train_size",
241
+ type=int,
242
+ default=1000,
243
+ help='the number of generated images used for ablating the concept'
244
+ )
245
+ parser.add_argument(
246
+ "--output_dir",
247
+ type=str,
248
+ default="custom-diffusion-model",
249
+ help="The output directory where the model predictions and checkpoints will be written.",
250
+ )
251
+ parser.add_argument(
252
+ "--num_class_images",
253
+ type=int,
254
+ default=1000,
255
+ help=(
256
+ "Minimal anchor class images. If there are not enough images already present in"
257
+ " class_data_dir, additional images will be sampled with class_prompt."
258
+ ),
259
+ )
260
+ parser.add_argument(
261
+ "--num_class_prompts",
262
+ type=int,
263
+ default=200,
264
+ help=(
265
+ "Minimal prompts used to generate anchor class images"
266
+ ),
267
+ )
268
+ parser.add_argument("--seed", type=int, default=42,
269
+ help="A seed for reproducible training.")
270
+ parser.add_argument(
271
+ "--resolution",
272
+ type=int,
273
+ default=512,
274
+ help=(
275
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
276
+ " resolution"
277
+ ),
278
+ )
279
+ parser.add_argument(
280
+ "--center_crop",
281
+ default=False,
282
+ action="store_true",
283
+ help=(
284
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
285
+ " cropped. The images will be resized to the resolution first before cropping."
286
+ ),
287
+ )
288
+ parser.add_argument(
289
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
290
+ )
291
+ parser.add_argument(
292
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
293
+ )
294
+ parser.add_argument("--num_train_epochs", type=int, default=1)
295
+ parser.add_argument(
296
+ "--max_train_steps",
297
+ type=int,
298
+ default=None,
299
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
300
+ )
301
+ parser.add_argument(
302
+ "--checkpointing_steps",
303
+ type=int,
304
+ default=250,
305
+ help=(
306
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
307
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
308
+ " training using `--resume_from_checkpoint`."
309
+ ),
310
+ )
311
+ parser.add_argument(
312
+ "--checkpoints_total_limit",
313
+ type=int,
314
+ default=None,
315
+ help=(
316
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
317
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
318
+ " for more docs"
319
+ ),
320
+ )
321
+ parser.add_argument(
322
+ "--resume_from_checkpoint",
323
+ type=str,
324
+ default=None,
325
+ help=(
326
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
327
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
328
+ ),
329
+ )
330
+ parser.add_argument(
331
+ "--gradient_accumulation_steps",
332
+ type=int,
333
+ default=1,
334
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
335
+ )
336
+ parser.add_argument(
337
+ "--gradient_checkpointing",
338
+ action="store_true",
339
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
340
+ )
341
+ parser.add_argument(
342
+ "--learning_rate",
343
+ type=float,
344
+ default=1e-5,
345
+ help="Initial learning rate (after the potential warmup period) to use.",
346
+ )
347
+ parser.add_argument(
348
+ "--scale_lr",
349
+ action="store_true",
350
+ default=False,
351
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
352
+ )
353
+ parser.add_argument(
354
+ "--dataloader_num_workers",
355
+ type=int,
356
+ default=2,
357
+ help=(
358
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
359
+ ),
360
+ )
361
+ parser.add_argument(
362
+ "--parameter_group",
363
+ type=str,
364
+ default='cross-attn',
365
+ choices=['full-weight', 'cross-attn', 'embedding'],
366
+ help='parameter groups to finetune. Default: full-weight for memorization and cross-attn for others'
367
+ )
368
+ parser.add_argument(
369
+ "--loss_type_reverse",
370
+ type=str,
371
+ default='model-based',
372
+ help="loss type for reverse fine-tuning",
373
+ )
374
+ parser.add_argument(
375
+ "--lr_scheduler",
376
+ type=str,
377
+ default="constant",
378
+ help=(
379
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
380
+ ' "constant", "constant_with_warmup"]'
381
+ ),
382
+ )
383
+ parser.add_argument(
384
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
385
+ )
386
+ parser.add_argument(
387
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
388
+ )
389
+ parser.add_argument("--adam_beta1", type=float, default=0.9,
390
+ help="The beta1 parameter for the Adam optimizer.")
391
+ parser.add_argument("--adam_beta2", type=float, default=0.999,
392
+ help="The beta2 parameter for the Adam optimizer.")
393
+ parser.add_argument("--adam_weight_decay", type=float,
394
+ default=1e-2, help="Weight decay to use.")
395
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08,
396
+ help="Epsilon value for the Adam optimizer")
397
+ parser.add_argument("--max_grad_norm", default=1.0,
398
+ type=float, help="Max gradient norm.")
399
+ parser.add_argument("--push_to_hub", action="store_true",
400
+ help="Whether or not to push the model to the Hub.")
401
+ parser.add_argument("--hub_token", type=str, default=None,
402
+ help="The token to use to push to the Model Hub.")
403
+ parser.add_argument(
404
+ "--hub_model_id",
405
+ type=str,
406
+ default=None,
407
+ help="The name of the repository to keep in sync with the local `output_dir`.",
408
+ )
409
+ parser.add_argument(
410
+ "--logging_dir",
411
+ type=str,
412
+ default="logs",
413
+ help=(
414
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
415
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
416
+ ),
417
+ )
418
+ parser.add_argument(
419
+ "--allow_tf32",
420
+ action="store_true",
421
+ help=(
422
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
423
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
424
+ ),
425
+ )
426
+ parser.add_argument(
427
+ "--report_to",
428
+ type=str,
429
+ default="tensorboard",
430
+ help=(
431
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
432
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
433
+ ),
434
+ )
435
+ parser.add_argument(
436
+ "--mixed_precision",
437
+ type=str,
438
+ default=None,
439
+ choices=["no", "fp16", "bf16"],
440
+ help=(
441
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
442
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
443
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
444
+ ),
445
+ )
446
+ parser.add_argument(
447
+ "--prior_generation_precision",
448
+ type=str,
449
+ default=None,
450
+ choices=["no", "fp32", "fp16", "bf16"],
451
+ help=(
452
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
453
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
454
+ ),
455
+ )
456
+ parser.add_argument(
457
+ "--concepts_list",
458
+ type=str,
459
+ default=None,
460
+ help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
461
+ )
462
+ parser.add_argument(
463
+ "--openai_key",
464
+ type=str,
465
+ default="",
466
+ help=(
467
+ "OPENAI API key. required for ablating objects and memorized images."
468
+ ),
469
+ )
470
+ parser.add_argument("--local_rank", type=int, default=-1,
471
+ help="For distributed training: local_rank")
472
+ parser.add_argument(
473
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
474
+ )
475
+ parser.add_argument("--hflip", action="store_true",
476
+ help="Apply horizontal flip data augmentation.")
477
+ parser.add_argument("--noaug", action="store_true",
478
+ help="Dont apply augmentation during data augmentation when this flag is enabled.")
479
+
480
+ if input_args is not None:
481
+ args = parser.parse_args(input_args)
482
+ else:
483
+ args = parser.parse_args()
484
+
485
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
486
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
487
+ args.local_rank = env_local_rank
488
+
489
+ if args.with_prior_preservation:
490
+ if args.concepts_list is None:
491
+ if args.class_data_dir is None:
492
+ raise ValueError(
493
+ "You must specify a data directory for class images.")
494
+ if args.class_prompt is None:
495
+ raise ValueError("You must specify prompt for class images.")
496
+ else:
497
+ # logger is not available yet
498
+ if args.class_data_dir is not None:
499
+ warnings.warn(
500
+ "You need not use --class_data_dir without --with_prior_preservation.")
501
+ if args.class_prompt is not None:
502
+ warnings.warn(
503
+ "You need not use --class_prompt without --with_prior_preservation.")
504
+
505
+ return args
506
+
507
+
508
+ def main(args):
509
+ logging_dir = Path(args.output_dir, args.logging_dir)
510
+
511
+ accelerator_project_config = ProjectConfiguration(
512
+ total_limit=args.checkpoints_total_limit)
513
+
514
+ accelerator = Accelerator(
515
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
516
+ mixed_precision=args.mixed_precision,
517
+ log_with=args.report_to,
518
+ project_dir=logging_dir,
519
+ project_config=accelerator_project_config,
520
+ )
521
+
522
+ if args.report_to == "wandb":
523
+ if not is_wandb_available():
524
+ raise ImportError(
525
+ "Make sure to install wandb if you want to use it for logging during training.")
526
+ import wandb
527
+
528
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
529
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
530
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
531
+ # Make one log on every process with the configuration for debugging.
532
+ logging.basicConfig(
533
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
534
+ datefmt="%m/%d/%Y %H:%M:%S",
535
+ level=logging.INFO,
536
+ )
537
+ logger.info(accelerator.state, main_process_only=False)
538
+ if accelerator.is_local_main_process:
539
+ transformers.utils.logging.set_verbosity_warning()
540
+ diffusers.utils.logging.set_verbosity_info()
541
+ else:
542
+ transformers.utils.logging.set_verbosity_error()
543
+ diffusers.utils.logging.set_verbosity_error()
544
+
545
+ # We need to initialize the trackers we use, and also store our configuration.
546
+ # The trackers initializes automatically on the main process.
547
+ if accelerator.is_main_process:
548
+ print(vars(args))
549
+ accelerator.init_trackers("custom-diffusion", config=vars(args))
550
+
551
+ # If passed along, set the training seed now.
552
+ if args.seed is not None:
553
+ set_seed(args.seed)
554
+ if args.concepts_list is None:
555
+ args.concepts_list = [
556
+ {
557
+ "instance_prompt": args.instance_prompt,
558
+ "class_prompt": args.class_prompt,
559
+ "instance_data_dir": args.instance_data_dir,
560
+ "class_data_dir": args.class_data_dir,
561
+ "caption_target": args.caption_target,
562
+ }
563
+ ]
564
+ else:
565
+ with open(args.concepts_list, "r") as f:
566
+ args.concepts_list = json.load(f)
567
+
568
+ # Generate class images if prior preservation is enabled.
569
+ for i, concept in enumerate(args.concepts_list):
570
+ # directly path to ablation images and its corresponding prompts is provided.
571
+ if (concept['instance_prompt'] is not None and concept['instance_data_dir'] is not None):
572
+ break
573
+
574
+ class_images_dir = Path(concept['class_data_dir'])
575
+ if not class_images_dir.exists():
576
+ class_images_dir.mkdir(parents=True, exist_ok=True)
577
+ os.makedirs(f'{class_images_dir}/images', exist_ok=True)
578
+
579
+ # we need to generate training images
580
+ if len(list(Path(os.path.join(class_images_dir, 'images')).iterdir())) < args.num_class_images:
581
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
582
+ if args.prior_generation_precision == "fp32":
583
+ torch_dtype = torch.float32
584
+ elif args.prior_generation_precision == "fp16":
585
+ torch_dtype = torch.float16
586
+ elif args.prior_generation_precision == "bf16":
587
+ torch_dtype = torch.bfloat16
588
+ pipeline = DiffusionPipeline.from_pretrained(
589
+ args.pretrained_model_name_or_path,
590
+ torch_dtype=torch_dtype,
591
+ safety_checker=None,
592
+ revision=args.revision,
593
+ )
594
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
595
+ pipeline.scheduler.config)
596
+
597
+ pipeline.set_progress_bar_config(disable=True)
598
+ pipeline.to(accelerator.device)
599
+
600
+ # need to create prompts using class_prompt.
601
+ if not os.path.isfile(concept['class_prompt']):
602
+ # style based prompts are retrieved from laion dataset
603
+ if args.concept_type == 'style':
604
+ with open(os.path.join(class_images_dir, 'painting.txt')) as f:
605
+ class_prompt_collection = [
606
+ x.strip() for x in f.readlines()]
607
+
608
+ # LLM based prompt collection.
609
+ else:
610
+ class_prompt = concept['class_prompt']
611
+ # in case of object query chatGPT to generate captions containing the anchor category
612
+ if args.concept_type == 'object':
613
+ class_prompt_collection, _ = getanchorprompts(
614
+ pipeline, accelerator, class_prompt, args.concept_type, class_images_dir, args.openai_key, args.num_class_prompts)
615
+ with open(class_images_dir / 'caption_anchor.txt', 'w') as f:
616
+ for prompt in class_prompt_collection:
617
+ f.write(prompt + '\n')
618
+ # in case of memorization query chatGPT to generate different captions that can be paraphrase of the origianl caption
619
+ elif args.concept_type == 'memorization':
620
+ class_prompt_collection, caption_target = getanchorprompts(
621
+ pipeline, accelerator, class_prompt, args.concept_type, class_images_dir, args.openai_key, args.num_class_prompts, mem_impath=args.mem_impath)
622
+ concept['caption_target'] += f';*+{caption_target}'
623
+ with open(class_images_dir / 'caption_target.txt', 'w') as f:
624
+ f.write(concept['caption_target'])
625
+ print(class_prompt_collection,
626
+ concept['caption_target'])
627
+ # class_prompt is filepath to prompts.
628
+ else:
629
+ with open(concept['class_prompt']) as f:
630
+ class_prompt_collection = [
631
+ x.strip() for x in f.readlines()]
632
+
633
+ num_new_images = args.num_class_images
634
+ logger.info(
635
+ f"Number of class images to sample: {num_new_images}.")
636
+
637
+ sample_dataset = PromptDataset(
638
+ class_prompt_collection, num_new_images)
639
+ sample_dataloader = torch.utils.data.DataLoader(
640
+ sample_dataset, batch_size=args.sample_batch_size)
641
+
642
+ sample_dataloader = accelerator.prepare(sample_dataloader)
643
+
644
+ if os.path.exists(f'{class_images_dir}/caption.txt'):
645
+ os.remove(f'{class_images_dir}/caption.txt')
646
+ if os.path.exists(f'{class_images_dir}/images.txt'):
647
+ os.remove(f'{class_images_dir}/images.txt')
648
+
649
+ for example in tqdm(
650
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
651
+ ):
652
+ accelerator.wait_for_everyone()
653
+ with open(f'{class_images_dir}/caption.txt', 'a') as f1, open(f'{class_images_dir}/images.txt', 'a') as f2:
654
+ images = pipeline(example["prompt"], num_inference_steps=25, guidance_scale=6., eta=1.).images
655
+
656
+ for i, image in enumerate(images):
657
+ hash_image = hashlib.sha1(
658
+ image.tobytes()).hexdigest()
659
+ image_filename = class_images_dir / \
660
+ f"images/{example['index'][i]}-{hash_image}.jpg"
661
+ image.save(image_filename)
662
+ f2.write(str(image_filename)+'\n')
663
+ f1.write('\n'.join(example["prompt"]) + '\n')
664
+ accelerator.wait_for_everyone()
665
+
666
+ del pipeline
667
+
668
+ if args.concept_type == 'memorization':
669
+ filter(class_images_dir, args.mem_impath,
670
+ outpath=str(class_images_dir / 'filtered'))
671
+ if os.path.exists(class_images_dir / 'caption_target.txt'):
672
+ with open(class_images_dir / 'caption_target.txt', 'r') as f:
673
+ concept['caption_target'] = f.readlines()[0].strip()
674
+ class_images_dir = class_images_dir / 'filtered'
675
+
676
+ concept['class_prompt'] = os.path.join(
677
+ class_images_dir, 'caption.txt')
678
+ concept['class_data_dir'] = os.path.join(
679
+ class_images_dir, 'images.txt')
680
+ concept['instance_prompt'] = os.path.join(
681
+ class_images_dir, 'caption.txt')
682
+ concept['instance_data_dir'] = os.path.join(
683
+ class_images_dir, 'images.txt')
684
+
685
+ if torch.cuda.is_available():
686
+ torch.cuda.empty_cache()
687
+
688
+ # Handle the repository creation
689
+ if accelerator.is_main_process:
690
+ if args.output_dir is not None:
691
+ os.makedirs(args.output_dir, exist_ok=True)
692
+
693
+ if args.push_to_hub:
694
+ print(args.hub_model_id or Path(args.output_dir).name)
695
+ repo_id = create_repo(
696
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
697
+ )
698
+ print(repo_id)
699
+ repo_id = args.hub_model_id
700
+
701
+ # Load the tokenizer
702
+ if args.tokenizer_name:
703
+ tokenizer = AutoTokenizer.from_pretrained(
704
+ args.tokenizer_name,
705
+ revision=args.revision,
706
+ use_fast=False,
707
+ )
708
+ elif args.pretrained_model_name_or_path:
709
+ tokenizer = AutoTokenizer.from_pretrained(
710
+ args.pretrained_model_name_or_path,
711
+ subfolder="tokenizer",
712
+ revision=args.revision,
713
+ use_fast=False,
714
+ )
715
+
716
+ # import correct text encoder class
717
+ text_encoder_cls = import_model_class_from_model_name_or_path(
718
+ args.pretrained_model_name_or_path, args.revision)
719
+
720
+ # Load scheduler and models
721
+ noise_scheduler = DDPMScheduler.from_pretrained(
722
+ args.pretrained_model_name_or_path, subfolder="scheduler")
723
+ text_encoder = text_encoder_cls.from_pretrained(
724
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
725
+ )
726
+ vae = AutoencoderKL.from_pretrained(
727
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
728
+ unet = UNet2DConditionModel.from_pretrained(
729
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
730
+ )
731
+
732
+ vae.requires_grad_(False)
733
+ if args.parameter_group != 'embedding':
734
+ text_encoder.requires_grad_(False)
735
+ unet = create_custom_diffusion(unet, args.parameter_group)
736
+
737
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
738
+ # as these models are only used for inference, keeping weights in full precision is not required.
739
+ weight_dtype = torch.float32
740
+ if accelerator.mixed_precision == "fp16":
741
+ weight_dtype = torch.float16
742
+ elif accelerator.mixed_precision == "bf16":
743
+ weight_dtype = torch.bfloat16
744
+
745
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
746
+ if accelerator.mixed_precision != "fp16":
747
+ unet.to(accelerator.device, dtype=weight_dtype)
748
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
749
+ vae.to(accelerator.device, dtype=weight_dtype)
750
+
751
+ if args.enable_xformers_memory_efficient_attention:
752
+ if is_xformers_available():
753
+ import xformers
754
+ xformers_version = version.parse(xformers.__version__)
755
+ if xformers_version == version.parse("0.0.16"):
756
+ logger.warn(
757
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
758
+ )
759
+ unet.enable_xformers_memory_efficient_attention()
760
+ else:
761
+ raise ValueError(
762
+ "xformers is not available. Make sure it is installed correctly")
763
+
764
+ if args.gradient_checkpointing:
765
+ unet.enable_gradient_checkpointing()
766
+ if args.parameter_group == 'embedding':
767
+ text_encoder.gradient_checkpointing_enable()
768
+ # Enable TF32 for faster training on Ampere GPUs,
769
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
770
+ if args.allow_tf32:
771
+ torch.backends.cuda.matmul.allow_tf32 = True
772
+
773
+ if args.scale_lr:
774
+ args.learning_rate = (
775
+ args.learning_rate * args.gradient_accumulation_steps *
776
+ args.train_batch_size * accelerator.num_processes
777
+ )
778
+ if args.with_prior_preservation:
779
+ args.learning_rate = args.learning_rate * 2.
780
+
781
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
782
+ if args.use_8bit_adam:
783
+ try:
784
+ import bitsandbytes as bnb
785
+ except ImportError:
786
+ raise ImportError(
787
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
788
+ )
789
+
790
+ optimizer_class = bnb.optim.AdamW8bit
791
+ else:
792
+ optimizer_class = torch.optim.AdamW
793
+
794
+ # Adding a modifier token which is optimized ####
795
+ # Code taken from https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
796
+ modifier_token_id = []
797
+ if args.parameter_group == 'embedding':
798
+ assert args.concept_type != 'memorization', "embedding finetuning is not supported for memorization"
799
+
800
+ for concept in args.concept_list:
801
+ # Convert the caption_target to ids
802
+ token_ids = tokenizer.encode(
803
+ [concept['caption_target']], add_special_tokens=False)
804
+ print(token_ids)
805
+ # Check if initializer_token is a single token or a sequence of tokens
806
+ modifier_token_id += token_ids
807
+
808
+ # Freeze all parameters except for the token embeddings in text encoder
809
+ params_to_freeze = itertools.chain(
810
+ text_encoder.text_model.encoder.parameters(),
811
+ text_encoder.text_model.final_layer_norm.parameters(),
812
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
813
+ )
814
+ freeze_params(params_to_freeze)
815
+ params_to_optimize = itertools.chain(
816
+ text_encoder.get_input_embeddings().parameters())
817
+ else:
818
+ if args.parameter_group == 'cross-attn':
819
+ params_to_optimize = itertools.chain([x[1] for x in unet.named_parameters() if (
820
+ 'attn2.to_k' in x[0] or 'attn2.to_v' in x[0])])
821
+ if args.parameter_group == 'full-weight':
822
+ params_to_optimize = itertools.chain(unet.parameters())
823
+
824
+ # Optimizer creation
825
+ optimizer = optimizer_class(
826
+ params_to_optimize,
827
+ lr=args.learning_rate,
828
+ betas=(args.adam_beta1, args.adam_beta2),
829
+ weight_decay=args.adam_weight_decay,
830
+ eps=args.adam_epsilon,
831
+ )
832
+
833
+ # Dataset and DataLoaders creation:
834
+ train_dataset = CustomDiffusionDataset(
835
+ concepts_list=args.concepts_list,
836
+ concept_type=args.concept_type,
837
+ tokenizer=tokenizer,
838
+ with_prior_preservation=args.with_prior_preservation,
839
+ size=args.resolution,
840
+ center_crop=args.center_crop,
841
+ num_class_images=args.num_class_images,
842
+ hflip=args.hflip, aug=not args.noaug,
843
+ )
844
+
845
+ train_dataloader = torch.utils.data.DataLoader(
846
+ train_dataset,
847
+ batch_size=args.train_batch_size,
848
+ shuffle=True,
849
+ collate_fn=lambda examples: collate_fn(
850
+ examples, args.with_prior_preservation),
851
+ num_workers=args.dataloader_num_workers,
852
+ )
853
+
854
+ # Scheduler and math around the number of training steps.
855
+ overrode_max_train_steps = False
856
+ num_update_steps_per_epoch = math.ceil(
857
+ len(train_dataloader) / args.gradient_accumulation_steps)
858
+ if args.max_train_steps is None:
859
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
860
+ overrode_max_train_steps = True
861
+
862
+ lr_scheduler = get_scheduler(
863
+ args.lr_scheduler,
864
+ optimizer=optimizer,
865
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
866
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
867
+ )
868
+
869
+ # Prepare everything with our `accelerator`.
870
+ if args.parameter_group == 'embedding':
871
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
872
+ text_encoder, optimizer, train_dataloader, lr_scheduler
873
+ )
874
+ else:
875
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
876
+ unet, optimizer, train_dataloader, lr_scheduler
877
+ )
878
+
879
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
880
+ num_update_steps_per_epoch = math.ceil(
881
+ len(train_dataloader) / args.gradient_accumulation_steps)
882
+ if overrode_max_train_steps:
883
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
884
+ # Afterwards we recalculate our number of training epochs
885
+ args.num_train_epochs = math.ceil(
886
+ args.max_train_steps / num_update_steps_per_epoch)
887
+
888
+ # Train!
889
+ total_batch_size = args.train_batch_size * \
890
+ accelerator.num_processes * args.gradient_accumulation_steps
891
+
892
+ logger.info("***** Running training *****")
893
+ logger.info(f" Num examples = {len(train_dataset)}")
894
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
895
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
896
+ logger.info(
897
+ f" Instantaneous batch size per device = {args.train_batch_size}")
898
+ logger.info(
899
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
900
+ logger.info(
901
+ f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
902
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
903
+ global_step = 0
904
+ first_epoch = 0
905
+
906
+ # Potentially load in the weights and states from a previous save
907
+ if args.resume_from_checkpoint:
908
+ if args.resume_from_checkpoint != "latest":
909
+ path = os.path.basename(args.resume_from_checkpoint)
910
+ else:
911
+ # Get the mos recent checkpoint
912
+ dirs = os.listdir(args.output_dir)
913
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
914
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
915
+ path = dirs[-1] if len(dirs) > 0 else None
916
+
917
+ if path is None:
918
+ accelerator.print(
919
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
920
+ )
921
+ args.resume_from_checkpoint = None
922
+ else:
923
+ accelerator.print(f"Resuming from checkpoint {path}")
924
+ accelerator.load_state(os.path.join(args.output_dir, path))
925
+ global_step = int(path.split("-")[1])
926
+
927
+ resume_global_step = global_step * args.gradient_accumulation_steps
928
+ first_epoch = global_step // num_update_steps_per_epoch
929
+ resume_step = resume_global_step % (
930
+ num_update_steps_per_epoch * args.gradient_accumulation_steps)
931
+
932
+ # Only show the progress bar once on each machine.
933
+ progress_bar = tqdm(range(global_step, args.max_train_steps),
934
+ disable=not accelerator.is_local_main_process)
935
+ progress_bar.set_description("Steps")
936
+
937
+ for epoch in range(first_epoch, args.num_train_epochs):
938
+ if args.parameter_group == 'embedding':
939
+ text_encoder.train()
940
+ else:
941
+ unet.train()
942
+ for step, batch in enumerate(train_dataloader):
943
+ # Skip steps until we reach the resumed step
944
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
945
+ if step % args.gradient_accumulation_steps == 0:
946
+ progress_bar.update(1)
947
+ continue
948
+
949
+ with accelerator.accumulate(unet) if args.parameter_group != 'embedding' else accelerator.accumulate(text_encoder):
950
+ # Convert images to latent space
951
+ latents = vae.encode(batch["pixel_values"].to(
952
+ dtype=weight_dtype)).latent_dist.sample()
953
+ latents = latents * vae.config.scaling_factor
954
+
955
+ # Sample noise that we'll add to the latents
956
+ noise = torch.randn_like(latents)
957
+ bsz = latents.shape[0]
958
+ # Sample a random timestep for each image
959
+ timesteps = torch.randint(
960
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
961
+ timesteps = timesteps.long()
962
+
963
+ # Add noise to the latents according to the noise magnitude at each timestep
964
+ # (this is the forward diffusion process)
965
+ noisy_latents = noise_scheduler.add_noise(
966
+ latents, noise, timesteps)
967
+
968
+ # Get the text embedding for conditioning
969
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
970
+ encoder_anchor_hidden_states = text_encoder(
971
+ batch["input_anchor_ids"])[0]
972
+
973
+ # Predict the noise residual
974
+ model_pred = unet(noisy_latents, timesteps,
975
+ encoder_hidden_states).sample
976
+ with torch.no_grad():
977
+ model_pred_anchor = unet(noisy_latents[:encoder_anchor_hidden_states.size(
978
+ 0)], timesteps[:encoder_anchor_hidden_states.size(0)], encoder_anchor_hidden_states).sample
979
+
980
+ # Get the target for loss depending on the prediction type
981
+ if args.loss_type_reverse == 'model-based':
982
+ if args.with_prior_preservation:
983
+ target_prior = torch.chunk(noise, 2, dim=0)[1]
984
+ target = model_pred_anchor
985
+ else:
986
+ if noise_scheduler.config.prediction_type == "epsilon":
987
+ target = noise
988
+ elif noise_scheduler.config.prediction_type == "v_prediction":
989
+ target = noise_scheduler.get_velocity(
990
+ latents, noise, timesteps)
991
+ else:
992
+ raise ValueError(
993
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}")
994
+ if args.with_prior_preservation:
995
+ target, target_prior = torch.chunk(target, 2, dim=0)
996
+
997
+ if args.with_prior_preservation:
998
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
999
+ model_pred, model_pred_prior = torch.chunk(
1000
+ model_pred, 2, dim=0)
1001
+ mask = torch.chunk(batch["mask"], 2, dim=0)[0]
1002
+ # Compute instance loss
1003
+ loss = F.mse_loss(model_pred.float(),
1004
+ target.float(), reduction="none")
1005
+ loss = (
1006
+ (loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
1007
+
1008
+ # Compute prior loss
1009
+ prior_loss = F.mse_loss(
1010
+ model_pred_prior.float(), target_prior.float(), reduction="mean")
1011
+
1012
+ # Add the prior loss to the instance loss.
1013
+ loss = loss + args.prior_loss_weight * prior_loss
1014
+ else:
1015
+ mask = batch["mask"]
1016
+ loss = F.mse_loss(model_pred.float(),
1017
+ target.float(), reduction="none")
1018
+ loss = (
1019
+ (loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
1020
+
1021
+ accelerator.backward(loss)
1022
+ # Zero out the gradients for all token embeddings except the newly added
1023
+ # embeddings for the concept, as we only want to optimize the concept embeddings
1024
+ if args.parameter_group == 'embedding':
1025
+ if accelerator.num_processes > 1:
1026
+ grads_text_encoder = text_encoder.module.get_input_embeddings().weight.grad
1027
+ else:
1028
+ grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
1029
+ # Get the index for tokens that we want to zero the grads for
1030
+ index_grads_to_zero = torch.arange(
1031
+ len(tokenizer)) != modifier_token_id[0]
1032
+ for i in range(len(modifier_token_id[1:])):
1033
+ index_grads_to_zero = index_grads_to_zero & (
1034
+ torch.arange(len(tokenizer)) != modifier_token_id[i])
1035
+ grads_text_encoder.data[index_grads_to_zero,
1036
+ :] = grads_text_encoder.data[index_grads_to_zero, :].fill_(0)
1037
+
1038
+ if accelerator.sync_gradients:
1039
+ params_to_clip = (
1040
+ itertools.chain(text_encoder.parameters())
1041
+ if args.parameter_group == 'embedding'
1042
+ else itertools.chain([x[1] for x in unet.named_parameters() if ('attn2' in x[0])])
1043
+ if args.parameter_group == 'cross-attn'
1044
+ else itertools.chain(unet.parameters())
1045
+ )
1046
+ accelerator.clip_grad_norm_(
1047
+ params_to_clip, args.max_grad_norm)
1048
+ optimizer.step()
1049
+ lr_scheduler.step()
1050
+ optimizer.zero_grad()
1051
+
1052
+ # Checks if the accelerator has performed an optimization step behind the scenes
1053
+ if accelerator.sync_gradients:
1054
+ progress_bar.update(1)
1055
+ global_step += 1
1056
+
1057
+ if global_step % args.checkpointing_steps == 0:
1058
+ if accelerator.is_main_process:
1059
+ pipeline = CustomDiffusionPipeline.from_pretrained(
1060
+ args.pretrained_model_name_or_path,
1061
+ unet=accelerator.unwrap_model(unet),
1062
+ text_encoder=accelerator.unwrap_model(
1063
+ text_encoder),
1064
+ tokenizer=tokenizer,
1065
+ revision=args.revision,
1066
+ modifier_token_id=modifier_token_id,
1067
+ )
1068
+ save_path = os.path.join(
1069
+ args.output_dir, f"delta-{global_step}")
1070
+ pipeline.save_pretrained(
1071
+ save_path, parameter_group=args.parameter_group)
1072
+ logger.info(f"Saved state to {save_path}")
1073
+
1074
+ logs = {"loss": loss.detach().item(
1075
+ ), "lr": lr_scheduler.get_last_lr()[0]}
1076
+ progress_bar.set_postfix(**logs)
1077
+ accelerator.log(logs, step=global_step)
1078
+
1079
+ if global_step >= args.max_train_steps:
1080
+ break
1081
+
1082
+ if accelerator.is_main_process:
1083
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1084
+ logger.info(
1085
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1086
+ f" {args.validation_prompt}."
1087
+ )
1088
+ # create pipeline
1089
+ pipeline = CustomDiffusionPipeline.from_pretrained(
1090
+ args.pretrained_model_name_or_path,
1091
+ unet=accelerator.unwrap_model(unet),
1092
+ text_encoder=accelerator.unwrap_model(text_encoder),
1093
+ tokenizer=tokenizer,
1094
+ revision=args.revision,
1095
+ modifier_token_id=modifier_token_id,
1096
+ )
1097
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1098
+ pipeline.scheduler.config)
1099
+ pipeline = pipeline.to(accelerator.device)
1100
+ pipeline.set_progress_bar_config(disable=True)
1101
+
1102
+ # run inference
1103
+ generator = torch.Generator(
1104
+ device=accelerator.device).manual_seed(args.seed)
1105
+ images = [
1106
+ pipeline(args.validation_prompt, num_inference_steps=25,
1107
+ generator=generator, eta=1.).images[0]
1108
+ for _ in range(args.num_validation_images)
1109
+ ]
1110
+
1111
+ for tracker in accelerator.trackers:
1112
+ if tracker.name == "tensorboard":
1113
+ np_images = np.stack([np.asarray(img)
1114
+ for img in images])
1115
+ tracker.writer.add_images(
1116
+ "validation", np_images, epoch, dataformats="NHWC")
1117
+ if tracker.name == "wandb":
1118
+ tracker.log(
1119
+ {
1120
+ "validation": [
1121
+ wandb.Image(
1122
+ image, caption=f"{i}: {args.validation_prompt}")
1123
+ for i, image in enumerate(images)
1124
+ ]
1125
+ }
1126
+ )
1127
+
1128
+ del pipeline
1129
+ torch.cuda.empty_cache()
1130
+
1131
+ accelerator.wait_for_everyone()
1132
+ if accelerator.is_main_process:
1133
+ unet = unet.to(torch.float32)
1134
+ pipeline = CustomDiffusionPipeline.from_pretrained(
1135
+ args.pretrained_model_name_or_path,
1136
+ unet=accelerator.unwrap_model(unet),
1137
+ text_encoder=accelerator.unwrap_model(text_encoder),
1138
+ tokenizer=tokenizer,
1139
+ revision=args.revision,
1140
+ modifier_token_id=modifier_token_id,
1141
+ )
1142
+ save_path = os.path.join(args.output_dir, "delta.bin")
1143
+ pipeline.save_pretrained(
1144
+ save_path, parameter_group=args.parameter_group)
1145
+
1146
+ # run inference
1147
+ if args.validation_prompt and args.num_validation_images > 0:
1148
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1149
+ pipeline.scheduler.config)
1150
+ pipeline = pipeline.to(accelerator.device)
1151
+ pipeline.set_progress_bar_config(disable=True)
1152
+
1153
+ # run inference
1154
+ generator = torch.Generator(
1155
+ device=accelerator.device).manual_seed(args.seed)
1156
+ images = [
1157
+ pipeline(args.validation_prompt, num_inference_steps=25,
1158
+ generator=generator, eta=1.).images[0]
1159
+ for _ in range(args.num_validation_images)
1160
+ ]
1161
+
1162
+ for tracker in accelerator.trackers:
1163
+ if tracker.name == "tensorboard":
1164
+ np_images = np.stack([np.asarray(img) for img in images])
1165
+ tracker.writer.add_images(
1166
+ "test", np_images, epoch, dataformats="NHWC")
1167
+ if tracker.name == "wandb":
1168
+ tracker.log(
1169
+ {
1170
+ "test": [
1171
+ wandb.Image(
1172
+ image, caption=f"{i}: {args.validation_prompt}")
1173
+ for i, image in enumerate(images)
1174
+ ]
1175
+ }
1176
+ )
1177
+
1178
+ if args.push_to_hub:
1179
+ save_model_card(
1180
+ repo_id,
1181
+ images=images,
1182
+ base_model=args.pretrained_model_name_or_path,
1183
+ prompt=args.instance_prompt,
1184
+ repo_folder=args.output_dir,
1185
+ )
1186
+ api = HfApi(token=args.hub_token)
1187
+ api.upload_folder(
1188
+ repo_id=repo_id,
1189
+ folder_path=args.output_dir,
1190
+ path_in_repo='.',
1191
+ repo_type='model'
1192
+ )
1193
+
1194
+ accelerator.end_training()
1195
+
1196
+
1197
+ if __name__ == "__main__":
1198
+ args = parse_args()
1199
+ main(args)
concept-ablation-diffusers/utils.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import shutil
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import openai
9
+ import regex as re
10
+ import requests
11
+ import torch
12
+ from PIL import Image
13
+ from torch.utils.data import Dataset
14
+ from torchvision import transforms
15
+ from tqdm.auto import tqdm
16
+
17
+ from diffusers import DPMSolverMultistepScheduler
18
+
19
+ normalize = transforms.Normalize(
20
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
21
+ )
22
+ small_288 = transforms.Compose([
23
+ transforms.Resize(288),
24
+ transforms.ToTensor(),
25
+ normalize,
26
+ ])
27
+
28
+ def collate_fn(examples, with_prior_preservation):
29
+ input_ids = [example["instance_prompt_ids"] for example in examples]
30
+ input_anchor_ids = [example["instance_anchor_prompt_ids"]
31
+ for example in examples]
32
+ pixel_values = [example["instance_images"] for example in examples]
33
+ mask = [example["mask"] for example in examples]
34
+ # Concat class and instance examples for prior preservation.
35
+ # We do this to avoid doing two forward passes.
36
+ if with_prior_preservation:
37
+ input_ids += [example["class_prompt_ids"] for example in examples]
38
+ pixel_values += [example["class_images"] for example in examples]
39
+ mask += [example["class_mask"] for example in examples]
40
+
41
+ input_ids = torch.cat(input_ids, dim=0)
42
+ input_anchor_ids = torch.cat(input_anchor_ids, dim=0)
43
+ pixel_values = torch.stack(pixel_values)
44
+ mask = torch.stack(mask)
45
+ pixel_values = pixel_values.to(
46
+ memory_format=torch.contiguous_format).float()
47
+ mask = mask.to(memory_format=torch.contiguous_format).float()
48
+
49
+ batch = {
50
+ "input_ids": input_ids,
51
+ "input_anchor_ids": input_anchor_ids,
52
+ "pixel_values": pixel_values,
53
+ "mask": mask.unsqueeze(1)
54
+ }
55
+ return batch
56
+
57
+
58
+ class PromptDataset(Dataset):
59
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
60
+
61
+ def __init__(self, prompt, num_samples):
62
+ self.prompt = prompt
63
+ self.num_samples = num_samples
64
+
65
+ def __len__(self):
66
+ return self.num_samples
67
+
68
+ def __getitem__(self, index):
69
+ example = {}
70
+ example["prompt"] = self.prompt[index % len(self.prompt)]
71
+ example["index"] = index
72
+ return example
73
+
74
+
75
+ class CustomDiffusionDataset(Dataset):
76
+ """
77
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
78
+ It pre-processes the images and the tokenizes prompts.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ concepts_list,
84
+ concept_type,
85
+ tokenizer,
86
+ size=512,
87
+ center_crop=False,
88
+ with_prior_preservation=False,
89
+ num_class_images=200,
90
+ hflip=False,
91
+ aug=True,
92
+ ):
93
+ self.size = size
94
+ self.center_crop = center_crop
95
+ self.tokenizer = tokenizer
96
+ self.interpolation = Image.BILINEAR
97
+ self.aug = aug
98
+ self.concept_type = concept_type
99
+
100
+ self.instance_images_path = []
101
+ self.class_images_path = []
102
+ self.with_prior_preservation = with_prior_preservation
103
+ for concept in concepts_list:
104
+ with open(concept["instance_data_dir"], "r") as f:
105
+ inst_images_path = f.read().splitlines()
106
+ with open(concept["instance_prompt"], "r") as f:
107
+ inst_prompt = f.read().splitlines()
108
+ inst_img_path = [(x, y, concept['caption_target'])
109
+ for (x, y) in zip(inst_images_path, inst_prompt)]
110
+ self.instance_images_path.extend(inst_img_path)
111
+
112
+ if with_prior_preservation:
113
+ class_data_root = Path(concept["class_data_dir"])
114
+ if os.path.isdir(class_data_root):
115
+ class_images_path = list(class_data_root.iterdir())
116
+ class_prompt = [concept["class_prompt"]
117
+ for _ in range(len(class_images_path))]
118
+ else:
119
+ with open(class_data_root, "r") as f:
120
+ class_images_path = f.read().splitlines()
121
+ with open(concept["class_prompt"], "r") as f:
122
+ class_prompt = f.read().splitlines()
123
+
124
+ class_img_path = [(x, y) for (x, y) in zip(
125
+ class_images_path, class_prompt)]
126
+ self.class_images_path.extend(
127
+ class_img_path[:num_class_images])
128
+
129
+ random.shuffle(self.instance_images_path)
130
+ self.num_instance_images = len(self.instance_images_path)
131
+ self.num_class_images = len(self.class_images_path)
132
+ self._length = max(self.num_class_images, self.num_instance_images)
133
+ self.flip = transforms.RandomHorizontalFlip(0.5 * hflip)
134
+
135
+ self.image_transforms = transforms.Compose(
136
+ [
137
+ self.flip,
138
+ transforms.Resize(
139
+ size, interpolation=transforms.InterpolationMode.BILINEAR),
140
+ transforms.CenterCrop(
141
+ size) if center_crop else transforms.RandomCrop(size),
142
+ transforms.ToTensor(),
143
+ transforms.Normalize([0.5], [0.5]),
144
+ ]
145
+ )
146
+
147
+ def __len__(self):
148
+ return self._length
149
+
150
+ def preprocess(self, image, scale, resample):
151
+ outer, inner = self.size, scale
152
+ if scale > self.size:
153
+ outer, inner = scale, self.size
154
+ top, left = np.random.randint(
155
+ 0, outer - inner + 1), np.random.randint(0, outer - inner + 1)
156
+ image = image.resize((scale, scale), resample=resample)
157
+ image = np.array(image).astype(np.uint8)
158
+ image = (image / 127.5 - 1.0).astype(np.float32)
159
+ instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32)
160
+ mask = np.zeros((self.size // 8, self.size // 8))
161
+ if scale > self.size:
162
+ instance_image = image[top: top + inner, left: left + inner, :]
163
+ mask = np.ones((self.size // 8, self.size // 8))
164
+ else:
165
+ instance_image[top: top + inner, left: left + inner, :] = image
166
+ mask[top // 8 + 1: (top + scale) // 8 - 1, left //
167
+ 8 + 1: (left + scale) // 8 - 1] = 1.
168
+ return instance_image, mask
169
+
170
+ def __getprompt__(self, instance_prompt, instance_target):
171
+ if self.concept_type == 'style':
172
+ r = np.random.choice([0, 1, 2])
173
+ instance_prompt = f'{instance_prompt}, in the style of {instance_target}' if r == 0 else f'in {instance_target}\'s style, {instance_prompt}' if r == 1 else f'in {instance_target}\'s style, {instance_prompt}'
174
+ elif self.concept_type == 'object':
175
+ anchor, target = instance_target.split('+')
176
+ instance_prompt = instance_prompt.replace(anchor, target)
177
+ elif self.concept_type == 'memorization':
178
+ instance_prompt = instance_target.split('+')[1]
179
+ return instance_prompt
180
+
181
+ def __getitem__(self, index):
182
+ example = {}
183
+ instance_image, instance_prompt, instance_target = self.instance_images_path[
184
+ index % self.num_instance_images]
185
+ instance_image = Image.open(instance_image)
186
+ if not instance_image.mode == "RGB":
187
+ instance_image = instance_image.convert("RGB")
188
+ instance_image = self.flip(instance_image)
189
+ # modify instance prompt according to the concept_type to include target concept
190
+ # multiple style/object fine-tuning
191
+ if ';' in instance_target:
192
+ instance_target = instance_target.split(';')
193
+ instance_target = instance_target[index % len(instance_target)]
194
+
195
+ instance_anchor_prompt = instance_prompt
196
+ instance_prompt = self.__getprompt__(instance_prompt, instance_target)
197
+ # apply resize augmentation and create a valid image region mask
198
+ random_scale = self.size
199
+ if self.aug:
200
+ random_scale = np.random.randint(self.size // 3, self.size + 1) if np.random.uniform(
201
+ ) < 0.66 else np.random.randint(int(1.2 * self.size), int(1.4 * self.size))
202
+ instance_image, mask = self.preprocess(
203
+ instance_image, random_scale, self.interpolation)
204
+
205
+ if random_scale < 0.6 * self.size:
206
+ instance_prompt = np.random.choice(
207
+ ["a far away ", "very small "]) + instance_prompt
208
+ elif random_scale > self.size:
209
+ instance_prompt = np.random.choice(
210
+ ["zoomed in ", "close up "]) + instance_prompt
211
+
212
+ example["instance_images"] = torch.from_numpy(
213
+ instance_image).permute(2, 0, 1)
214
+ example["mask"] = torch.from_numpy(mask)
215
+
216
+ example["instance_prompt_ids"] = self.tokenizer(
217
+ instance_prompt,
218
+ truncation=True,
219
+ padding="max_length",
220
+ max_length=self.tokenizer.model_max_length,
221
+ return_tensors="pt",
222
+ ).input_ids
223
+ example["instance_anchor_prompt_ids"] = self.tokenizer(
224
+ instance_anchor_prompt,
225
+ truncation=True,
226
+ padding="max_length",
227
+ max_length=self.tokenizer.model_max_length,
228
+ return_tensors="pt",
229
+ ).input_ids
230
+
231
+ if self.with_prior_preservation:
232
+ class_image, class_prompt = self.class_images_path[index %
233
+ self.num_class_images]
234
+ class_image = Image.open(class_image)
235
+ if not class_image.mode == "RGB":
236
+ class_image = class_image.convert("RGB")
237
+ example["class_images"] = self.image_transforms(class_image)
238
+ example["class_mask"] = torch.ones_like(example["mask"])
239
+ example["class_prompt_ids"] = self.tokenizer(
240
+ class_prompt,
241
+ truncation=True,
242
+ padding="max_length",
243
+ max_length=self.tokenizer.model_max_length,
244
+ return_tensors="pt",
245
+ ).input_ids
246
+
247
+ return example
248
+
249
+
250
+ def isimage(path):
251
+ if 'png' in path.lower() or 'jpg' in path.lower() or 'jpeg' in path.lower():
252
+ return True
253
+
254
+
255
+ def filter(folder, impath, outpath=None, unfiltered_path=None, threshold=0.15,
256
+ image_threshold=0.5, anchor_size=10, target_size=3, return_score=False):
257
+ model = torch.jit.load(
258
+ "./assets/sscd_imagenet_mixup.torchscript.pt")
259
+ if isinstance(folder, list):
260
+ image_paths = folder
261
+ image_captions = ["None" for _ in range(len(image_paths))]
262
+ elif Path(folder / 'images.txt').exists():
263
+ with open(f'{folder}/images.txt', "r") as f:
264
+ image_paths = f.read().splitlines()
265
+ with open(f'{folder}/caption.txt', "r") as f:
266
+ image_captions = f.read().splitlines()
267
+ else:
268
+ image_paths = [os.path.join(str(folder), file_path)
269
+ for file_path in os.listdir(folder) if isimage(file_path)]
270
+ image_captions = ["None" for _ in range(len(image_paths))]
271
+
272
+ batch = small_288(Image.open(impath).convert('RGB')).unsqueeze(0)
273
+ embedding_target = model(batch)[0, :]
274
+
275
+ filtered_paths = []
276
+ filtered_captions = []
277
+ unfiltered_paths = []
278
+ unfiltered_captions = []
279
+ count_dict = {}
280
+ for im, c in zip(image_paths, image_captions):
281
+ if c not in count_dict:
282
+ count_dict[c] = 0
283
+ if isinstance(folder, list):
284
+ batch = small_288(im).unsqueeze(0)
285
+ else:
286
+ batch = small_288(Image.open(im).convert('RGB')).unsqueeze(0)
287
+ embedding = model(batch)[0, :]
288
+
289
+ diff_sscd = (embedding * embedding_target).sum()
290
+
291
+ if diff_sscd <= image_threshold:
292
+ filtered_paths.append(im)
293
+ filtered_captions.append(c)
294
+ count_dict[c] += 1
295
+ else:
296
+ unfiltered_paths.append(im)
297
+ unfiltered_captions.append(c)
298
+
299
+ # only return score
300
+ if return_score:
301
+ score = len(unfiltered_paths) / \
302
+ (len(unfiltered_paths)+len(filtered_paths))
303
+ return score
304
+
305
+ os.makedirs(outpath, exist_ok=True)
306
+ os.makedirs(f'{outpath}/samples', exist_ok=True)
307
+ with open(f'{outpath}/caption.txt', 'w') as f:
308
+ for each in filtered_captions:
309
+ f.write(each.strip() + '\n')
310
+
311
+ with open(f'{outpath}/images.txt', 'w') as f:
312
+ for each in filtered_paths:
313
+ f.write(each.strip() + '\n')
314
+ imbase = Path(each).name
315
+ shutil.copy(each, f'{outpath}/samples/{imbase}')
316
+
317
+ print('++++++++++++++++++++++++++++++++++++++++++++++++')
318
+ print('+ Filter Summary +')
319
+ print(f'+ Remained images: {len(filtered_paths)}')
320
+ print(f'+ Filtered images: {len(unfiltered_paths)}')
321
+ print('++++++++++++++++++++++++++++++++++++++++++++++++')
322
+
323
+ sorted_list = sorted(list(count_dict.items()),
324
+ key=lambda x: x[1], reverse=True)
325
+ anchor_prompts = [c[0] for c in sorted_list[:anchor_size]]
326
+ target_prompts = [c[0] for c in sorted_list[-target_size:]]
327
+ return anchor_prompts, target_prompts, len(filtered_paths)
328
+
329
+
330
+ def getanchorprompts(pipeline, accelerator, class_prompt, concept_type, class_images_dir, api_key, num_class_images=200, mem_impath=None):
331
+ openai.api_key = api_key
332
+ class_prompt_collection = []
333
+ caption_target = []
334
+ if concept_type == 'object':
335
+ messages = [{"role": "system", "content": "You can describe any image via text and provide captions for wide variety of images that is possible to generate."}]
336
+ messages = [{"role": "user", "content": f"Generate {num_class_images} captions for images containing a {class_prompt}. The caption should also contain the word \"{class_prompt}\" "}]
337
+ while True:
338
+ completion = openai.ChatCompletion.create(
339
+ model="gpt-3.5-turbo",
340
+ messages=messages
341
+ )
342
+ class_prompt_collection += [x for x in completion.choices[0].message.content.lower(
343
+ ).split('\n') if class_prompt in x]
344
+ messages.append(
345
+ {"role": "assistant", "content": completion.choices[0].message.content})
346
+ messages.append(
347
+ {"role": "user", "content": f"Generate {num_class_images-len(class_prompt_collection)} more captions"})
348
+ if len(class_prompt_collection) >= num_class_images:
349
+ break
350
+ class_prompt_collection = clean_prompt(class_prompt_collection)[
351
+ :num_class_images]
352
+
353
+ elif concept_type == 'memorization':
354
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
355
+ pipeline.scheduler.config)
356
+ num_prompts_firstpass = 5
357
+ num_prompts_secondpass = 2
358
+ threshold = 0.3
359
+ # Generate num_prompts_firstpass paraphrases which generate different content at least 1-threshold % of the times.
360
+ os.makedirs(class_images_dir / 'temp/', exist_ok=True)
361
+ class_prompt_collection_counter = []
362
+ caption_target = []
363
+ prev_captions = []
364
+ messages = [{"role": "user", "content": f"Generate {4*num_prompts_firstpass} different paraphrase of the caption: {class_prompt}. Preserve the meaning when paraphrasing."}]
365
+ while True:
366
+ completion = openai.ChatCompletion.create(
367
+ model="gpt-3.5-turbo",
368
+ messages=messages
369
+ )
370
+ # print(completion.choices[0].message.content.lower().split('\n'))
371
+ class_prompt_collection_ = [x.strip(
372
+ ) for x in completion.choices[0].message.content.lower().split('\n') if x.strip() != '']
373
+ class_prompt_collection_ = clean_prompt(class_prompt_collection_)
374
+ # print(class_prompt_collection_)
375
+ for prompt in tqdm(
376
+ class_prompt_collection_, desc="Generating anchor and target prompts ", disable=not accelerator.is_local_main_process
377
+ ):
378
+ print(f'Prompt: {prompt}')
379
+ images = pipeline([prompt]*10, num_inference_steps=25,).images
380
+
381
+ score = filter(images, mem_impath, return_score=True)
382
+ print(f'Memorization rate: {score}')
383
+ if score <= threshold and prompt not in class_prompt_collection and len(class_prompt_collection) < num_prompts_firstpass:
384
+ class_prompt_collection += [prompt]
385
+ class_prompt_collection_counter += [score]
386
+ elif score >= 0.6 and prompt not in caption_target and len(caption_target) < 2:
387
+ caption_target += [prompt]
388
+ if len(class_prompt_collection) >= num_prompts_firstpass and len(caption_target) >= 2:
389
+ break
390
+
391
+ if len(class_prompt_collection) >= num_prompts_firstpass:
392
+ break
393
+ # print("prompts till now", class_prompt_collection, caption_target)
394
+ # print("prompts till now", len(
395
+ # class_prompt_collection), len(caption_target))
396
+ prev_captions += class_prompt_collection_
397
+ prev_captions_ = ','.join(prev_captions[-40:])
398
+
399
+ messages = [
400
+ {"role": "user", "content": f"Generate {4*(num_prompts_firstpass- len(class_prompt_collection))} different paraphrase of the caption: {class_prompt}. Preserve the meaning the most when paraphrasing. Also make sure that the new captions are different from the following captions: {prev_captions_[:4000]}"}]
401
+
402
+ # Generate more paraphrases using the captions we retrieved above.
403
+ for prompt in class_prompt_collection[:num_prompts_firstpass]:
404
+ completion = openai.ChatCompletion.create(
405
+ model="gpt-3.5-turbo",
406
+ messages=[
407
+ {"role": "user", "content": f"Generate {num_prompts_secondpass} different paraphrases of: {prompt}. "}]
408
+
409
+ )
410
+ class_prompt_collection += clean_prompt(
411
+ [x.strip() for x in completion.choices[0].message.content.lower().split('\n') if x.strip() != ''])
412
+
413
+ for prompt in tqdm(class_prompt_collection[num_prompts_firstpass:], desc="Memorization rate for final prompts"):
414
+ images = pipeline([prompt]*10, num_inference_steps=25,).images
415
+
416
+ class_prompt_collection_counter += [
417
+ filter(images, mem_impath, return_score=True)]
418
+
419
+ # select least ten and most memorized text prompts to be selected as anchor and target prompts.
420
+ class_prompt_collection = sorted(
421
+ zip(class_prompt_collection, class_prompt_collection_counter), key=lambda x: x[1])
422
+ caption_target += [x for (x, y) in class_prompt_collection if y >= 0.6]
423
+ class_prompt_collection = [
424
+ x for (x, y) in class_prompt_collection if y <= threshold][:10]
425
+ print("Anchor prompts:", class_prompt_collection)
426
+ print("Target prompts:", caption_target)
427
+ return class_prompt_collection, ';*+'.join(caption_target)
428
+
429
+
430
+ def clean_prompt(class_prompt_collection):
431
+ class_prompt_collection = [re.sub(
432
+ r"[0-9]+", lambda num: '' * len(num.group(0)), prompt) for prompt in class_prompt_collection]
433
+ class_prompt_collection = [re.sub(
434
+ r"^\.+", lambda dots: '' * len(dots.group(0)), prompt) for prompt in class_prompt_collection]
435
+ class_prompt_collection = [x.strip() for x in class_prompt_collection]
436
+ class_prompt_collection = [x.replace('"', '') for x in class_prompt_collection]
437
+ return class_prompt_collection
438
+
439
+
440
+ def safe_dir(dir):
441
+ if not dir.exists():
442
+ dir.mkdir()
443
+ return dir
images/applications.png ADDED

Git LFS Details

  • SHA256: deec6225d4533fb66380321c76a918ff0b9932502192a25f2f32e6f11f2c5db2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.01 MB
models/greg_rutkowski_ablation_delta.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e096c941d90652bbed08e912f899de7aaeb895a3797a11bcaf49a2733580a9d
3
+ size 76685761
models/vangogh_ablation_delta.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b92f3c3a7aa8887c6f358ac73eb6a4d41387c4b4044dceda179a21846948f50
3
+ size 76685761
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ scipy
5
+ accelerate
6
+ modelcards
7
+ transformers>=4.25.1
8
+ diffusers==0.19.0
9
+ tqdm
10
+ openai
11
+ triton
12
+ xformers>=0.0.20
13
+ bitsandbytes
trainer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import PIL.Image
3
+ import shlex
4
+ import shutil
5
+ import subprocess
6
+ from pathlib import Path
7
+ import os
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+
12
+
13
+ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
14
+ w, h = image.size
15
+ if w == h:
16
+ return image
17
+ elif w > h:
18
+ new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
19
+ new_image.paste(image, (0, (w - h) // 2))
20
+ return new_image
21
+ else:
22
+ new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
23
+ new_image.paste(image, ((h - w) // 2, 0))
24
+ return new_image
25
+
26
+
27
+
28
+ def train_submit(
29
+ prompt, anchor_prompt, concept_type, reg_lambda, iterations, lr, openai_key, save_path, mem_impath=None
30
+ ):
31
+ if not torch.cuda.is_available():
32
+ raise gr.Error('CUDA is not available.')
33
+
34
+ torch.cuda.empty_cache()
35
+ original_prompt = prompt
36
+ parameter_group = "cross-attn"
37
+ train_batch_size = 4
38
+ if concept_type == 'style':
39
+ class_data_dir = f'./data/samples_painting/'
40
+ anchor_prompt = f'./assets/painting.txt'
41
+ openai_key = ''
42
+ elif concept_type == 'object':
43
+ os.makedirs('temp', exist_ok=True)
44
+ class_data_dir = f'./temp/{anchor_prompt}'
45
+ name = save_path.split('/')[-1]
46
+ prompt = f'{anchor_prompt}+{prompt}'
47
+ assert openai_key is not None
48
+
49
+ if len(openai_key.split('\n')) > 1:
50
+ openai_key = openai_key.split('\n')
51
+ with open(f'./temp/{name}.txt', 'w') as f:
52
+ for prompt_ in openai_key:
53
+ f.write(prompt_.strip()+'\n')
54
+ openai_key = ''
55
+ anchor_prompt = f'./temp/{name}.txt'
56
+ elif concept_type == 'memorization':
57
+ os.system("wget https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_mixup.torchscript.pt -P assets/")
58
+ os.makedirs('temp', exist_ok=True)
59
+ prompt = f'*+{prompt}'
60
+ name = save_path.split('/')[-1]
61
+ train_batch_size = 1
62
+ lr = 5e-7
63
+ parameter_group = "full-weight"
64
+
65
+ assert openai_key is not None
66
+ assert mem_impath is not None
67
+
68
+ if len(openai_key.split('\n')) > 1:
69
+ openai_key = openai_key.split('\n')
70
+ with open(f'./temp/{name}.txt', 'w') as f:
71
+ for prompt_ in openai_key:
72
+ f.write(prompt_.strip()+'\n')
73
+ openai_key = ''
74
+ anchor_prompt = f'./temp/{name}.txt'
75
+ else:
76
+ anchor_prompt = prompt
77
+
78
+ print(mem_impath)
79
+ image = PIL.Image.open(mem_impath[0][0].name)
80
+ image = pad_image(image)
81
+ image = image.convert('RGB')
82
+ mem_impath = f"./temp/{original_prompt.lower().replace(' ', '')}.jpg"
83
+ image.save(mem_impath, format='JPEG', quality=100)
84
+
85
+ class_data_dir = f"./temp/{original_prompt.lower().replace(' ', '')}"
86
+
87
+
88
+ command = f'''
89
+ accelerate launch concept-ablation-diffusers/train.py \
90
+ --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \
91
+ --output_dir={save_path} \
92
+ --class_data_dir={class_data_dir} \
93
+ --class_prompt="{anchor_prompt}" \
94
+ --caption_target "{prompt}" \
95
+ --concept_type {concept_type} \
96
+ --resolution=512 \
97
+ --train_batch_size={train_batch_size} \
98
+ --learning_rate={lr} \
99
+ --max_train_steps={iterations} \
100
+ --scale_lr --hflip \
101
+ --parameter_group {parameter_group} \
102
+ --openai_key "{openai_key}" \
103
+ --enable_xformers_memory_efficient_attention --num_class_images 500
104
+ '''
105
+
106
+ if concept_type == 'style':
107
+ command += f' --noaug'
108
+
109
+ if concept_type == 'memorization':
110
+ command += f' --use_8bit_adam --with_prior_preservation --prior_loss_weight=1.0 --mem_impath {mem_impath}'
111
+
112
+ with open(f'{save_path}/train.sh', 'w') as f:
113
+ command_s = ' '.join(command.split())
114
+ f.write(command_s)
115
+
116
+ res = subprocess.run(shlex.split(command))
117
+
118
+ if res.returncode == 0:
119
+ result_message = 'Training Completed!'
120
+ else:
121
+ result_message = 'Training Failed!'
122
+ weight_paths = sorted(Path(save_path).glob('*.bin'))
123
+ print(weight_paths)
124
+ return gr.update(value=result_message), weight_paths[0]
125
+
126
+
127
+ def inference(model_path, prompt, n_steps, generator):
128
+ import sys
129
+ sys.path.append('concept-ablation/diffusers/.')
130
+ from model_pipeline import CustomDiffusionPipeline
131
+ import torch
132
+
133
+ pipe = CustomDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda")
134
+ image1 = pipe(prompt, num_inference_steps=n_steps, guidance_scale=6., eta=1., generator=generator).images[0]
135
+
136
+ pipe.load_model(model_path)
137
+ image2 = pipe(prompt, num_inference_steps=n_steps, guidance_scale=6., eta=1., generator=generator).images[0]
138
+
139
+ return image1, image2