Spaces:
Runtime error
Runtime error
Nupur Kumari
commited on
Commit
·
8173ae1
0
Parent(s):
concept ablation
Browse files- .gitattributes +35 -0
- .gitignore +1 -0
- README.md +53 -0
- __init__.py +0 -0
- app.py +233 -0
- assets/painting.txt +200 -0
- concept-ablation-diffusers/LICENSE +21 -0
- concept-ablation-diffusers/model_pipeline.py +237 -0
- concept-ablation-diffusers/train.py +1199 -0
- concept-ablation-diffusers/utils.py +443 -0
- images/applications.png +3 -0
- models/greg_rutkowski_ablation_delta.bin +3 -0
- models/vangogh_ablation_delta.bin +3 -0
- requirements.txt +13 -0
- trainer.py +139 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
images/applications.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
README.md
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Erasing Concepts from Diffusion Models
|
| 3 |
+
emoji: 💡
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.21.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Erasing Concepts from Diffusion Models
|
| 16 |
+
|
| 17 |
+
Project Website [https://erasing.baulab.info](https://erasing.baulab.info) <br>
|
| 18 |
+
Arxiv Preprint [https://arxiv.org/pdf/2303.07345.pdf](https://arxiv.org/pdf/2303.07345.pdf) <br>
|
| 19 |
+
Fine-tuned Weights [https://erasing.baulab.info/weights/esd_models/](https://erasing.baulab.info/weights/esd_models/) <br>
|
| 20 |
+
<div align='center'>
|
| 21 |
+
<img src = 'images/applications.png'>
|
| 22 |
+
</div>
|
| 23 |
+
|
| 24 |
+
Motivated by recent advancements in text-to-image diffusion, we study erasure of specific concepts from the model's weights. While Stable Diffusion has shown promise in producing explicit or realistic artwork, it has raised concerns regarding its potential for misuse. We propose a fine-tuning method that can erase a visual concept from a pre-trained diffusion model, given only the name of the style and using negative guidance as a teacher. We benchmark our method against previous approaches that remove sexually explicit content and demonstrate its effectiveness, performing on par with Safe Latent Diffusion and censored training.
|
| 25 |
+
|
| 26 |
+
To evaluate artistic style removal, we conduct experiments erasing five modern artists from the network and conduct a user study to assess the human perception of the removed styles. Unlike previous methods, our approach can remove concepts from a diffusion model permanently rather than modifying the output at the inference time, so it cannot be circumvented even if a user has access to model weights
|
| 27 |
+
|
| 28 |
+
Given only a short text description of an undesired visual concept and no additional data, our method fine-tunes model weights to erase the targeted concept. Our method can avoid NSFW content, stop imitation of a specific artist's style, or even erase a whole object class from model output, while preserving the model's behavior and capabilities on other topics.
|
| 29 |
+
|
| 30 |
+
## Demo vs github
|
| 31 |
+
|
| 32 |
+
This demo uses an updated implementation from the original Erasing codebase the publication is based from.
|
| 33 |
+
|
| 34 |
+
## Running locally
|
| 35 |
+
|
| 36 |
+
1.) Create an environment using the packages included in the requirements.txt file
|
| 37 |
+
|
| 38 |
+
2.) Run `python app.py`
|
| 39 |
+
|
| 40 |
+
3.) Open the application in browser at `http://127.0.0.1:7860/`
|
| 41 |
+
|
| 42 |
+
4.) Train, evaluate, and save models using our method
|
| 43 |
+
|
| 44 |
+
## Citing our work
|
| 45 |
+
The preprint can be cited as follows
|
| 46 |
+
```
|
| 47 |
+
@article{gandikota2023erasing,
|
| 48 |
+
title={Erasing Concepts from Diffusion Models},
|
| 49 |
+
author={Rohit Gandikota and Joanna Materzy\'nska and Jaden Fiotto-Kaufman and David Bau},
|
| 50 |
+
journal={arXiv preprint arXiv:2303.07345},
|
| 51 |
+
year={2023}
|
| 52 |
+
}
|
| 53 |
+
```
|
__init__.py
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from trainer import train_submit, inference
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
model_map = {'Van Gogh' : 'models/vangogh_ablation_delta.bin',
|
| 9 |
+
'Greg Rutkowski' : 'models/greg_rutkowski_ablation_delta.bin',
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
ORIGINAL_SPACE_ID = 'nupurkmr9/concept-ablation'
|
| 13 |
+
SPACE_ID = os.getenv('SPACE_ID')
|
| 14 |
+
|
| 15 |
+
SHARED_UI_WARNING = f'''## Attention - the demo requires at least 24GB VRAM for style and object removal, 24 GB VRAM for memorized image removal. Please clone this repository to run on your own machine.
|
| 16 |
+
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>. This demo is partly adapted from https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion.
|
| 17 |
+
'''
|
| 18 |
+
|
| 19 |
+
sys.path.append("concept-ablation-diffusers")
|
| 20 |
+
|
| 21 |
+
class Demo:
|
| 22 |
+
|
| 23 |
+
def __init__(self) -> None:
|
| 24 |
+
|
| 25 |
+
self.training = False
|
| 26 |
+
self.generating = False
|
| 27 |
+
|
| 28 |
+
# self.diffuser = StableDiffuser(scheduler='DDIM').to('cuda').eval().half()
|
| 29 |
+
|
| 30 |
+
with gr.Blocks() as demo:
|
| 31 |
+
self.layout()
|
| 32 |
+
demo.queue(concurrency_count=5).launch()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def layout(self):
|
| 36 |
+
|
| 37 |
+
with gr.Row():
|
| 38 |
+
|
| 39 |
+
if SPACE_ID == ORIGINAL_SPACE_ID:
|
| 40 |
+
|
| 41 |
+
self.warning = gr.Markdown(SHARED_UI_WARNING)
|
| 42 |
+
|
| 43 |
+
with gr.Row():
|
| 44 |
+
|
| 45 |
+
with gr.Tab("Test") as inference_column:
|
| 46 |
+
|
| 47 |
+
with gr.Row():
|
| 48 |
+
|
| 49 |
+
self.explain_infr = gr.Markdown(interactive=False,
|
| 50 |
+
value='This is a demo of [Concept Ablation](https://www.cs.cmu.edu/~concept-ablation/). To try out a model where a concept has been erased, select a model and enter any prompt. For example, if you select the model "Van Gogh" you can generate images for the prompt "A portrait in the style of Van Gogh" and compare the ablated and pre-trained models. We have also provided several other pre-fine-tuned models with artistic styles and concepts ablated (Check out the "Ablated Model" drop-down). You can also train and run your own custom models. Check out the "train" section for custom ablation of concepts.')
|
| 51 |
+
|
| 52 |
+
with gr.Row():
|
| 53 |
+
|
| 54 |
+
with gr.Column(scale=1):
|
| 55 |
+
|
| 56 |
+
self.prompt_input_infr = gr.Text(
|
| 57 |
+
placeholder="Enter prompt...",
|
| 58 |
+
label="Prompt",
|
| 59 |
+
info="Prompt to generate"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
with gr.Row():
|
| 63 |
+
|
| 64 |
+
self.model_dropdown = gr.Dropdown(
|
| 65 |
+
label="Ablated Models",
|
| 66 |
+
choices= list(model_map.keys()),
|
| 67 |
+
value='Van Gogh',
|
| 68 |
+
interactive=True
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.seed_infr = gr.Number(
|
| 72 |
+
label="Seed",
|
| 73 |
+
value=42
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
with gr.Column(scale=2):
|
| 77 |
+
|
| 78 |
+
self.infr_button = gr.Button(
|
| 79 |
+
value="Generate",
|
| 80 |
+
interactive=True
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
with gr.Row():
|
| 84 |
+
|
| 85 |
+
self.image_new = gr.Image(
|
| 86 |
+
label="Ablated",
|
| 87 |
+
interactive=False
|
| 88 |
+
)
|
| 89 |
+
self.image_orig = gr.Image(
|
| 90 |
+
label="SD",
|
| 91 |
+
interactive=False
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
with gr.Tab("Train") as training_column:
|
| 95 |
+
|
| 96 |
+
with gr.Row():
|
| 97 |
+
|
| 98 |
+
self.explain_train= gr.Markdown(interactive=False,
|
| 99 |
+
value='In this part you can ablate any concept from Stable Diffusion. Enter the name of the concept and select the kind of concept (e.g. object, style, memorization). You will also need to select a parent anchor concept e.g. cats when ablating grumpy cat, painting when ablating an artists\' style. When ablating a specific object or memorized image, you also need to either provide OpenAI API key or upload a file with 50-200 prompts corresponding to the ablation concept. With default settings, it takes about 20 minutes to fine-tune the model; then you can try inference above or download the weights. The training code used here is slightly different than the code tested in the original paper. Code and details are at [github link](https://github.com/nupurkmr9/concept-ablation).')
|
| 100 |
+
|
| 101 |
+
with gr.Row():
|
| 102 |
+
|
| 103 |
+
with gr.Column(scale=3):
|
| 104 |
+
mem_impath = []
|
| 105 |
+
|
| 106 |
+
self.prompt_input = gr.Text(
|
| 107 |
+
placeholder="Enter concept to remove... e.g. van gogh",
|
| 108 |
+
label="prompt",
|
| 109 |
+
info="Name of the concept to ablate from Model"
|
| 110 |
+
)
|
| 111 |
+
self.anchor_prompt = gr.Text(
|
| 112 |
+
placeholder="Enter anchor concept... e.g. painting",
|
| 113 |
+
label="anchor prompt",
|
| 114 |
+
info="Name of the anchor concept (superset of the concept to be ablated)"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
choices = ['style', 'object', 'memorization']
|
| 118 |
+
|
| 119 |
+
self.concept_type = gr.Dropdown(
|
| 120 |
+
choices=choices,
|
| 121 |
+
value='style',
|
| 122 |
+
label='Ablated concept type',
|
| 123 |
+
info='Ablated concept type'
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.reg_lambda = gr.Number(
|
| 127 |
+
value=0,
|
| 128 |
+
label="Regularization loss",
|
| 129 |
+
info='Whether to add regularization loss on anchor concept. 1.0 when common words in ablated and anchor prompt e.g. grumpy cat and cat'
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.iterations_input = gr.Number(
|
| 133 |
+
value=100,
|
| 134 |
+
precision=0,
|
| 135 |
+
label="Iterations",
|
| 136 |
+
info='iterations used to train'
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.lr_input = gr.Number(
|
| 140 |
+
value=2e-6,
|
| 141 |
+
label="Learning Rate",
|
| 142 |
+
info='Learning rate used to train'
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
visible_openai_key = True
|
| 146 |
+
self.openai_key = gr.Text(
|
| 147 |
+
placeholder="Enter openAI API key or atleast 50 prompts if concept type is object/memorization",
|
| 148 |
+
label="OpenAI API key or Prompts (Required when concept type is object or memorization)",
|
| 149 |
+
info="If concept type is object, we use chatGPT to generate a set of prompts correspondig to the ablation concept. If concept type is memorization, we use ChatGPT to generate paraphrases of the text prompt that generates memorized image. You can either provide the api key or a set of desired prompts (atleast 50). For reference please check example prompts at https://github.com/nupurkmr9/concept-ablation/blob/main/assets/finetune_prompts/ ",
|
| 150 |
+
visible=visible_openai_key
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
visible = True
|
| 154 |
+
mem_impath.append(gr.Files(label=f'''Upload the memorized image if concept type is memorization''', visible=visible))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
with gr.Column(scale=1):
|
| 158 |
+
|
| 159 |
+
self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
|
| 160 |
+
|
| 161 |
+
self.train_button = gr.Button(
|
| 162 |
+
value="Train",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self.download = gr.Files()
|
| 166 |
+
|
| 167 |
+
self.infr_button.click(self.inference, inputs = [
|
| 168 |
+
self.prompt_input_infr,
|
| 169 |
+
self.seed_infr,
|
| 170 |
+
self.model_dropdown
|
| 171 |
+
],
|
| 172 |
+
outputs=[
|
| 173 |
+
self.image_new,
|
| 174 |
+
self.image_orig
|
| 175 |
+
]
|
| 176 |
+
)
|
| 177 |
+
self.train_button.click(self.train, inputs = [
|
| 178 |
+
self.prompt_input,
|
| 179 |
+
self.anchor_prompt,
|
| 180 |
+
self.concept_type,
|
| 181 |
+
self.reg_lambda,
|
| 182 |
+
self.iterations_input,
|
| 183 |
+
self.lr_input,
|
| 184 |
+
self.openai_key,
|
| 185 |
+
] + mem_impath,
|
| 186 |
+
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def train(self, prompt, anchor_prompt, concept_type, reg_lambda, iterations, lr, openai_key, *inputs):
|
| 190 |
+
self.train_status.update(value='')
|
| 191 |
+
if self.training:
|
| 192 |
+
return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
|
| 193 |
+
|
| 194 |
+
randn = torch.randint(1, 10000000, (1,)).item()
|
| 195 |
+
|
| 196 |
+
save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}"
|
| 197 |
+
os.makedirs(save_path, exist_ok=True)
|
| 198 |
+
self.training = True
|
| 199 |
+
mem_impath = inputs[:1]
|
| 200 |
+
train_submit(prompt, anchor_prompt, concept_type, reg_lambda, iterations, lr, openai_key, save_path, mem_impath)
|
| 201 |
+
|
| 202 |
+
self.training = False
|
| 203 |
+
|
| 204 |
+
torch.cuda.empty_cache()
|
| 205 |
+
|
| 206 |
+
modelpath = sorted(Path(save_path).glob('*.bin'))[0]
|
| 207 |
+
model_map[f"Custom_{prompt.lower().replace(' ', '')}"] = modelpath
|
| 208 |
+
|
| 209 |
+
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), modelpath, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
|
| 213 |
+
|
| 214 |
+
seed = seed or 42
|
| 215 |
+
n_steps = 50
|
| 216 |
+
|
| 217 |
+
generator = torch.manual_seed(seed)
|
| 218 |
+
|
| 219 |
+
model_path = model_map[model_name]
|
| 220 |
+
|
| 221 |
+
torch.cuda.empty_cache()
|
| 222 |
+
|
| 223 |
+
generator = torch.manual_seed(seed)
|
| 224 |
+
|
| 225 |
+
orig_image, edited_image = inference(model_path, prompt, n_steps, generator)
|
| 226 |
+
|
| 227 |
+
torch.cuda.empty_cache()
|
| 228 |
+
|
| 229 |
+
return edited_image, orig_image
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
demo = Demo()
|
| 233 |
+
|
assets/painting.txt
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Purple Paints
|
| 2 |
+
Painting a Seascape in Water color part 1
|
| 3 |
+
Why Arts & Crafts Are Beneficial to Early Childhood Education, Creve Coeur, Missouri
|
| 4 |
+
Simple Cherry Blossom Tree Painting images
|
| 5 |
+
Diy Wall Painting, Mural Painting, Paintings, Murals For Kids, Kids Wall Murals, Watercolor Walls, Mural Wall Art, Diy Canvas Art, Art Abstrait
|
| 6 |
+
Watch Doris Allen paints Roses on porcelain, China Painter Doris Drew Allen
|
| 7 |
+
Oil Painting with Kelsey May Connor in the tradition of Classical Realism | Art Classes Malta
|
| 8 |
+
Peinture LHuile Fruits Ltude Youtube
|
| 9 |
+
coffee painting canvas diy acrylic step by step for beginners #stepbysteppainting #tracie
|
| 10 |
+
Sara Barcus paints Bob Ross
|
| 11 |
+
Stock Video Footage of Little Baby Girl Paints Water Colors
|
| 12 |
+
"Jill Baker working on her painting ""Lady In Waiting"" (2016)"
|
| 13 |
+
Children painting a wall
|
| 14 |
+
Close up shot of the process of mixing oil paint with palette knife Stock Footage
|
| 15 |
+
toile : A young blonde woman draws a tree with paints Vidéos Libres De Droits
|
| 16 |
+
Acrylic Paint Grass - How To Paint Grass - with Acrylic
|
| 17 |
+
Snowfield Sentinel - En Plein Air Oil Demonstration
|
| 18 |
+
Artist painting A Stock Photography
|
| 19 |
+
peindre l 39 aquarelle youtube
|
| 20 |
+
Monet inspired Water Lily Pond - good way to practice the art of shadows
|
| 21 |
+
щетка для волос : A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life Стоковые видеозаписи
|
| 22 |
+
Wall Painting Techniques Youtube
|
| 23 |
+
Painting under the pergola, Yeniköy 2014
|
| 24 |
+
Practical Mom: Bleeding Crepe Paper Art Project for Little Kids
|
| 25 |
+
painter man at work with paint brush, easel, canvas and...
|
| 26 |
+
Watercolor landscape painting
|
| 27 |
+
How to paint birds in watercolour
|
| 28 |
+
Acrylic Painting Techniques - How to Paint Flowers - Parrot Tulip
|
| 29 |
+
Painting cars toddler activity
|
| 30 |
+
A painting student with her canvas at Dickinson Farm.
|
| 31 |
+
Blonde woman in denim clothing finding creativity by painting a canvas hung on white wall.
|
| 32 |
+
beautiful young woman painter at work,
|
| 33 |
+
which : A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life Stock Footage
|
| 34 |
+
how-to-paint-a-crown-crested-crane-in-acrylic-background
|
| 35 |
+
Slide Painting: A super fun indoor or outdoor process art activity for toddlers or preschoolers! Use cars, balls, or anything that rolls!
|
| 36 |
+
Poppy Spree - Week 2 Painting
|
| 37 |
+
How-To Make Colored Milk Explosion
|
| 38 |
+
Abstract Acrylic Painting Demo - Watercolor Look - Color Explosion - Easy how to abstract
|
| 39 |
+
Seine artist by triciamary
|
| 40 |
+
Free Painting Lessons - Cherry Blossom Bridge - Commentary by Acrylic Artist Brandon Schaefer
|
| 41 |
+
3 blue trees on ice - painting video step by step demonstration
|
| 42 |
+
Martin Grealish painting in Venice.jpg
|
| 43 |
+
Live Painter for Elegant Wedding
|
| 44 |
+
Long-haired artist paints on canvas
|
| 45 |
+
Powe Painting Co Cover Photo
|
| 46 |
+
Live Painting Part 1 (how to begin)
|
| 47 |
+
which : A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life Stock Footage
|
| 48 |
+
Artist paints with oil paints
|
| 49 |
+
Stock Video Footage of hand of a young woman painting in purple
|
| 50 |
+
Watercolour tutorial - Summer shadows, cottages and flowers
|
| 51 |
+
Scrape pastel with knife to release a snowfall!
|
| 52 |
+
Acrylic pour painting process video
|
| 53 |
+
Woman painting stock photo
|
| 54 |
+
student painting at the Sainsbury Centre
|
| 55 |
+
Honfleur Painter
|
| 56 |
+
Bald Eagle in Wetlands Original Mini Painting on Easel
|
| 57 |
+
painting in nature things to do in virginia beach
|
| 58 |
+
Time lapse aquarelle painting Stock Footage
|
| 59 |
+
Watercolour Aquarelle Poppies Poppy Painting Demo Youtube
|
| 60 |
+
Wedding Artist Ben Keys of Wed on Canvas painting live at wedding reception. // The Graystone Inn, Wilmington, NC // Photo Courtesy of Blueberry Creative
|
| 61 |
+
NIOS - 225 Painting - Guide Book For Class 10th - English Medium
|
| 62 |
+
Fearless Watercolor #1 Frozen Marsh Full Watercolor Demonstration Music Ted Yoder
|
| 63 |
+
Oil Painting with Kelsey May Connor in the tradition of Classical Realism | Art Classes Malta
|
| 64 |
+
Michael Chambers, painting in process
|
| 65 |
+
KNIFE PAINTING - WATER LILY by NELLY LESTRADE (english subtitles)
|
| 66 |
+
A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life
|
| 67 |
+
Sip & Sketch at Acquiesce Winery!
|
| 68 |
+
painting my bedroom wall youtube how to paint a galaxy wall mural in a spaceship themed
|
| 69 |
+
Flat art painter workshop with paint supplies equipment tools ba Stock Illustration
|
| 70 |
+
Gold Leaf Gilding in Westminster London
|
| 71 |
+
Commission a portrait-artist-brisbane-perth-sydney-melbourne
|
| 72 |
+
My very first brushstrokes as a mural painter
|
| 73 |
+
Stock Video Footage of Little Baby Girl Paints Water Colors
|
| 74 |
+
bing images of flower portraits watercolor by billy showell | VIDEO :: DVD - Watercolour Flower Portraits with Billy Showell
|
| 75 |
+
Haidee-Jo Summers painting the winning painting at Paint Out Norwich 2014
|
| 76 |
+
Artist paints a picture of oil paint brush in hand with palette Stock Footage
|
| 77 |
+
img 9546 Sam Flores, recap: live painting in Golden Gate Park upper playground Sam Flores power to the peaceful paint live painting graffiti fifty24sf gallery art
|
| 78 |
+
Artist at Michael Orwick (DragonFire Gallery) Stormy Weather Arts Festival Quick Draw event at Tolovana Inn
|
| 79 |
+
Painting at the Tower - Stephen B Whatley, 2000 (Stephen B Whatley) Tags: uk windows england orange detail building london art architecture painting studio paint artist 2000 vibrant perspective millenium canvas creation painter expressionism expressionist toweroflondon palette damncool stkatharinedocks mywinners platinumphoto goldstaraward stephenbwhatley
|
| 80 |
+
Video aula curso Pintura em tela com espatula part 1
|
| 81 |
+
"getlinkyoutube.com-LIBERO! Full video ""orchidee"" dal artista Igor Sakharov"
|
| 82 |
+
Weekly Specials Painting: Royal Paint By Number Mini Junior Dolphins
|
| 83 |
+
Woman artist painting landscape outdoor Stock Footage
|
| 84 |
+
painting styrofoam trays for a spring toddler art project
|
| 85 |
+
Madinah Wilson, with the Biden Institute, signs a poster on the Green at the University of Delaware on Tuesday during an event for National Voter Registration Day that got students visiting the Green to register to vote.
|
| 86 |
+
female hands painting a water landscape with oil paints using palette knife. 4k - cavalletto attrezzatura per arti e mestieri video stock e b–roll
|
| 87 |
+
Art studio #miniature. Dioramas and Clever Things Dollhouse Miniatures and Accessories
|
| 88 |
+
INDIANA - BOB ROSS EXHIBIT - 1-29-2021 (1).jpeg
|
| 89 |
+
How to paint a CLOUDY sky tutorial - Jennings644
|
| 90 |
+
Catherine Soucy
|
| 91 |
+
Pin By Marionette Taboniar On Art Videos By Marionette
|
| 92 |
+
coffee painting canvas diy acrylic step by step for beginners #stepbysteppainting #tracie
|
| 93 |
+
Anita Jansen Youtube Watercolor Art Watercolor Trees Art
|
| 94 |
+
Beautiful young woman painter at work, on grey background — Stock Photo
|
| 95 |
+
In my natural habitat 🎨🎨 ~ I'm currently accepting tea recommendations. Tell me you're favourite teas(or hot drinks) in my stories! ☕️☕️ ~ #creatingsunflowers #mystudio #markmaking #naturallycurlyhair #curlybeauties #homeinthestudio #coloraddict #strongfit #fitnessenthusiast #collectart #originalpainting #colormehappy #workingprocess #intuitiveart #decorativeart #girlgains #worksonpaper #modernartist #womenwhopaint #girlgains
|
| 96 |
+
Judith paints in a Bob Ross style.
|
| 97 |
+
Royal & Langnickel: Painting by Number (Splish Splash)
|
| 98 |
+
Close up of a needlework landscape being made at hong ngoc handicraft center Stock Footage
|
| 99 |
+
Watercolour Tutorial Blea Tarn Cottage Lake District
|
| 100 |
+
Squishy Painting from make-it-your-own.com (Crafts & activities for kids)
|
| 101 |
+
Painter-girl en plein air
|
| 102 |
+
Kids painting art outdoor activity, montessori homeschooling education
|
| 103 |
+
HM Altered Book Project Page 3 - Art Journaling - Mixed Media - How to
|
| 104 |
+
Fototapete - Talented male Artist Works on Abstract Oil Painting, with Broad Strokes of Paint Brush he Creates Modern Masterpiece. Dark and Messy Creative Studio where Large Canvas Stands on Easel Illuminated
|
| 105 |
+
How I Start My Plein Air Paintings
|
| 106 |
+
acrylic painting how to step by step flower painting mixed media flower heads 7 acrylic
|
| 107 |
+
Oil painting classes oil painting classes in Ameerpet institution for oil painting classes in hyderabad canvas painting classes in hyderabad Oil paintings for sales oil painting images oil painting in canvas
|
| 108 |
+
палитра : Young Beautiful Female Artist is in an Art Studio, Sitting Behind an Easel and Painting on Canvas. Drawing Process: in the Art Studio of the Artists Hand Art Girl with a Brush Painting on Canvas.4K Стоковые видеозаписи
|
| 109 |
+
What is Plein Air Painting like in Ireland?
|
| 110 |
+
Seven Tips for Setting up an Impromptu Garden Art Studio
|
| 111 |
+
young woman painting - artist stock videos & royalty-free footage
|
| 112 |
+
Pébéo - Mixed Media: mixing Studio Acrylics, Vitrail et Fantasy colors on a 3D frame - YouTube
|
| 113 |
+
A painter To paint
|
| 114 |
+
How to paint a pumpkin canvas, art skills not required!!!
|
| 115 |
+
http://funkidos.com/pictures-world/art-world/hyper-realistic-paintings-on-wood
|
| 116 |
+
Web Design for Oil Painters
|
| 117 |
+
Painters palette with oil paints Stock Footage
|
| 118 |
+
live-wedding-painting-by-ben-keys-live-event-artist-of-wed-on-canvas-garden-wedding-gibbes-museum-pure-luxe-bride-for-the-knot
|
| 119 |
+
Lee's Painting logo
|
| 120 |
+
Sir Winston Churchill painting
|
| 121 |
+
The girl takes blue oil paint from the palette and spreads it on the canvas with Stock Footage
|
| 122 |
+
Van Gogh Sunflowers | You Paint the Masters
|
| 123 |
+
Painting lesson high school - tutorial de pintura para secundaria
|
| 124 |
+
Crayon Art... now this is even cooler than the other kind of
|
| 125 |
+
CoachingMasteryStudio-940-v1-png
|
| 126 |
+
finished tree with pink flowers and within the circle on a gold background
|
| 127 |
+
3 blue trees on ice - painting video step by step demonstration
|
| 128 |
+
Talented male Artist Works on Abstract Oil Painting, with Broad Strokes of Paint Brush he Creates Modern Masterpiece. Dark and Messy Creative Studio where Large Canvas Stands on Easel Illuminated | Shutterstock HD Video #1036269983
|
| 129 |
+
magic painting logo designs
|
| 130 |
+
Workshop artistic, painting pigments color palette, color picker
|
| 131 |
+
Toddler Art — Autumn Tree Painting
|
| 132 |
+
Children painting pottery 2 Royalty Free Stock Photography
|
| 133 |
+
4K Talented young graffiti street artist working on a mural in urban area Stock Footage
|
| 134 |
+
fairytale castle mural for kids bedroom
|
| 135 |
+
how to paint how to paint clouds in acrylic instructional painting
|
| 136 |
+
A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life
|
| 137 |
+
The girl draws a picture by numbers,with yellow acrylic paints. Hobby for adults Stock Footage
|
| 138 |
+
Sahara Mural (Jamila BC) Tags: cloud art sahara painting twilight mural desert dusk jamila jamilaproductions
|
| 139 |
+
ressam : Close Up of Little Girl Drawing in Nature with Teacher. Art School Classes in the City Park. Slow Motion. Creativity Inspiration Expression Concept Stok Video
|
| 140 |
+
Watercolor Tutorial Mountains | blue ridge mountains easy step by step watercolor
|
| 141 |
+
Stock Video Footage of Street painting
|
| 142 |
+
which : A pretty woman with red hair, paints a picture on canvas, which stands on the easel. The lady is in the open air near the lake of the river, she draws from life Stock Footage
|
| 143 |
+
Palette, canvas and easel Stock Image
|
| 144 |
+
A woman paints a landscape with oil paints with a palette and brush. Slider shot Stock Footage
|
| 145 |
+
1280x720 Blue Hydrangea Watercolor Doodle On Cotton Xuan Rice Paper By Amy
|
| 146 |
+
Zero technique in watercolour but had loads of fun
|
| 147 |
+
Taos - Plein Air Painter
|
| 148 |
+
Artist drawing picture of a city landscape Stock Footage
|
| 149 |
+
"""Painters and paintings"""
|
| 150 |
+
Puff, Pass and Paint- 420-friendly painting in Orange County! 21+ tickets
|
| 151 |
+
Melted Crayon Art | Unsimple Living
|
| 152 |
+
Live Online Art Classes Webinar Oil Painting Sunset Pt1
|
| 153 |
+
finger paint col.jpg
|
| 154 |
+
Painting and wine class
|
| 155 |
+
Artist Nancy Tankersley works on a plein air oil study of the landscape at King's Creek/Choptank Wetlands preserve.
|
| 156 |
+
Using the palette knife to scrape away paint
|
| 157 |
+
Close-up shot of a woman artist making brush strokes with oil on canvas. Impressionist painter working en plein air
|
| 158 |
+
Students create paintings inspired by VanGogh's textured impasto
|
| 159 |
+
Close-up shot of impressionist painter making final brush strokes during plein air painting
|
| 160 |
+
Preschool Painting Activity Royalty Free Stock Images
|
| 161 |
+
Advance Paint Pour Workshops
|
| 162 |
+
Painting By Numbers Poppy Field
|
| 163 |
+
Artist paints a picture of oil paint brush in hand with palette. Slider shot Stock Footage
|
| 164 |
+
Art painting hobby leisure girl drawing picture. Art painting hobby. creative leisure. girl drawing a picture. talent inspiration creation and self expression stock photography
|
| 165 |
+
1.5-Hour Children's Creative Art Session for 1 Child
|
| 166 |
+
Painting the Nude in Watercolour Demonstration by Anthony Barrow BA MIFL
|
| 167 |
+
The cherry blossoms in the Mt. Fuji Acrylic Painting Buying a new proven fact that everyone in your family will like doing together outside? Canvas Painting Tutorials, Easy Canvas Painting, Easy Paintings, Amazing Paintings, Mini Paintings, Pour Painting, Painting Videos, Painting Lessons, Small Canvas Art
|
| 168 |
+
paintings in oil
|
| 169 |
+
Printable Wall Murals japanese cherry blossoms by sole junkie youtube
|
| 170 |
+
How to paint a Rose in watercolors, the impressionistic way
|
| 171 |
+
Artist�s hand with paintbrush painting the picture photo
|
| 172 |
+
Childrens wall murals the muralist kids custom art for Equestrian wall mural
|
| 173 |
+
St Patrick's Day Toddler Craft - Toilet paper shamrocks + pop up art
|
| 174 |
+
Como pintar estrada, cerca e rvore - How to paint tree - how to paint landscape
|
| 175 |
+
How to paint tree branches - Painting Tutorial
|
| 176 |
+
spraying with a spray bottle
|
| 177 |
+
Miniature paintings on the ground, a painting on the wall and a painting on an easel of a woman
|
| 178 |
+
At the hairdresser on the market
|
| 179 |
+
Paint With Kevin Hill Snow Covered Valley
|
| 180 |
+
C. Rasmussen in her Los Angeles studio, 2018.
|
| 181 |
+
Stock Video Footage of View from behind on child paints a sheet of paper by hand
|
| 182 |
+
Sandra Gal seen through a ring flash tube as she works on her latest painting.
|
| 183 |
+
Watercolour Art Classes
|
| 184 |
+
Painting icon. Stock Footage
|
| 185 |
+
Watercolour Trees - How To Paint A Tree In Watercolour
|
| 186 |
+
Watercolor with Birgit O'Connor: Within the Flower
|
| 187 |
+
London-Fine-Art-Studios-Painting.png
|
| 188 |
+
Beautiful young woman painter at work, isolated on white
|
| 189 |
+
Kindergarten and Day Care Cartoon Wall Art Work In Mumbai
|
| 190 |
+
Close-up of woman artist painting still life picture on canvas in art studio
|
| 191 |
+
Modified Paint.Net logo
|
| 192 |
+
Im Grand Canyon wird Kunst zelebriert. - Foto: Arizona Office of Tourism
|
| 193 |
+
Melted crayons art
|
| 194 |
+
Woman paints a picture outdoors.
|
| 195 |
+
Snow Leopard Cub - Oils
|
| 196 |
+
Stock Video Footage of close up of a brush that paints a canvas - painter
|
| 197 |
+
Woman hand painting Stock Photo
|
| 198 |
+
Watercolour Aquarelle Demo Poppies Coquelicots Aquarelle By
|
| 199 |
+
painting 9 year 7 year pop painting american doll saige
|
| 200 |
+
Contemporary man with smartphone taking photograph of the painting on easel
|
concept-ablation-diffusers/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 pix2pixzero
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
concept-ablation-diffusers/model_pipeline.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from accelerate.logging import get_logger
|
| 5 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
| 6 |
+
from diffusers.models.cross_attention import CrossAttention
|
| 7 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
|
| 8 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
| 9 |
+
StableDiffusionSafetyChecker,
|
| 10 |
+
)
|
| 11 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 12 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 13 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
| 14 |
+
|
| 15 |
+
if is_xformers_available():
|
| 16 |
+
import xformers
|
| 17 |
+
import xformers.ops
|
| 18 |
+
else:
|
| 19 |
+
xformers = None
|
| 20 |
+
|
| 21 |
+
logger = get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def set_use_memory_efficient_attention_xformers(
|
| 25 |
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
| 26 |
+
):
|
| 27 |
+
if use_memory_efficient_attention_xformers:
|
| 28 |
+
if self.added_kv_proj_dim is not None:
|
| 29 |
+
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
| 30 |
+
# which uses this type of cross attention ONLY because the attention mask of format
|
| 31 |
+
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
| 32 |
+
raise NotImplementedError(
|
| 33 |
+
"Memory efficient attention with `xformers` is currently not supported when"
|
| 34 |
+
" `self.added_kv_proj_dim` is defined."
|
| 35 |
+
)
|
| 36 |
+
elif not is_xformers_available():
|
| 37 |
+
raise ModuleNotFoundError(
|
| 38 |
+
(
|
| 39 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
| 40 |
+
" xformers"
|
| 41 |
+
),
|
| 42 |
+
name="xformers",
|
| 43 |
+
)
|
| 44 |
+
elif not torch.cuda.is_available():
|
| 45 |
+
raise ValueError(
|
| 46 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
| 47 |
+
" only available for GPU "
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
try:
|
| 51 |
+
# Make sure we can run the memory efficient attention
|
| 52 |
+
_ = xformers.ops.memory_efficient_attention(
|
| 53 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 54 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 55 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 56 |
+
)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
raise e
|
| 59 |
+
|
| 60 |
+
processor = CustomDiffusionXFormersAttnProcessor(
|
| 61 |
+
attention_op=attention_op)
|
| 62 |
+
else:
|
| 63 |
+
processor = CustomDiffusionAttnProcessor()
|
| 64 |
+
|
| 65 |
+
self.set_processor(processor)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class CustomDiffusionAttnProcessor:
|
| 69 |
+
def __call__(
|
| 70 |
+
self,
|
| 71 |
+
attn: CrossAttention,
|
| 72 |
+
hidden_states,
|
| 73 |
+
encoder_hidden_states=None,
|
| 74 |
+
attention_mask=None,
|
| 75 |
+
):
|
| 76 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 77 |
+
attention_mask = attn.prepare_attention_mask(
|
| 78 |
+
attention_mask, sequence_length, batch_size)
|
| 79 |
+
query = attn.to_q(hidden_states)
|
| 80 |
+
|
| 81 |
+
crossattn = False
|
| 82 |
+
if encoder_hidden_states is None:
|
| 83 |
+
encoder_hidden_states = hidden_states
|
| 84 |
+
else:
|
| 85 |
+
crossattn = True
|
| 86 |
+
if attn.cross_attention_norm:
|
| 87 |
+
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
| 88 |
+
|
| 89 |
+
key = attn.to_k(encoder_hidden_states)
|
| 90 |
+
value = attn.to_v(encoder_hidden_states)
|
| 91 |
+
if crossattn:
|
| 92 |
+
detach = torch.ones_like(key)
|
| 93 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.
|
| 94 |
+
key = detach * key + (1 - detach) * key.detach()
|
| 95 |
+
value = detach * value + (1 - detach) * value.detach()
|
| 96 |
+
|
| 97 |
+
query = attn.head_to_batch_dim(query)
|
| 98 |
+
key = attn.head_to_batch_dim(key)
|
| 99 |
+
value = attn.head_to_batch_dim(value)
|
| 100 |
+
|
| 101 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 102 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 103 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 104 |
+
|
| 105 |
+
# linear proj
|
| 106 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 107 |
+
# dropout
|
| 108 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 109 |
+
|
| 110 |
+
return hidden_states
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class CustomDiffusionXFormersAttnProcessor:
|
| 114 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
| 115 |
+
self.attention_op = attention_op
|
| 116 |
+
|
| 117 |
+
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 118 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 119 |
+
|
| 120 |
+
attention_mask = attn.prepare_attention_mask(
|
| 121 |
+
attention_mask, sequence_length, batch_size)
|
| 122 |
+
|
| 123 |
+
query = attn.to_q(hidden_states)
|
| 124 |
+
|
| 125 |
+
crossattn = False
|
| 126 |
+
if encoder_hidden_states is None:
|
| 127 |
+
encoder_hidden_states = hidden_states
|
| 128 |
+
else:
|
| 129 |
+
crossattn = True
|
| 130 |
+
if attn.cross_attention_norm:
|
| 131 |
+
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
| 132 |
+
|
| 133 |
+
key = attn.to_k(encoder_hidden_states)
|
| 134 |
+
value = attn.to_v(encoder_hidden_states)
|
| 135 |
+
if crossattn:
|
| 136 |
+
detach = torch.ones_like(key)
|
| 137 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.
|
| 138 |
+
key = detach * key + (1 - detach) * key.detach()
|
| 139 |
+
value = detach * value + (1 - detach) * value.detach()
|
| 140 |
+
|
| 141 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
| 142 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
| 143 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
| 144 |
+
|
| 145 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
| 146 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
| 147 |
+
)
|
| 148 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 149 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 150 |
+
|
| 151 |
+
# linear proj
|
| 152 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 153 |
+
# dropout
|
| 154 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 155 |
+
return hidden_states
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class CustomDiffusionPipeline(StableDiffusionPipeline):
|
| 159 |
+
r"""
|
| 160 |
+
Pipeline for custom diffusion model.
|
| 161 |
+
|
| 162 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 163 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.).
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
vae ([`AutoencoderKL`]):
|
| 167 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 168 |
+
text_encoder ([`CLIPTextModel`]):
|
| 169 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
| 170 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 171 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 172 |
+
tokenizer (`CLIPTokenizer`):
|
| 173 |
+
Tokenizer of class
|
| 174 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 175 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
| 176 |
+
scheduler ([`SchedulerMixin`]):
|
| 177 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
| 178 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 179 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 180 |
+
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
| 181 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
| 182 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
| 183 |
+
modifier_token_id: list of id of tokens related to the target concept that are modified when ablated.
|
| 184 |
+
"""
|
| 185 |
+
_optional_components = ["safety_checker",
|
| 186 |
+
"feature_extractor", "modifier_token_id"]
|
| 187 |
+
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
vae: AutoencoderKL,
|
| 191 |
+
text_encoder: CLIPTextModel,
|
| 192 |
+
tokenizer: CLIPTokenizer,
|
| 193 |
+
unet: UNet2DConditionModel,
|
| 194 |
+
scheduler: SchedulerMixin,
|
| 195 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 196 |
+
feature_extractor: CLIPFeatureExtractor,
|
| 197 |
+
requires_safety_checker: bool = True,
|
| 198 |
+
modifier_token_id: list = [],
|
| 199 |
+
):
|
| 200 |
+
super().__init__(vae,
|
| 201 |
+
text_encoder,
|
| 202 |
+
tokenizer,
|
| 203 |
+
unet,
|
| 204 |
+
scheduler,
|
| 205 |
+
safety_checker,
|
| 206 |
+
feature_extractor,
|
| 207 |
+
requires_safety_checker)
|
| 208 |
+
|
| 209 |
+
self.modifier_token_id = modifier_token_id
|
| 210 |
+
|
| 211 |
+
def save_pretrained(self, save_path, parameter_group="cross-attn", all=False):
|
| 212 |
+
if all:
|
| 213 |
+
super().save_pretrained(save_path)
|
| 214 |
+
else:
|
| 215 |
+
delta_dict = {'unet': {}}
|
| 216 |
+
if parameter_group == 'embedding':
|
| 217 |
+
delta_dict['text_encoder'] = self.text_encoder.state_dict()
|
| 218 |
+
for name, params in self.unet.named_parameters():
|
| 219 |
+
if parameter_group == "cross-attn":
|
| 220 |
+
if 'attn2.to_k' in name or 'attn2.to_v' in name:
|
| 221 |
+
delta_dict['unet'][name] = params.cpu().clone()
|
| 222 |
+
elif parameter_group == "full-weight":
|
| 223 |
+
delta_dict['unet'][name] = params.cpu().clone()
|
| 224 |
+
else:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
"parameter_group argument only supports one of [cross-attn, full-weight, embedding]"
|
| 227 |
+
)
|
| 228 |
+
torch.save(delta_dict, save_path)
|
| 229 |
+
|
| 230 |
+
def load_model(self, save_path):
|
| 231 |
+
st = torch.load(save_path)
|
| 232 |
+
print(st.keys())
|
| 233 |
+
if 'text_encoder' in st:
|
| 234 |
+
self.text_encoder.load_state_dict(st['text_encoder'])
|
| 235 |
+
for name, params in self.unet.named_parameters():
|
| 236 |
+
if name in st['unet']:
|
| 237 |
+
params.data.copy_(st['unet'][f'{name}'])
|
concept-ablation-diffusers/train.py
ADDED
|
@@ -0,0 +1,1199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code is modified from the Huggingface repository: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py, and
|
| 2 |
+
import argparse
|
| 3 |
+
import hashlib
|
| 4 |
+
import itertools
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import warnings
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import torch.utils.checkpoint
|
| 16 |
+
import transformers
|
| 17 |
+
from accelerate import Accelerator
|
| 18 |
+
from accelerate.logging import get_logger
|
| 19 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 20 |
+
from huggingface_hub import HfApi, create_repo
|
| 21 |
+
from model_pipeline import (
|
| 22 |
+
CustomDiffusionAttnProcessor,
|
| 23 |
+
CustomDiffusionPipeline,
|
| 24 |
+
set_use_memory_efficient_attention_xformers,
|
| 25 |
+
)
|
| 26 |
+
from packaging import version
|
| 27 |
+
from tqdm.auto import tqdm
|
| 28 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
| 29 |
+
from utils import (
|
| 30 |
+
CustomDiffusionDataset,
|
| 31 |
+
PromptDataset,
|
| 32 |
+
collate_fn,
|
| 33 |
+
filter,
|
| 34 |
+
getanchorprompts,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
import diffusers
|
| 38 |
+
from diffusers import (
|
| 39 |
+
AutoencoderKL,
|
| 40 |
+
DDPMScheduler,
|
| 41 |
+
DiffusionPipeline,
|
| 42 |
+
DPMSolverMultistepScheduler,
|
| 43 |
+
UNet2DConditionModel,
|
| 44 |
+
)
|
| 45 |
+
from diffusers.models.cross_attention import CrossAttention
|
| 46 |
+
from diffusers.optimization import get_scheduler
|
| 47 |
+
from diffusers.utils import check_min_version, is_wandb_available
|
| 48 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 49 |
+
|
| 50 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 51 |
+
check_min_version("0.14.0")
|
| 52 |
+
|
| 53 |
+
logger = get_logger(__name__)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def create_custom_diffusion(unet, parameter_group):
|
| 57 |
+
for name, params in unet.named_parameters():
|
| 58 |
+
if parameter_group == "cross-attn":
|
| 59 |
+
if 'attn2.to_k' in name or 'attn2.to_v' in name:
|
| 60 |
+
params.requires_grad = True
|
| 61 |
+
else:
|
| 62 |
+
params.requires_grad = False
|
| 63 |
+
elif parameter_group == 'full-weight':
|
| 64 |
+
params.requires_grad = True
|
| 65 |
+
elif parameter_group == 'embedding':
|
| 66 |
+
params.requires_grad = False
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
"parameter_group argument only cross-attn, full-weight, embedding"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# change attn class
|
| 73 |
+
def change_attn(unet):
|
| 74 |
+
for layer in unet.children():
|
| 75 |
+
if type(layer) == CrossAttention:
|
| 76 |
+
bound_method = set_use_memory_efficient_attention_xformers.__get__(
|
| 77 |
+
layer, layer.__class__)
|
| 78 |
+
setattr(
|
| 79 |
+
layer, 'set_use_memory_efficient_attention_xformers', bound_method)
|
| 80 |
+
else:
|
| 81 |
+
change_attn(layer)
|
| 82 |
+
|
| 83 |
+
change_attn(unet)
|
| 84 |
+
unet.set_attn_processor(CustomDiffusionAttnProcessor())
|
| 85 |
+
return unet
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
|
| 89 |
+
img_str = ""
|
| 90 |
+
for i, image in enumerate(images):
|
| 91 |
+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
| 92 |
+
img_str += f"./image_{i}.png\n"
|
| 93 |
+
|
| 94 |
+
yaml = f"""
|
| 95 |
+
---
|
| 96 |
+
license: creativeml-openrail-m
|
| 97 |
+
base_model: {base_model}
|
| 98 |
+
instance_prompt: {prompt}
|
| 99 |
+
tags:
|
| 100 |
+
- stable-diffusion
|
| 101 |
+
- stable-diffusion-diffusers
|
| 102 |
+
- text-to-image
|
| 103 |
+
- diffusers
|
| 104 |
+
- custom diffusion
|
| 105 |
+
inference: true
|
| 106 |
+
---
|
| 107 |
+
"""
|
| 108 |
+
model_card = f"""
|
| 109 |
+
# Custom Diffusion - {repo_id}
|
| 110 |
+
|
| 111 |
+
These are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n
|
| 112 |
+
{img_str[0]}
|
| 113 |
+
"""
|
| 114 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
| 115 |
+
f.write(yaml + model_card)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
| 119 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
| 120 |
+
pretrained_model_name_or_path,
|
| 121 |
+
subfolder="text_encoder",
|
| 122 |
+
revision=revision,
|
| 123 |
+
)
|
| 124 |
+
model_class = text_encoder_config.architectures[0]
|
| 125 |
+
|
| 126 |
+
if model_class == "CLIPTextModel":
|
| 127 |
+
from transformers import CLIPTextModel
|
| 128 |
+
|
| 129 |
+
return CLIPTextModel
|
| 130 |
+
elif model_class == "RobertaSeriesModelWithTransformation":
|
| 131 |
+
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
|
| 132 |
+
RobertaSeriesModelWithTransformation,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return RobertaSeriesModelWithTransformation
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def freeze_params(params):
|
| 141 |
+
for param in params:
|
| 142 |
+
param.requires_grad = False
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def parse_args(input_args=None):
|
| 146 |
+
parser = argparse.ArgumentParser(
|
| 147 |
+
description="Simple example of a training script.")
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--pretrained_model_name_or_path",
|
| 150 |
+
type=str,
|
| 151 |
+
default=None,
|
| 152 |
+
required=True,
|
| 153 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--revision",
|
| 157 |
+
type=str,
|
| 158 |
+
default=None,
|
| 159 |
+
required=False,
|
| 160 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--tokenizer_name",
|
| 164 |
+
type=str,
|
| 165 |
+
default=None,
|
| 166 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--concept_type",
|
| 170 |
+
type=str,
|
| 171 |
+
required=True,
|
| 172 |
+
choices=['style', 'object', 'memorization'],
|
| 173 |
+
help='the type of removed concepts'
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--caption_target",
|
| 177 |
+
type=str,
|
| 178 |
+
required=True,
|
| 179 |
+
help="target style to remove, used when kldiv loss",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--instance_data_dir",
|
| 183 |
+
type=str,
|
| 184 |
+
default=None,
|
| 185 |
+
help="A folder containing the training data of instance images.",
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--class_data_dir",
|
| 189 |
+
type=str,
|
| 190 |
+
default=None,
|
| 191 |
+
help="A folder containing the training data of class images.",
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--instance_prompt",
|
| 195 |
+
type=str,
|
| 196 |
+
help="The prompt with identifier specifying the instance",
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--class_prompt",
|
| 200 |
+
type=str,
|
| 201 |
+
default=None,
|
| 202 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--mem_impath",
|
| 206 |
+
type=str,
|
| 207 |
+
default="",
|
| 208 |
+
help='the path to saved memorized image. Required when concept_type is memorization'
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--validation_prompt",
|
| 212 |
+
type=str,
|
| 213 |
+
default=None,
|
| 214 |
+
help="A prompt that is used during validation to verify that the model is learning.",
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--num_validation_images",
|
| 218 |
+
type=int,
|
| 219 |
+
default=2,
|
| 220 |
+
help="Number of images that should be generated during validation with `validation_prompt`.",
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--validation_steps",
|
| 224 |
+
type=int,
|
| 225 |
+
default=500,
|
| 226 |
+
help=(
|
| 227 |
+
"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
|
| 228 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
| 229 |
+
),
|
| 230 |
+
)
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--with_prior_preservation",
|
| 233 |
+
default=False,
|
| 234 |
+
action="store_true",
|
| 235 |
+
help="Flag to add prior preservation loss.",
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument("--prior_loss_weight", type=float,
|
| 238 |
+
default=1.0, help="The weight of prior preservation loss.")
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--train_size",
|
| 241 |
+
type=int,
|
| 242 |
+
default=1000,
|
| 243 |
+
help='the number of generated images used for ablating the concept'
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--output_dir",
|
| 247 |
+
type=str,
|
| 248 |
+
default="custom-diffusion-model",
|
| 249 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--num_class_images",
|
| 253 |
+
type=int,
|
| 254 |
+
default=1000,
|
| 255 |
+
help=(
|
| 256 |
+
"Minimal anchor class images. If there are not enough images already present in"
|
| 257 |
+
" class_data_dir, additional images will be sampled with class_prompt."
|
| 258 |
+
),
|
| 259 |
+
)
|
| 260 |
+
parser.add_argument(
|
| 261 |
+
"--num_class_prompts",
|
| 262 |
+
type=int,
|
| 263 |
+
default=200,
|
| 264 |
+
help=(
|
| 265 |
+
"Minimal prompts used to generate anchor class images"
|
| 266 |
+
),
|
| 267 |
+
)
|
| 268 |
+
parser.add_argument("--seed", type=int, default=42,
|
| 269 |
+
help="A seed for reproducible training.")
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--resolution",
|
| 272 |
+
type=int,
|
| 273 |
+
default=512,
|
| 274 |
+
help=(
|
| 275 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 276 |
+
" resolution"
|
| 277 |
+
),
|
| 278 |
+
)
|
| 279 |
+
parser.add_argument(
|
| 280 |
+
"--center_crop",
|
| 281 |
+
default=False,
|
| 282 |
+
action="store_true",
|
| 283 |
+
help=(
|
| 284 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
| 285 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
| 286 |
+
),
|
| 287 |
+
)
|
| 288 |
+
parser.add_argument(
|
| 289 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
| 290 |
+
)
|
| 291 |
+
parser.add_argument(
|
| 292 |
+
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
| 293 |
+
)
|
| 294 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--max_train_steps",
|
| 297 |
+
type=int,
|
| 298 |
+
default=None,
|
| 299 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 300 |
+
)
|
| 301 |
+
parser.add_argument(
|
| 302 |
+
"--checkpointing_steps",
|
| 303 |
+
type=int,
|
| 304 |
+
default=250,
|
| 305 |
+
help=(
|
| 306 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
| 307 |
+
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
| 308 |
+
" training using `--resume_from_checkpoint`."
|
| 309 |
+
),
|
| 310 |
+
)
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--checkpoints_total_limit",
|
| 313 |
+
type=int,
|
| 314 |
+
default=None,
|
| 315 |
+
help=(
|
| 316 |
+
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
| 317 |
+
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
| 318 |
+
" for more docs"
|
| 319 |
+
),
|
| 320 |
+
)
|
| 321 |
+
parser.add_argument(
|
| 322 |
+
"--resume_from_checkpoint",
|
| 323 |
+
type=str,
|
| 324 |
+
default=None,
|
| 325 |
+
help=(
|
| 326 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 327 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 328 |
+
),
|
| 329 |
+
)
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--gradient_accumulation_steps",
|
| 332 |
+
type=int,
|
| 333 |
+
default=1,
|
| 334 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 335 |
+
)
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--gradient_checkpointing",
|
| 338 |
+
action="store_true",
|
| 339 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--learning_rate",
|
| 343 |
+
type=float,
|
| 344 |
+
default=1e-5,
|
| 345 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 346 |
+
)
|
| 347 |
+
parser.add_argument(
|
| 348 |
+
"--scale_lr",
|
| 349 |
+
action="store_true",
|
| 350 |
+
default=False,
|
| 351 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 352 |
+
)
|
| 353 |
+
parser.add_argument(
|
| 354 |
+
"--dataloader_num_workers",
|
| 355 |
+
type=int,
|
| 356 |
+
default=2,
|
| 357 |
+
help=(
|
| 358 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 359 |
+
),
|
| 360 |
+
)
|
| 361 |
+
parser.add_argument(
|
| 362 |
+
"--parameter_group",
|
| 363 |
+
type=str,
|
| 364 |
+
default='cross-attn',
|
| 365 |
+
choices=['full-weight', 'cross-attn', 'embedding'],
|
| 366 |
+
help='parameter groups to finetune. Default: full-weight for memorization and cross-attn for others'
|
| 367 |
+
)
|
| 368 |
+
parser.add_argument(
|
| 369 |
+
"--loss_type_reverse",
|
| 370 |
+
type=str,
|
| 371 |
+
default='model-based',
|
| 372 |
+
help="loss type for reverse fine-tuning",
|
| 373 |
+
)
|
| 374 |
+
parser.add_argument(
|
| 375 |
+
"--lr_scheduler",
|
| 376 |
+
type=str,
|
| 377 |
+
default="constant",
|
| 378 |
+
help=(
|
| 379 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 380 |
+
' "constant", "constant_with_warmup"]'
|
| 381 |
+
),
|
| 382 |
+
)
|
| 383 |
+
parser.add_argument(
|
| 384 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 385 |
+
)
|
| 386 |
+
parser.add_argument(
|
| 387 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 388 |
+
)
|
| 389 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9,
|
| 390 |
+
help="The beta1 parameter for the Adam optimizer.")
|
| 391 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999,
|
| 392 |
+
help="The beta2 parameter for the Adam optimizer.")
|
| 393 |
+
parser.add_argument("--adam_weight_decay", type=float,
|
| 394 |
+
default=1e-2, help="Weight decay to use.")
|
| 395 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08,
|
| 396 |
+
help="Epsilon value for the Adam optimizer")
|
| 397 |
+
parser.add_argument("--max_grad_norm", default=1.0,
|
| 398 |
+
type=float, help="Max gradient norm.")
|
| 399 |
+
parser.add_argument("--push_to_hub", action="store_true",
|
| 400 |
+
help="Whether or not to push the model to the Hub.")
|
| 401 |
+
parser.add_argument("--hub_token", type=str, default=None,
|
| 402 |
+
help="The token to use to push to the Model Hub.")
|
| 403 |
+
parser.add_argument(
|
| 404 |
+
"--hub_model_id",
|
| 405 |
+
type=str,
|
| 406 |
+
default=None,
|
| 407 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 408 |
+
)
|
| 409 |
+
parser.add_argument(
|
| 410 |
+
"--logging_dir",
|
| 411 |
+
type=str,
|
| 412 |
+
default="logs",
|
| 413 |
+
help=(
|
| 414 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 415 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 416 |
+
),
|
| 417 |
+
)
|
| 418 |
+
parser.add_argument(
|
| 419 |
+
"--allow_tf32",
|
| 420 |
+
action="store_true",
|
| 421 |
+
help=(
|
| 422 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 423 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 424 |
+
),
|
| 425 |
+
)
|
| 426 |
+
parser.add_argument(
|
| 427 |
+
"--report_to",
|
| 428 |
+
type=str,
|
| 429 |
+
default="tensorboard",
|
| 430 |
+
help=(
|
| 431 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 432 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 433 |
+
),
|
| 434 |
+
)
|
| 435 |
+
parser.add_argument(
|
| 436 |
+
"--mixed_precision",
|
| 437 |
+
type=str,
|
| 438 |
+
default=None,
|
| 439 |
+
choices=["no", "fp16", "bf16"],
|
| 440 |
+
help=(
|
| 441 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 442 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 443 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 444 |
+
),
|
| 445 |
+
)
|
| 446 |
+
parser.add_argument(
|
| 447 |
+
"--prior_generation_precision",
|
| 448 |
+
type=str,
|
| 449 |
+
default=None,
|
| 450 |
+
choices=["no", "fp32", "fp16", "bf16"],
|
| 451 |
+
help=(
|
| 452 |
+
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 453 |
+
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
|
| 454 |
+
),
|
| 455 |
+
)
|
| 456 |
+
parser.add_argument(
|
| 457 |
+
"--concepts_list",
|
| 458 |
+
type=str,
|
| 459 |
+
default=None,
|
| 460 |
+
help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
|
| 461 |
+
)
|
| 462 |
+
parser.add_argument(
|
| 463 |
+
"--openai_key",
|
| 464 |
+
type=str,
|
| 465 |
+
default="",
|
| 466 |
+
help=(
|
| 467 |
+
"OPENAI API key. required for ablating objects and memorized images."
|
| 468 |
+
),
|
| 469 |
+
)
|
| 470 |
+
parser.add_argument("--local_rank", type=int, default=-1,
|
| 471 |
+
help="For distributed training: local_rank")
|
| 472 |
+
parser.add_argument(
|
| 473 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 474 |
+
)
|
| 475 |
+
parser.add_argument("--hflip", action="store_true",
|
| 476 |
+
help="Apply horizontal flip data augmentation.")
|
| 477 |
+
parser.add_argument("--noaug", action="store_true",
|
| 478 |
+
help="Dont apply augmentation during data augmentation when this flag is enabled.")
|
| 479 |
+
|
| 480 |
+
if input_args is not None:
|
| 481 |
+
args = parser.parse_args(input_args)
|
| 482 |
+
else:
|
| 483 |
+
args = parser.parse_args()
|
| 484 |
+
|
| 485 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 486 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 487 |
+
args.local_rank = env_local_rank
|
| 488 |
+
|
| 489 |
+
if args.with_prior_preservation:
|
| 490 |
+
if args.concepts_list is None:
|
| 491 |
+
if args.class_data_dir is None:
|
| 492 |
+
raise ValueError(
|
| 493 |
+
"You must specify a data directory for class images.")
|
| 494 |
+
if args.class_prompt is None:
|
| 495 |
+
raise ValueError("You must specify prompt for class images.")
|
| 496 |
+
else:
|
| 497 |
+
# logger is not available yet
|
| 498 |
+
if args.class_data_dir is not None:
|
| 499 |
+
warnings.warn(
|
| 500 |
+
"You need not use --class_data_dir without --with_prior_preservation.")
|
| 501 |
+
if args.class_prompt is not None:
|
| 502 |
+
warnings.warn(
|
| 503 |
+
"You need not use --class_prompt without --with_prior_preservation.")
|
| 504 |
+
|
| 505 |
+
return args
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def main(args):
|
| 509 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 510 |
+
|
| 511 |
+
accelerator_project_config = ProjectConfiguration(
|
| 512 |
+
total_limit=args.checkpoints_total_limit)
|
| 513 |
+
|
| 514 |
+
accelerator = Accelerator(
|
| 515 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 516 |
+
mixed_precision=args.mixed_precision,
|
| 517 |
+
log_with=args.report_to,
|
| 518 |
+
project_dir=logging_dir,
|
| 519 |
+
project_config=accelerator_project_config,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
if args.report_to == "wandb":
|
| 523 |
+
if not is_wandb_available():
|
| 524 |
+
raise ImportError(
|
| 525 |
+
"Make sure to install wandb if you want to use it for logging during training.")
|
| 526 |
+
import wandb
|
| 527 |
+
|
| 528 |
+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
| 529 |
+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
| 530 |
+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
| 531 |
+
# Make one log on every process with the configuration for debugging.
|
| 532 |
+
logging.basicConfig(
|
| 533 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 534 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 535 |
+
level=logging.INFO,
|
| 536 |
+
)
|
| 537 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 538 |
+
if accelerator.is_local_main_process:
|
| 539 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 540 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 541 |
+
else:
|
| 542 |
+
transformers.utils.logging.set_verbosity_error()
|
| 543 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 544 |
+
|
| 545 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 546 |
+
# The trackers initializes automatically on the main process.
|
| 547 |
+
if accelerator.is_main_process:
|
| 548 |
+
print(vars(args))
|
| 549 |
+
accelerator.init_trackers("custom-diffusion", config=vars(args))
|
| 550 |
+
|
| 551 |
+
# If passed along, set the training seed now.
|
| 552 |
+
if args.seed is not None:
|
| 553 |
+
set_seed(args.seed)
|
| 554 |
+
if args.concepts_list is None:
|
| 555 |
+
args.concepts_list = [
|
| 556 |
+
{
|
| 557 |
+
"instance_prompt": args.instance_prompt,
|
| 558 |
+
"class_prompt": args.class_prompt,
|
| 559 |
+
"instance_data_dir": args.instance_data_dir,
|
| 560 |
+
"class_data_dir": args.class_data_dir,
|
| 561 |
+
"caption_target": args.caption_target,
|
| 562 |
+
}
|
| 563 |
+
]
|
| 564 |
+
else:
|
| 565 |
+
with open(args.concepts_list, "r") as f:
|
| 566 |
+
args.concepts_list = json.load(f)
|
| 567 |
+
|
| 568 |
+
# Generate class images if prior preservation is enabled.
|
| 569 |
+
for i, concept in enumerate(args.concepts_list):
|
| 570 |
+
# directly path to ablation images and its corresponding prompts is provided.
|
| 571 |
+
if (concept['instance_prompt'] is not None and concept['instance_data_dir'] is not None):
|
| 572 |
+
break
|
| 573 |
+
|
| 574 |
+
class_images_dir = Path(concept['class_data_dir'])
|
| 575 |
+
if not class_images_dir.exists():
|
| 576 |
+
class_images_dir.mkdir(parents=True, exist_ok=True)
|
| 577 |
+
os.makedirs(f'{class_images_dir}/images', exist_ok=True)
|
| 578 |
+
|
| 579 |
+
# we need to generate training images
|
| 580 |
+
if len(list(Path(os.path.join(class_images_dir, 'images')).iterdir())) < args.num_class_images:
|
| 581 |
+
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
| 582 |
+
if args.prior_generation_precision == "fp32":
|
| 583 |
+
torch_dtype = torch.float32
|
| 584 |
+
elif args.prior_generation_precision == "fp16":
|
| 585 |
+
torch_dtype = torch.float16
|
| 586 |
+
elif args.prior_generation_precision == "bf16":
|
| 587 |
+
torch_dtype = torch.bfloat16
|
| 588 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
| 589 |
+
args.pretrained_model_name_or_path,
|
| 590 |
+
torch_dtype=torch_dtype,
|
| 591 |
+
safety_checker=None,
|
| 592 |
+
revision=args.revision,
|
| 593 |
+
)
|
| 594 |
+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 595 |
+
pipeline.scheduler.config)
|
| 596 |
+
|
| 597 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 598 |
+
pipeline.to(accelerator.device)
|
| 599 |
+
|
| 600 |
+
# need to create prompts using class_prompt.
|
| 601 |
+
if not os.path.isfile(concept['class_prompt']):
|
| 602 |
+
# style based prompts are retrieved from laion dataset
|
| 603 |
+
if args.concept_type == 'style':
|
| 604 |
+
with open(os.path.join(class_images_dir, 'painting.txt')) as f:
|
| 605 |
+
class_prompt_collection = [
|
| 606 |
+
x.strip() for x in f.readlines()]
|
| 607 |
+
|
| 608 |
+
# LLM based prompt collection.
|
| 609 |
+
else:
|
| 610 |
+
class_prompt = concept['class_prompt']
|
| 611 |
+
# in case of object query chatGPT to generate captions containing the anchor category
|
| 612 |
+
if args.concept_type == 'object':
|
| 613 |
+
class_prompt_collection, _ = getanchorprompts(
|
| 614 |
+
pipeline, accelerator, class_prompt, args.concept_type, class_images_dir, args.openai_key, args.num_class_prompts)
|
| 615 |
+
with open(class_images_dir / 'caption_anchor.txt', 'w') as f:
|
| 616 |
+
for prompt in class_prompt_collection:
|
| 617 |
+
f.write(prompt + '\n')
|
| 618 |
+
# in case of memorization query chatGPT to generate different captions that can be paraphrase of the origianl caption
|
| 619 |
+
elif args.concept_type == 'memorization':
|
| 620 |
+
class_prompt_collection, caption_target = getanchorprompts(
|
| 621 |
+
pipeline, accelerator, class_prompt, args.concept_type, class_images_dir, args.openai_key, args.num_class_prompts, mem_impath=args.mem_impath)
|
| 622 |
+
concept['caption_target'] += f';*+{caption_target}'
|
| 623 |
+
with open(class_images_dir / 'caption_target.txt', 'w') as f:
|
| 624 |
+
f.write(concept['caption_target'])
|
| 625 |
+
print(class_prompt_collection,
|
| 626 |
+
concept['caption_target'])
|
| 627 |
+
# class_prompt is filepath to prompts.
|
| 628 |
+
else:
|
| 629 |
+
with open(concept['class_prompt']) as f:
|
| 630 |
+
class_prompt_collection = [
|
| 631 |
+
x.strip() for x in f.readlines()]
|
| 632 |
+
|
| 633 |
+
num_new_images = args.num_class_images
|
| 634 |
+
logger.info(
|
| 635 |
+
f"Number of class images to sample: {num_new_images}.")
|
| 636 |
+
|
| 637 |
+
sample_dataset = PromptDataset(
|
| 638 |
+
class_prompt_collection, num_new_images)
|
| 639 |
+
sample_dataloader = torch.utils.data.DataLoader(
|
| 640 |
+
sample_dataset, batch_size=args.sample_batch_size)
|
| 641 |
+
|
| 642 |
+
sample_dataloader = accelerator.prepare(sample_dataloader)
|
| 643 |
+
|
| 644 |
+
if os.path.exists(f'{class_images_dir}/caption.txt'):
|
| 645 |
+
os.remove(f'{class_images_dir}/caption.txt')
|
| 646 |
+
if os.path.exists(f'{class_images_dir}/images.txt'):
|
| 647 |
+
os.remove(f'{class_images_dir}/images.txt')
|
| 648 |
+
|
| 649 |
+
for example in tqdm(
|
| 650 |
+
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
| 651 |
+
):
|
| 652 |
+
accelerator.wait_for_everyone()
|
| 653 |
+
with open(f'{class_images_dir}/caption.txt', 'a') as f1, open(f'{class_images_dir}/images.txt', 'a') as f2:
|
| 654 |
+
images = pipeline(example["prompt"], num_inference_steps=25, guidance_scale=6., eta=1.).images
|
| 655 |
+
|
| 656 |
+
for i, image in enumerate(images):
|
| 657 |
+
hash_image = hashlib.sha1(
|
| 658 |
+
image.tobytes()).hexdigest()
|
| 659 |
+
image_filename = class_images_dir / \
|
| 660 |
+
f"images/{example['index'][i]}-{hash_image}.jpg"
|
| 661 |
+
image.save(image_filename)
|
| 662 |
+
f2.write(str(image_filename)+'\n')
|
| 663 |
+
f1.write('\n'.join(example["prompt"]) + '\n')
|
| 664 |
+
accelerator.wait_for_everyone()
|
| 665 |
+
|
| 666 |
+
del pipeline
|
| 667 |
+
|
| 668 |
+
if args.concept_type == 'memorization':
|
| 669 |
+
filter(class_images_dir, args.mem_impath,
|
| 670 |
+
outpath=str(class_images_dir / 'filtered'))
|
| 671 |
+
if os.path.exists(class_images_dir / 'caption_target.txt'):
|
| 672 |
+
with open(class_images_dir / 'caption_target.txt', 'r') as f:
|
| 673 |
+
concept['caption_target'] = f.readlines()[0].strip()
|
| 674 |
+
class_images_dir = class_images_dir / 'filtered'
|
| 675 |
+
|
| 676 |
+
concept['class_prompt'] = os.path.join(
|
| 677 |
+
class_images_dir, 'caption.txt')
|
| 678 |
+
concept['class_data_dir'] = os.path.join(
|
| 679 |
+
class_images_dir, 'images.txt')
|
| 680 |
+
concept['instance_prompt'] = os.path.join(
|
| 681 |
+
class_images_dir, 'caption.txt')
|
| 682 |
+
concept['instance_data_dir'] = os.path.join(
|
| 683 |
+
class_images_dir, 'images.txt')
|
| 684 |
+
|
| 685 |
+
if torch.cuda.is_available():
|
| 686 |
+
torch.cuda.empty_cache()
|
| 687 |
+
|
| 688 |
+
# Handle the repository creation
|
| 689 |
+
if accelerator.is_main_process:
|
| 690 |
+
if args.output_dir is not None:
|
| 691 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 692 |
+
|
| 693 |
+
if args.push_to_hub:
|
| 694 |
+
print(args.hub_model_id or Path(args.output_dir).name)
|
| 695 |
+
repo_id = create_repo(
|
| 696 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 697 |
+
)
|
| 698 |
+
print(repo_id)
|
| 699 |
+
repo_id = args.hub_model_id
|
| 700 |
+
|
| 701 |
+
# Load the tokenizer
|
| 702 |
+
if args.tokenizer_name:
|
| 703 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 704 |
+
args.tokenizer_name,
|
| 705 |
+
revision=args.revision,
|
| 706 |
+
use_fast=False,
|
| 707 |
+
)
|
| 708 |
+
elif args.pretrained_model_name_or_path:
|
| 709 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 710 |
+
args.pretrained_model_name_or_path,
|
| 711 |
+
subfolder="tokenizer",
|
| 712 |
+
revision=args.revision,
|
| 713 |
+
use_fast=False,
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
# import correct text encoder class
|
| 717 |
+
text_encoder_cls = import_model_class_from_model_name_or_path(
|
| 718 |
+
args.pretrained_model_name_or_path, args.revision)
|
| 719 |
+
|
| 720 |
+
# Load scheduler and models
|
| 721 |
+
noise_scheduler = DDPMScheduler.from_pretrained(
|
| 722 |
+
args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 723 |
+
text_encoder = text_encoder_cls.from_pretrained(
|
| 724 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
| 725 |
+
)
|
| 726 |
+
vae = AutoencoderKL.from_pretrained(
|
| 727 |
+
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
| 728 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 729 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
vae.requires_grad_(False)
|
| 733 |
+
if args.parameter_group != 'embedding':
|
| 734 |
+
text_encoder.requires_grad_(False)
|
| 735 |
+
unet = create_custom_diffusion(unet, args.parameter_group)
|
| 736 |
+
|
| 737 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
| 738 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
| 739 |
+
weight_dtype = torch.float32
|
| 740 |
+
if accelerator.mixed_precision == "fp16":
|
| 741 |
+
weight_dtype = torch.float16
|
| 742 |
+
elif accelerator.mixed_precision == "bf16":
|
| 743 |
+
weight_dtype = torch.bfloat16
|
| 744 |
+
|
| 745 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
| 746 |
+
if accelerator.mixed_precision != "fp16":
|
| 747 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
| 748 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 749 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 750 |
+
|
| 751 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 752 |
+
if is_xformers_available():
|
| 753 |
+
import xformers
|
| 754 |
+
xformers_version = version.parse(xformers.__version__)
|
| 755 |
+
if xformers_version == version.parse("0.0.16"):
|
| 756 |
+
logger.warn(
|
| 757 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
| 758 |
+
)
|
| 759 |
+
unet.enable_xformers_memory_efficient_attention()
|
| 760 |
+
else:
|
| 761 |
+
raise ValueError(
|
| 762 |
+
"xformers is not available. Make sure it is installed correctly")
|
| 763 |
+
|
| 764 |
+
if args.gradient_checkpointing:
|
| 765 |
+
unet.enable_gradient_checkpointing()
|
| 766 |
+
if args.parameter_group == 'embedding':
|
| 767 |
+
text_encoder.gradient_checkpointing_enable()
|
| 768 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 769 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 770 |
+
if args.allow_tf32:
|
| 771 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 772 |
+
|
| 773 |
+
if args.scale_lr:
|
| 774 |
+
args.learning_rate = (
|
| 775 |
+
args.learning_rate * args.gradient_accumulation_steps *
|
| 776 |
+
args.train_batch_size * accelerator.num_processes
|
| 777 |
+
)
|
| 778 |
+
if args.with_prior_preservation:
|
| 779 |
+
args.learning_rate = args.learning_rate * 2.
|
| 780 |
+
|
| 781 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 782 |
+
if args.use_8bit_adam:
|
| 783 |
+
try:
|
| 784 |
+
import bitsandbytes as bnb
|
| 785 |
+
except ImportError:
|
| 786 |
+
raise ImportError(
|
| 787 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 791 |
+
else:
|
| 792 |
+
optimizer_class = torch.optim.AdamW
|
| 793 |
+
|
| 794 |
+
# Adding a modifier token which is optimized ####
|
| 795 |
+
# Code taken from https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
|
| 796 |
+
modifier_token_id = []
|
| 797 |
+
if args.parameter_group == 'embedding':
|
| 798 |
+
assert args.concept_type != 'memorization', "embedding finetuning is not supported for memorization"
|
| 799 |
+
|
| 800 |
+
for concept in args.concept_list:
|
| 801 |
+
# Convert the caption_target to ids
|
| 802 |
+
token_ids = tokenizer.encode(
|
| 803 |
+
[concept['caption_target']], add_special_tokens=False)
|
| 804 |
+
print(token_ids)
|
| 805 |
+
# Check if initializer_token is a single token or a sequence of tokens
|
| 806 |
+
modifier_token_id += token_ids
|
| 807 |
+
|
| 808 |
+
# Freeze all parameters except for the token embeddings in text encoder
|
| 809 |
+
params_to_freeze = itertools.chain(
|
| 810 |
+
text_encoder.text_model.encoder.parameters(),
|
| 811 |
+
text_encoder.text_model.final_layer_norm.parameters(),
|
| 812 |
+
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
| 813 |
+
)
|
| 814 |
+
freeze_params(params_to_freeze)
|
| 815 |
+
params_to_optimize = itertools.chain(
|
| 816 |
+
text_encoder.get_input_embeddings().parameters())
|
| 817 |
+
else:
|
| 818 |
+
if args.parameter_group == 'cross-attn':
|
| 819 |
+
params_to_optimize = itertools.chain([x[1] for x in unet.named_parameters() if (
|
| 820 |
+
'attn2.to_k' in x[0] or 'attn2.to_v' in x[0])])
|
| 821 |
+
if args.parameter_group == 'full-weight':
|
| 822 |
+
params_to_optimize = itertools.chain(unet.parameters())
|
| 823 |
+
|
| 824 |
+
# Optimizer creation
|
| 825 |
+
optimizer = optimizer_class(
|
| 826 |
+
params_to_optimize,
|
| 827 |
+
lr=args.learning_rate,
|
| 828 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 829 |
+
weight_decay=args.adam_weight_decay,
|
| 830 |
+
eps=args.adam_epsilon,
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
# Dataset and DataLoaders creation:
|
| 834 |
+
train_dataset = CustomDiffusionDataset(
|
| 835 |
+
concepts_list=args.concepts_list,
|
| 836 |
+
concept_type=args.concept_type,
|
| 837 |
+
tokenizer=tokenizer,
|
| 838 |
+
with_prior_preservation=args.with_prior_preservation,
|
| 839 |
+
size=args.resolution,
|
| 840 |
+
center_crop=args.center_crop,
|
| 841 |
+
num_class_images=args.num_class_images,
|
| 842 |
+
hflip=args.hflip, aug=not args.noaug,
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 846 |
+
train_dataset,
|
| 847 |
+
batch_size=args.train_batch_size,
|
| 848 |
+
shuffle=True,
|
| 849 |
+
collate_fn=lambda examples: collate_fn(
|
| 850 |
+
examples, args.with_prior_preservation),
|
| 851 |
+
num_workers=args.dataloader_num_workers,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
# Scheduler and math around the number of training steps.
|
| 855 |
+
overrode_max_train_steps = False
|
| 856 |
+
num_update_steps_per_epoch = math.ceil(
|
| 857 |
+
len(train_dataloader) / args.gradient_accumulation_steps)
|
| 858 |
+
if args.max_train_steps is None:
|
| 859 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 860 |
+
overrode_max_train_steps = True
|
| 861 |
+
|
| 862 |
+
lr_scheduler = get_scheduler(
|
| 863 |
+
args.lr_scheduler,
|
| 864 |
+
optimizer=optimizer,
|
| 865 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
| 866 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
# Prepare everything with our `accelerator`.
|
| 870 |
+
if args.parameter_group == 'embedding':
|
| 871 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 872 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler
|
| 873 |
+
)
|
| 874 |
+
else:
|
| 875 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 876 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 880 |
+
num_update_steps_per_epoch = math.ceil(
|
| 881 |
+
len(train_dataloader) / args.gradient_accumulation_steps)
|
| 882 |
+
if overrode_max_train_steps:
|
| 883 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 884 |
+
# Afterwards we recalculate our number of training epochs
|
| 885 |
+
args.num_train_epochs = math.ceil(
|
| 886 |
+
args.max_train_steps / num_update_steps_per_epoch)
|
| 887 |
+
|
| 888 |
+
# Train!
|
| 889 |
+
total_batch_size = args.train_batch_size * \
|
| 890 |
+
accelerator.num_processes * args.gradient_accumulation_steps
|
| 891 |
+
|
| 892 |
+
logger.info("***** Running training *****")
|
| 893 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 894 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 895 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 896 |
+
logger.info(
|
| 897 |
+
f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 898 |
+
logger.info(
|
| 899 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 900 |
+
logger.info(
|
| 901 |
+
f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 902 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 903 |
+
global_step = 0
|
| 904 |
+
first_epoch = 0
|
| 905 |
+
|
| 906 |
+
# Potentially load in the weights and states from a previous save
|
| 907 |
+
if args.resume_from_checkpoint:
|
| 908 |
+
if args.resume_from_checkpoint != "latest":
|
| 909 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 910 |
+
else:
|
| 911 |
+
# Get the mos recent checkpoint
|
| 912 |
+
dirs = os.listdir(args.output_dir)
|
| 913 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 914 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 915 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 916 |
+
|
| 917 |
+
if path is None:
|
| 918 |
+
accelerator.print(
|
| 919 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 920 |
+
)
|
| 921 |
+
args.resume_from_checkpoint = None
|
| 922 |
+
else:
|
| 923 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 924 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 925 |
+
global_step = int(path.split("-")[1])
|
| 926 |
+
|
| 927 |
+
resume_global_step = global_step * args.gradient_accumulation_steps
|
| 928 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 929 |
+
resume_step = resume_global_step % (
|
| 930 |
+
num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
| 931 |
+
|
| 932 |
+
# Only show the progress bar once on each machine.
|
| 933 |
+
progress_bar = tqdm(range(global_step, args.max_train_steps),
|
| 934 |
+
disable=not accelerator.is_local_main_process)
|
| 935 |
+
progress_bar.set_description("Steps")
|
| 936 |
+
|
| 937 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 938 |
+
if args.parameter_group == 'embedding':
|
| 939 |
+
text_encoder.train()
|
| 940 |
+
else:
|
| 941 |
+
unet.train()
|
| 942 |
+
for step, batch in enumerate(train_dataloader):
|
| 943 |
+
# Skip steps until we reach the resumed step
|
| 944 |
+
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
| 945 |
+
if step % args.gradient_accumulation_steps == 0:
|
| 946 |
+
progress_bar.update(1)
|
| 947 |
+
continue
|
| 948 |
+
|
| 949 |
+
with accelerator.accumulate(unet) if args.parameter_group != 'embedding' else accelerator.accumulate(text_encoder):
|
| 950 |
+
# Convert images to latent space
|
| 951 |
+
latents = vae.encode(batch["pixel_values"].to(
|
| 952 |
+
dtype=weight_dtype)).latent_dist.sample()
|
| 953 |
+
latents = latents * vae.config.scaling_factor
|
| 954 |
+
|
| 955 |
+
# Sample noise that we'll add to the latents
|
| 956 |
+
noise = torch.randn_like(latents)
|
| 957 |
+
bsz = latents.shape[0]
|
| 958 |
+
# Sample a random timestep for each image
|
| 959 |
+
timesteps = torch.randint(
|
| 960 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
| 961 |
+
timesteps = timesteps.long()
|
| 962 |
+
|
| 963 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 964 |
+
# (this is the forward diffusion process)
|
| 965 |
+
noisy_latents = noise_scheduler.add_noise(
|
| 966 |
+
latents, noise, timesteps)
|
| 967 |
+
|
| 968 |
+
# Get the text embedding for conditioning
|
| 969 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
| 970 |
+
encoder_anchor_hidden_states = text_encoder(
|
| 971 |
+
batch["input_anchor_ids"])[0]
|
| 972 |
+
|
| 973 |
+
# Predict the noise residual
|
| 974 |
+
model_pred = unet(noisy_latents, timesteps,
|
| 975 |
+
encoder_hidden_states).sample
|
| 976 |
+
with torch.no_grad():
|
| 977 |
+
model_pred_anchor = unet(noisy_latents[:encoder_anchor_hidden_states.size(
|
| 978 |
+
0)], timesteps[:encoder_anchor_hidden_states.size(0)], encoder_anchor_hidden_states).sample
|
| 979 |
+
|
| 980 |
+
# Get the target for loss depending on the prediction type
|
| 981 |
+
if args.loss_type_reverse == 'model-based':
|
| 982 |
+
if args.with_prior_preservation:
|
| 983 |
+
target_prior = torch.chunk(noise, 2, dim=0)[1]
|
| 984 |
+
target = model_pred_anchor
|
| 985 |
+
else:
|
| 986 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 987 |
+
target = noise
|
| 988 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 989 |
+
target = noise_scheduler.get_velocity(
|
| 990 |
+
latents, noise, timesteps)
|
| 991 |
+
else:
|
| 992 |
+
raise ValueError(
|
| 993 |
+
f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 994 |
+
if args.with_prior_preservation:
|
| 995 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
| 996 |
+
|
| 997 |
+
if args.with_prior_preservation:
|
| 998 |
+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
| 999 |
+
model_pred, model_pred_prior = torch.chunk(
|
| 1000 |
+
model_pred, 2, dim=0)
|
| 1001 |
+
mask = torch.chunk(batch["mask"], 2, dim=0)[0]
|
| 1002 |
+
# Compute instance loss
|
| 1003 |
+
loss = F.mse_loss(model_pred.float(),
|
| 1004 |
+
target.float(), reduction="none")
|
| 1005 |
+
loss = (
|
| 1006 |
+
(loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
|
| 1007 |
+
|
| 1008 |
+
# Compute prior loss
|
| 1009 |
+
prior_loss = F.mse_loss(
|
| 1010 |
+
model_pred_prior.float(), target_prior.float(), reduction="mean")
|
| 1011 |
+
|
| 1012 |
+
# Add the prior loss to the instance loss.
|
| 1013 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
| 1014 |
+
else:
|
| 1015 |
+
mask = batch["mask"]
|
| 1016 |
+
loss = F.mse_loss(model_pred.float(),
|
| 1017 |
+
target.float(), reduction="none")
|
| 1018 |
+
loss = (
|
| 1019 |
+
(loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
|
| 1020 |
+
|
| 1021 |
+
accelerator.backward(loss)
|
| 1022 |
+
# Zero out the gradients for all token embeddings except the newly added
|
| 1023 |
+
# embeddings for the concept, as we only want to optimize the concept embeddings
|
| 1024 |
+
if args.parameter_group == 'embedding':
|
| 1025 |
+
if accelerator.num_processes > 1:
|
| 1026 |
+
grads_text_encoder = text_encoder.module.get_input_embeddings().weight.grad
|
| 1027 |
+
else:
|
| 1028 |
+
grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
|
| 1029 |
+
# Get the index for tokens that we want to zero the grads for
|
| 1030 |
+
index_grads_to_zero = torch.arange(
|
| 1031 |
+
len(tokenizer)) != modifier_token_id[0]
|
| 1032 |
+
for i in range(len(modifier_token_id[1:])):
|
| 1033 |
+
index_grads_to_zero = index_grads_to_zero & (
|
| 1034 |
+
torch.arange(len(tokenizer)) != modifier_token_id[i])
|
| 1035 |
+
grads_text_encoder.data[index_grads_to_zero,
|
| 1036 |
+
:] = grads_text_encoder.data[index_grads_to_zero, :].fill_(0)
|
| 1037 |
+
|
| 1038 |
+
if accelerator.sync_gradients:
|
| 1039 |
+
params_to_clip = (
|
| 1040 |
+
itertools.chain(text_encoder.parameters())
|
| 1041 |
+
if args.parameter_group == 'embedding'
|
| 1042 |
+
else itertools.chain([x[1] for x in unet.named_parameters() if ('attn2' in x[0])])
|
| 1043 |
+
if args.parameter_group == 'cross-attn'
|
| 1044 |
+
else itertools.chain(unet.parameters())
|
| 1045 |
+
)
|
| 1046 |
+
accelerator.clip_grad_norm_(
|
| 1047 |
+
params_to_clip, args.max_grad_norm)
|
| 1048 |
+
optimizer.step()
|
| 1049 |
+
lr_scheduler.step()
|
| 1050 |
+
optimizer.zero_grad()
|
| 1051 |
+
|
| 1052 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 1053 |
+
if accelerator.sync_gradients:
|
| 1054 |
+
progress_bar.update(1)
|
| 1055 |
+
global_step += 1
|
| 1056 |
+
|
| 1057 |
+
if global_step % args.checkpointing_steps == 0:
|
| 1058 |
+
if accelerator.is_main_process:
|
| 1059 |
+
pipeline = CustomDiffusionPipeline.from_pretrained(
|
| 1060 |
+
args.pretrained_model_name_or_path,
|
| 1061 |
+
unet=accelerator.unwrap_model(unet),
|
| 1062 |
+
text_encoder=accelerator.unwrap_model(
|
| 1063 |
+
text_encoder),
|
| 1064 |
+
tokenizer=tokenizer,
|
| 1065 |
+
revision=args.revision,
|
| 1066 |
+
modifier_token_id=modifier_token_id,
|
| 1067 |
+
)
|
| 1068 |
+
save_path = os.path.join(
|
| 1069 |
+
args.output_dir, f"delta-{global_step}")
|
| 1070 |
+
pipeline.save_pretrained(
|
| 1071 |
+
save_path, parameter_group=args.parameter_group)
|
| 1072 |
+
logger.info(f"Saved state to {save_path}")
|
| 1073 |
+
|
| 1074 |
+
logs = {"loss": loss.detach().item(
|
| 1075 |
+
), "lr": lr_scheduler.get_last_lr()[0]}
|
| 1076 |
+
progress_bar.set_postfix(**logs)
|
| 1077 |
+
accelerator.log(logs, step=global_step)
|
| 1078 |
+
|
| 1079 |
+
if global_step >= args.max_train_steps:
|
| 1080 |
+
break
|
| 1081 |
+
|
| 1082 |
+
if accelerator.is_main_process:
|
| 1083 |
+
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
| 1084 |
+
logger.info(
|
| 1085 |
+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
| 1086 |
+
f" {args.validation_prompt}."
|
| 1087 |
+
)
|
| 1088 |
+
# create pipeline
|
| 1089 |
+
pipeline = CustomDiffusionPipeline.from_pretrained(
|
| 1090 |
+
args.pretrained_model_name_or_path,
|
| 1091 |
+
unet=accelerator.unwrap_model(unet),
|
| 1092 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
| 1093 |
+
tokenizer=tokenizer,
|
| 1094 |
+
revision=args.revision,
|
| 1095 |
+
modifier_token_id=modifier_token_id,
|
| 1096 |
+
)
|
| 1097 |
+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 1098 |
+
pipeline.scheduler.config)
|
| 1099 |
+
pipeline = pipeline.to(accelerator.device)
|
| 1100 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 1101 |
+
|
| 1102 |
+
# run inference
|
| 1103 |
+
generator = torch.Generator(
|
| 1104 |
+
device=accelerator.device).manual_seed(args.seed)
|
| 1105 |
+
images = [
|
| 1106 |
+
pipeline(args.validation_prompt, num_inference_steps=25,
|
| 1107 |
+
generator=generator, eta=1.).images[0]
|
| 1108 |
+
for _ in range(args.num_validation_images)
|
| 1109 |
+
]
|
| 1110 |
+
|
| 1111 |
+
for tracker in accelerator.trackers:
|
| 1112 |
+
if tracker.name == "tensorboard":
|
| 1113 |
+
np_images = np.stack([np.asarray(img)
|
| 1114 |
+
for img in images])
|
| 1115 |
+
tracker.writer.add_images(
|
| 1116 |
+
"validation", np_images, epoch, dataformats="NHWC")
|
| 1117 |
+
if tracker.name == "wandb":
|
| 1118 |
+
tracker.log(
|
| 1119 |
+
{
|
| 1120 |
+
"validation": [
|
| 1121 |
+
wandb.Image(
|
| 1122 |
+
image, caption=f"{i}: {args.validation_prompt}")
|
| 1123 |
+
for i, image in enumerate(images)
|
| 1124 |
+
]
|
| 1125 |
+
}
|
| 1126 |
+
)
|
| 1127 |
+
|
| 1128 |
+
del pipeline
|
| 1129 |
+
torch.cuda.empty_cache()
|
| 1130 |
+
|
| 1131 |
+
accelerator.wait_for_everyone()
|
| 1132 |
+
if accelerator.is_main_process:
|
| 1133 |
+
unet = unet.to(torch.float32)
|
| 1134 |
+
pipeline = CustomDiffusionPipeline.from_pretrained(
|
| 1135 |
+
args.pretrained_model_name_or_path,
|
| 1136 |
+
unet=accelerator.unwrap_model(unet),
|
| 1137 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
| 1138 |
+
tokenizer=tokenizer,
|
| 1139 |
+
revision=args.revision,
|
| 1140 |
+
modifier_token_id=modifier_token_id,
|
| 1141 |
+
)
|
| 1142 |
+
save_path = os.path.join(args.output_dir, "delta.bin")
|
| 1143 |
+
pipeline.save_pretrained(
|
| 1144 |
+
save_path, parameter_group=args.parameter_group)
|
| 1145 |
+
|
| 1146 |
+
# run inference
|
| 1147 |
+
if args.validation_prompt and args.num_validation_images > 0:
|
| 1148 |
+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 1149 |
+
pipeline.scheduler.config)
|
| 1150 |
+
pipeline = pipeline.to(accelerator.device)
|
| 1151 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 1152 |
+
|
| 1153 |
+
# run inference
|
| 1154 |
+
generator = torch.Generator(
|
| 1155 |
+
device=accelerator.device).manual_seed(args.seed)
|
| 1156 |
+
images = [
|
| 1157 |
+
pipeline(args.validation_prompt, num_inference_steps=25,
|
| 1158 |
+
generator=generator, eta=1.).images[0]
|
| 1159 |
+
for _ in range(args.num_validation_images)
|
| 1160 |
+
]
|
| 1161 |
+
|
| 1162 |
+
for tracker in accelerator.trackers:
|
| 1163 |
+
if tracker.name == "tensorboard":
|
| 1164 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 1165 |
+
tracker.writer.add_images(
|
| 1166 |
+
"test", np_images, epoch, dataformats="NHWC")
|
| 1167 |
+
if tracker.name == "wandb":
|
| 1168 |
+
tracker.log(
|
| 1169 |
+
{
|
| 1170 |
+
"test": [
|
| 1171 |
+
wandb.Image(
|
| 1172 |
+
image, caption=f"{i}: {args.validation_prompt}")
|
| 1173 |
+
for i, image in enumerate(images)
|
| 1174 |
+
]
|
| 1175 |
+
}
|
| 1176 |
+
)
|
| 1177 |
+
|
| 1178 |
+
if args.push_to_hub:
|
| 1179 |
+
save_model_card(
|
| 1180 |
+
repo_id,
|
| 1181 |
+
images=images,
|
| 1182 |
+
base_model=args.pretrained_model_name_or_path,
|
| 1183 |
+
prompt=args.instance_prompt,
|
| 1184 |
+
repo_folder=args.output_dir,
|
| 1185 |
+
)
|
| 1186 |
+
api = HfApi(token=args.hub_token)
|
| 1187 |
+
api.upload_folder(
|
| 1188 |
+
repo_id=repo_id,
|
| 1189 |
+
folder_path=args.output_dir,
|
| 1190 |
+
path_in_repo='.',
|
| 1191 |
+
repo_type='model'
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
accelerator.end_training()
|
| 1195 |
+
|
| 1196 |
+
|
| 1197 |
+
if __name__ == "__main__":
|
| 1198 |
+
args = parse_args()
|
| 1199 |
+
main(args)
|
concept-ablation-diffusers/utils.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import shutil
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import openai
|
| 9 |
+
import regex as re
|
| 10 |
+
import requests
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
from tqdm.auto import tqdm
|
| 16 |
+
|
| 17 |
+
from diffusers import DPMSolverMultistepScheduler
|
| 18 |
+
|
| 19 |
+
normalize = transforms.Normalize(
|
| 20 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
|
| 21 |
+
)
|
| 22 |
+
small_288 = transforms.Compose([
|
| 23 |
+
transforms.Resize(288),
|
| 24 |
+
transforms.ToTensor(),
|
| 25 |
+
normalize,
|
| 26 |
+
])
|
| 27 |
+
|
| 28 |
+
def collate_fn(examples, with_prior_preservation):
|
| 29 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
| 30 |
+
input_anchor_ids = [example["instance_anchor_prompt_ids"]
|
| 31 |
+
for example in examples]
|
| 32 |
+
pixel_values = [example["instance_images"] for example in examples]
|
| 33 |
+
mask = [example["mask"] for example in examples]
|
| 34 |
+
# Concat class and instance examples for prior preservation.
|
| 35 |
+
# We do this to avoid doing two forward passes.
|
| 36 |
+
if with_prior_preservation:
|
| 37 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
| 38 |
+
pixel_values += [example["class_images"] for example in examples]
|
| 39 |
+
mask += [example["class_mask"] for example in examples]
|
| 40 |
+
|
| 41 |
+
input_ids = torch.cat(input_ids, dim=0)
|
| 42 |
+
input_anchor_ids = torch.cat(input_anchor_ids, dim=0)
|
| 43 |
+
pixel_values = torch.stack(pixel_values)
|
| 44 |
+
mask = torch.stack(mask)
|
| 45 |
+
pixel_values = pixel_values.to(
|
| 46 |
+
memory_format=torch.contiguous_format).float()
|
| 47 |
+
mask = mask.to(memory_format=torch.contiguous_format).float()
|
| 48 |
+
|
| 49 |
+
batch = {
|
| 50 |
+
"input_ids": input_ids,
|
| 51 |
+
"input_anchor_ids": input_anchor_ids,
|
| 52 |
+
"pixel_values": pixel_values,
|
| 53 |
+
"mask": mask.unsqueeze(1)
|
| 54 |
+
}
|
| 55 |
+
return batch
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class PromptDataset(Dataset):
|
| 59 |
+
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
| 60 |
+
|
| 61 |
+
def __init__(self, prompt, num_samples):
|
| 62 |
+
self.prompt = prompt
|
| 63 |
+
self.num_samples = num_samples
|
| 64 |
+
|
| 65 |
+
def __len__(self):
|
| 66 |
+
return self.num_samples
|
| 67 |
+
|
| 68 |
+
def __getitem__(self, index):
|
| 69 |
+
example = {}
|
| 70 |
+
example["prompt"] = self.prompt[index % len(self.prompt)]
|
| 71 |
+
example["index"] = index
|
| 72 |
+
return example
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class CustomDiffusionDataset(Dataset):
|
| 76 |
+
"""
|
| 77 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
| 78 |
+
It pre-processes the images and the tokenizes prompts.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
concepts_list,
|
| 84 |
+
concept_type,
|
| 85 |
+
tokenizer,
|
| 86 |
+
size=512,
|
| 87 |
+
center_crop=False,
|
| 88 |
+
with_prior_preservation=False,
|
| 89 |
+
num_class_images=200,
|
| 90 |
+
hflip=False,
|
| 91 |
+
aug=True,
|
| 92 |
+
):
|
| 93 |
+
self.size = size
|
| 94 |
+
self.center_crop = center_crop
|
| 95 |
+
self.tokenizer = tokenizer
|
| 96 |
+
self.interpolation = Image.BILINEAR
|
| 97 |
+
self.aug = aug
|
| 98 |
+
self.concept_type = concept_type
|
| 99 |
+
|
| 100 |
+
self.instance_images_path = []
|
| 101 |
+
self.class_images_path = []
|
| 102 |
+
self.with_prior_preservation = with_prior_preservation
|
| 103 |
+
for concept in concepts_list:
|
| 104 |
+
with open(concept["instance_data_dir"], "r") as f:
|
| 105 |
+
inst_images_path = f.read().splitlines()
|
| 106 |
+
with open(concept["instance_prompt"], "r") as f:
|
| 107 |
+
inst_prompt = f.read().splitlines()
|
| 108 |
+
inst_img_path = [(x, y, concept['caption_target'])
|
| 109 |
+
for (x, y) in zip(inst_images_path, inst_prompt)]
|
| 110 |
+
self.instance_images_path.extend(inst_img_path)
|
| 111 |
+
|
| 112 |
+
if with_prior_preservation:
|
| 113 |
+
class_data_root = Path(concept["class_data_dir"])
|
| 114 |
+
if os.path.isdir(class_data_root):
|
| 115 |
+
class_images_path = list(class_data_root.iterdir())
|
| 116 |
+
class_prompt = [concept["class_prompt"]
|
| 117 |
+
for _ in range(len(class_images_path))]
|
| 118 |
+
else:
|
| 119 |
+
with open(class_data_root, "r") as f:
|
| 120 |
+
class_images_path = f.read().splitlines()
|
| 121 |
+
with open(concept["class_prompt"], "r") as f:
|
| 122 |
+
class_prompt = f.read().splitlines()
|
| 123 |
+
|
| 124 |
+
class_img_path = [(x, y) for (x, y) in zip(
|
| 125 |
+
class_images_path, class_prompt)]
|
| 126 |
+
self.class_images_path.extend(
|
| 127 |
+
class_img_path[:num_class_images])
|
| 128 |
+
|
| 129 |
+
random.shuffle(self.instance_images_path)
|
| 130 |
+
self.num_instance_images = len(self.instance_images_path)
|
| 131 |
+
self.num_class_images = len(self.class_images_path)
|
| 132 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
| 133 |
+
self.flip = transforms.RandomHorizontalFlip(0.5 * hflip)
|
| 134 |
+
|
| 135 |
+
self.image_transforms = transforms.Compose(
|
| 136 |
+
[
|
| 137 |
+
self.flip,
|
| 138 |
+
transforms.Resize(
|
| 139 |
+
size, interpolation=transforms.InterpolationMode.BILINEAR),
|
| 140 |
+
transforms.CenterCrop(
|
| 141 |
+
size) if center_crop else transforms.RandomCrop(size),
|
| 142 |
+
transforms.ToTensor(),
|
| 143 |
+
transforms.Normalize([0.5], [0.5]),
|
| 144 |
+
]
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def __len__(self):
|
| 148 |
+
return self._length
|
| 149 |
+
|
| 150 |
+
def preprocess(self, image, scale, resample):
|
| 151 |
+
outer, inner = self.size, scale
|
| 152 |
+
if scale > self.size:
|
| 153 |
+
outer, inner = scale, self.size
|
| 154 |
+
top, left = np.random.randint(
|
| 155 |
+
0, outer - inner + 1), np.random.randint(0, outer - inner + 1)
|
| 156 |
+
image = image.resize((scale, scale), resample=resample)
|
| 157 |
+
image = np.array(image).astype(np.uint8)
|
| 158 |
+
image = (image / 127.5 - 1.0).astype(np.float32)
|
| 159 |
+
instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32)
|
| 160 |
+
mask = np.zeros((self.size // 8, self.size // 8))
|
| 161 |
+
if scale > self.size:
|
| 162 |
+
instance_image = image[top: top + inner, left: left + inner, :]
|
| 163 |
+
mask = np.ones((self.size // 8, self.size // 8))
|
| 164 |
+
else:
|
| 165 |
+
instance_image[top: top + inner, left: left + inner, :] = image
|
| 166 |
+
mask[top // 8 + 1: (top + scale) // 8 - 1, left //
|
| 167 |
+
8 + 1: (left + scale) // 8 - 1] = 1.
|
| 168 |
+
return instance_image, mask
|
| 169 |
+
|
| 170 |
+
def __getprompt__(self, instance_prompt, instance_target):
|
| 171 |
+
if self.concept_type == 'style':
|
| 172 |
+
r = np.random.choice([0, 1, 2])
|
| 173 |
+
instance_prompt = f'{instance_prompt}, in the style of {instance_target}' if r == 0 else f'in {instance_target}\'s style, {instance_prompt}' if r == 1 else f'in {instance_target}\'s style, {instance_prompt}'
|
| 174 |
+
elif self.concept_type == 'object':
|
| 175 |
+
anchor, target = instance_target.split('+')
|
| 176 |
+
instance_prompt = instance_prompt.replace(anchor, target)
|
| 177 |
+
elif self.concept_type == 'memorization':
|
| 178 |
+
instance_prompt = instance_target.split('+')[1]
|
| 179 |
+
return instance_prompt
|
| 180 |
+
|
| 181 |
+
def __getitem__(self, index):
|
| 182 |
+
example = {}
|
| 183 |
+
instance_image, instance_prompt, instance_target = self.instance_images_path[
|
| 184 |
+
index % self.num_instance_images]
|
| 185 |
+
instance_image = Image.open(instance_image)
|
| 186 |
+
if not instance_image.mode == "RGB":
|
| 187 |
+
instance_image = instance_image.convert("RGB")
|
| 188 |
+
instance_image = self.flip(instance_image)
|
| 189 |
+
# modify instance prompt according to the concept_type to include target concept
|
| 190 |
+
# multiple style/object fine-tuning
|
| 191 |
+
if ';' in instance_target:
|
| 192 |
+
instance_target = instance_target.split(';')
|
| 193 |
+
instance_target = instance_target[index % len(instance_target)]
|
| 194 |
+
|
| 195 |
+
instance_anchor_prompt = instance_prompt
|
| 196 |
+
instance_prompt = self.__getprompt__(instance_prompt, instance_target)
|
| 197 |
+
# apply resize augmentation and create a valid image region mask
|
| 198 |
+
random_scale = self.size
|
| 199 |
+
if self.aug:
|
| 200 |
+
random_scale = np.random.randint(self.size // 3, self.size + 1) if np.random.uniform(
|
| 201 |
+
) < 0.66 else np.random.randint(int(1.2 * self.size), int(1.4 * self.size))
|
| 202 |
+
instance_image, mask = self.preprocess(
|
| 203 |
+
instance_image, random_scale, self.interpolation)
|
| 204 |
+
|
| 205 |
+
if random_scale < 0.6 * self.size:
|
| 206 |
+
instance_prompt = np.random.choice(
|
| 207 |
+
["a far away ", "very small "]) + instance_prompt
|
| 208 |
+
elif random_scale > self.size:
|
| 209 |
+
instance_prompt = np.random.choice(
|
| 210 |
+
["zoomed in ", "close up "]) + instance_prompt
|
| 211 |
+
|
| 212 |
+
example["instance_images"] = torch.from_numpy(
|
| 213 |
+
instance_image).permute(2, 0, 1)
|
| 214 |
+
example["mask"] = torch.from_numpy(mask)
|
| 215 |
+
|
| 216 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
| 217 |
+
instance_prompt,
|
| 218 |
+
truncation=True,
|
| 219 |
+
padding="max_length",
|
| 220 |
+
max_length=self.tokenizer.model_max_length,
|
| 221 |
+
return_tensors="pt",
|
| 222 |
+
).input_ids
|
| 223 |
+
example["instance_anchor_prompt_ids"] = self.tokenizer(
|
| 224 |
+
instance_anchor_prompt,
|
| 225 |
+
truncation=True,
|
| 226 |
+
padding="max_length",
|
| 227 |
+
max_length=self.tokenizer.model_max_length,
|
| 228 |
+
return_tensors="pt",
|
| 229 |
+
).input_ids
|
| 230 |
+
|
| 231 |
+
if self.with_prior_preservation:
|
| 232 |
+
class_image, class_prompt = self.class_images_path[index %
|
| 233 |
+
self.num_class_images]
|
| 234 |
+
class_image = Image.open(class_image)
|
| 235 |
+
if not class_image.mode == "RGB":
|
| 236 |
+
class_image = class_image.convert("RGB")
|
| 237 |
+
example["class_images"] = self.image_transforms(class_image)
|
| 238 |
+
example["class_mask"] = torch.ones_like(example["mask"])
|
| 239 |
+
example["class_prompt_ids"] = self.tokenizer(
|
| 240 |
+
class_prompt,
|
| 241 |
+
truncation=True,
|
| 242 |
+
padding="max_length",
|
| 243 |
+
max_length=self.tokenizer.model_max_length,
|
| 244 |
+
return_tensors="pt",
|
| 245 |
+
).input_ids
|
| 246 |
+
|
| 247 |
+
return example
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def isimage(path):
|
| 251 |
+
if 'png' in path.lower() or 'jpg' in path.lower() or 'jpeg' in path.lower():
|
| 252 |
+
return True
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def filter(folder, impath, outpath=None, unfiltered_path=None, threshold=0.15,
|
| 256 |
+
image_threshold=0.5, anchor_size=10, target_size=3, return_score=False):
|
| 257 |
+
model = torch.jit.load(
|
| 258 |
+
"./assets/sscd_imagenet_mixup.torchscript.pt")
|
| 259 |
+
if isinstance(folder, list):
|
| 260 |
+
image_paths = folder
|
| 261 |
+
image_captions = ["None" for _ in range(len(image_paths))]
|
| 262 |
+
elif Path(folder / 'images.txt').exists():
|
| 263 |
+
with open(f'{folder}/images.txt', "r") as f:
|
| 264 |
+
image_paths = f.read().splitlines()
|
| 265 |
+
with open(f'{folder}/caption.txt', "r") as f:
|
| 266 |
+
image_captions = f.read().splitlines()
|
| 267 |
+
else:
|
| 268 |
+
image_paths = [os.path.join(str(folder), file_path)
|
| 269 |
+
for file_path in os.listdir(folder) if isimage(file_path)]
|
| 270 |
+
image_captions = ["None" for _ in range(len(image_paths))]
|
| 271 |
+
|
| 272 |
+
batch = small_288(Image.open(impath).convert('RGB')).unsqueeze(0)
|
| 273 |
+
embedding_target = model(batch)[0, :]
|
| 274 |
+
|
| 275 |
+
filtered_paths = []
|
| 276 |
+
filtered_captions = []
|
| 277 |
+
unfiltered_paths = []
|
| 278 |
+
unfiltered_captions = []
|
| 279 |
+
count_dict = {}
|
| 280 |
+
for im, c in zip(image_paths, image_captions):
|
| 281 |
+
if c not in count_dict:
|
| 282 |
+
count_dict[c] = 0
|
| 283 |
+
if isinstance(folder, list):
|
| 284 |
+
batch = small_288(im).unsqueeze(0)
|
| 285 |
+
else:
|
| 286 |
+
batch = small_288(Image.open(im).convert('RGB')).unsqueeze(0)
|
| 287 |
+
embedding = model(batch)[0, :]
|
| 288 |
+
|
| 289 |
+
diff_sscd = (embedding * embedding_target).sum()
|
| 290 |
+
|
| 291 |
+
if diff_sscd <= image_threshold:
|
| 292 |
+
filtered_paths.append(im)
|
| 293 |
+
filtered_captions.append(c)
|
| 294 |
+
count_dict[c] += 1
|
| 295 |
+
else:
|
| 296 |
+
unfiltered_paths.append(im)
|
| 297 |
+
unfiltered_captions.append(c)
|
| 298 |
+
|
| 299 |
+
# only return score
|
| 300 |
+
if return_score:
|
| 301 |
+
score = len(unfiltered_paths) / \
|
| 302 |
+
(len(unfiltered_paths)+len(filtered_paths))
|
| 303 |
+
return score
|
| 304 |
+
|
| 305 |
+
os.makedirs(outpath, exist_ok=True)
|
| 306 |
+
os.makedirs(f'{outpath}/samples', exist_ok=True)
|
| 307 |
+
with open(f'{outpath}/caption.txt', 'w') as f:
|
| 308 |
+
for each in filtered_captions:
|
| 309 |
+
f.write(each.strip() + '\n')
|
| 310 |
+
|
| 311 |
+
with open(f'{outpath}/images.txt', 'w') as f:
|
| 312 |
+
for each in filtered_paths:
|
| 313 |
+
f.write(each.strip() + '\n')
|
| 314 |
+
imbase = Path(each).name
|
| 315 |
+
shutil.copy(each, f'{outpath}/samples/{imbase}')
|
| 316 |
+
|
| 317 |
+
print('++++++++++++++++++++++++++++++++++++++++++++++++')
|
| 318 |
+
print('+ Filter Summary +')
|
| 319 |
+
print(f'+ Remained images: {len(filtered_paths)}')
|
| 320 |
+
print(f'+ Filtered images: {len(unfiltered_paths)}')
|
| 321 |
+
print('++++++++++++++++++++++++++++++++++++++++++++++++')
|
| 322 |
+
|
| 323 |
+
sorted_list = sorted(list(count_dict.items()),
|
| 324 |
+
key=lambda x: x[1], reverse=True)
|
| 325 |
+
anchor_prompts = [c[0] for c in sorted_list[:anchor_size]]
|
| 326 |
+
target_prompts = [c[0] for c in sorted_list[-target_size:]]
|
| 327 |
+
return anchor_prompts, target_prompts, len(filtered_paths)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def getanchorprompts(pipeline, accelerator, class_prompt, concept_type, class_images_dir, api_key, num_class_images=200, mem_impath=None):
|
| 331 |
+
openai.api_key = api_key
|
| 332 |
+
class_prompt_collection = []
|
| 333 |
+
caption_target = []
|
| 334 |
+
if concept_type == 'object':
|
| 335 |
+
messages = [{"role": "system", "content": "You can describe any image via text and provide captions for wide variety of images that is possible to generate."}]
|
| 336 |
+
messages = [{"role": "user", "content": f"Generate {num_class_images} captions for images containing a {class_prompt}. The caption should also contain the word \"{class_prompt}\" "}]
|
| 337 |
+
while True:
|
| 338 |
+
completion = openai.ChatCompletion.create(
|
| 339 |
+
model="gpt-3.5-turbo",
|
| 340 |
+
messages=messages
|
| 341 |
+
)
|
| 342 |
+
class_prompt_collection += [x for x in completion.choices[0].message.content.lower(
|
| 343 |
+
).split('\n') if class_prompt in x]
|
| 344 |
+
messages.append(
|
| 345 |
+
{"role": "assistant", "content": completion.choices[0].message.content})
|
| 346 |
+
messages.append(
|
| 347 |
+
{"role": "user", "content": f"Generate {num_class_images-len(class_prompt_collection)} more captions"})
|
| 348 |
+
if len(class_prompt_collection) >= num_class_images:
|
| 349 |
+
break
|
| 350 |
+
class_prompt_collection = clean_prompt(class_prompt_collection)[
|
| 351 |
+
:num_class_images]
|
| 352 |
+
|
| 353 |
+
elif concept_type == 'memorization':
|
| 354 |
+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 355 |
+
pipeline.scheduler.config)
|
| 356 |
+
num_prompts_firstpass = 5
|
| 357 |
+
num_prompts_secondpass = 2
|
| 358 |
+
threshold = 0.3
|
| 359 |
+
# Generate num_prompts_firstpass paraphrases which generate different content at least 1-threshold % of the times.
|
| 360 |
+
os.makedirs(class_images_dir / 'temp/', exist_ok=True)
|
| 361 |
+
class_prompt_collection_counter = []
|
| 362 |
+
caption_target = []
|
| 363 |
+
prev_captions = []
|
| 364 |
+
messages = [{"role": "user", "content": f"Generate {4*num_prompts_firstpass} different paraphrase of the caption: {class_prompt}. Preserve the meaning when paraphrasing."}]
|
| 365 |
+
while True:
|
| 366 |
+
completion = openai.ChatCompletion.create(
|
| 367 |
+
model="gpt-3.5-turbo",
|
| 368 |
+
messages=messages
|
| 369 |
+
)
|
| 370 |
+
# print(completion.choices[0].message.content.lower().split('\n'))
|
| 371 |
+
class_prompt_collection_ = [x.strip(
|
| 372 |
+
) for x in completion.choices[0].message.content.lower().split('\n') if x.strip() != '']
|
| 373 |
+
class_prompt_collection_ = clean_prompt(class_prompt_collection_)
|
| 374 |
+
# print(class_prompt_collection_)
|
| 375 |
+
for prompt in tqdm(
|
| 376 |
+
class_prompt_collection_, desc="Generating anchor and target prompts ", disable=not accelerator.is_local_main_process
|
| 377 |
+
):
|
| 378 |
+
print(f'Prompt: {prompt}')
|
| 379 |
+
images = pipeline([prompt]*10, num_inference_steps=25,).images
|
| 380 |
+
|
| 381 |
+
score = filter(images, mem_impath, return_score=True)
|
| 382 |
+
print(f'Memorization rate: {score}')
|
| 383 |
+
if score <= threshold and prompt not in class_prompt_collection and len(class_prompt_collection) < num_prompts_firstpass:
|
| 384 |
+
class_prompt_collection += [prompt]
|
| 385 |
+
class_prompt_collection_counter += [score]
|
| 386 |
+
elif score >= 0.6 and prompt not in caption_target and len(caption_target) < 2:
|
| 387 |
+
caption_target += [prompt]
|
| 388 |
+
if len(class_prompt_collection) >= num_prompts_firstpass and len(caption_target) >= 2:
|
| 389 |
+
break
|
| 390 |
+
|
| 391 |
+
if len(class_prompt_collection) >= num_prompts_firstpass:
|
| 392 |
+
break
|
| 393 |
+
# print("prompts till now", class_prompt_collection, caption_target)
|
| 394 |
+
# print("prompts till now", len(
|
| 395 |
+
# class_prompt_collection), len(caption_target))
|
| 396 |
+
prev_captions += class_prompt_collection_
|
| 397 |
+
prev_captions_ = ','.join(prev_captions[-40:])
|
| 398 |
+
|
| 399 |
+
messages = [
|
| 400 |
+
{"role": "user", "content": f"Generate {4*(num_prompts_firstpass- len(class_prompt_collection))} different paraphrase of the caption: {class_prompt}. Preserve the meaning the most when paraphrasing. Also make sure that the new captions are different from the following captions: {prev_captions_[:4000]}"}]
|
| 401 |
+
|
| 402 |
+
# Generate more paraphrases using the captions we retrieved above.
|
| 403 |
+
for prompt in class_prompt_collection[:num_prompts_firstpass]:
|
| 404 |
+
completion = openai.ChatCompletion.create(
|
| 405 |
+
model="gpt-3.5-turbo",
|
| 406 |
+
messages=[
|
| 407 |
+
{"role": "user", "content": f"Generate {num_prompts_secondpass} different paraphrases of: {prompt}. "}]
|
| 408 |
+
|
| 409 |
+
)
|
| 410 |
+
class_prompt_collection += clean_prompt(
|
| 411 |
+
[x.strip() for x in completion.choices[0].message.content.lower().split('\n') if x.strip() != ''])
|
| 412 |
+
|
| 413 |
+
for prompt in tqdm(class_prompt_collection[num_prompts_firstpass:], desc="Memorization rate for final prompts"):
|
| 414 |
+
images = pipeline([prompt]*10, num_inference_steps=25,).images
|
| 415 |
+
|
| 416 |
+
class_prompt_collection_counter += [
|
| 417 |
+
filter(images, mem_impath, return_score=True)]
|
| 418 |
+
|
| 419 |
+
# select least ten and most memorized text prompts to be selected as anchor and target prompts.
|
| 420 |
+
class_prompt_collection = sorted(
|
| 421 |
+
zip(class_prompt_collection, class_prompt_collection_counter), key=lambda x: x[1])
|
| 422 |
+
caption_target += [x for (x, y) in class_prompt_collection if y >= 0.6]
|
| 423 |
+
class_prompt_collection = [
|
| 424 |
+
x for (x, y) in class_prompt_collection if y <= threshold][:10]
|
| 425 |
+
print("Anchor prompts:", class_prompt_collection)
|
| 426 |
+
print("Target prompts:", caption_target)
|
| 427 |
+
return class_prompt_collection, ';*+'.join(caption_target)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def clean_prompt(class_prompt_collection):
|
| 431 |
+
class_prompt_collection = [re.sub(
|
| 432 |
+
r"[0-9]+", lambda num: '' * len(num.group(0)), prompt) for prompt in class_prompt_collection]
|
| 433 |
+
class_prompt_collection = [re.sub(
|
| 434 |
+
r"^\.+", lambda dots: '' * len(dots.group(0)), prompt) for prompt in class_prompt_collection]
|
| 435 |
+
class_prompt_collection = [x.strip() for x in class_prompt_collection]
|
| 436 |
+
class_prompt_collection = [x.replace('"', '') for x in class_prompt_collection]
|
| 437 |
+
return class_prompt_collection
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def safe_dir(dir):
|
| 441 |
+
if not dir.exists():
|
| 442 |
+
dir.mkdir()
|
| 443 |
+
return dir
|
images/applications.png
ADDED
|
Git LFS Details
|
models/greg_rutkowski_ablation_delta.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5e096c941d90652bbed08e912f899de7aaeb895a3797a11bcaf49a2733580a9d
|
| 3 |
+
size 76685761
|
models/vangogh_ablation_delta.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b92f3c3a7aa8887c6f358ac73eb6a4d41387c4b4044dceda179a21846948f50
|
| 3 |
+
size 76685761
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
scipy
|
| 5 |
+
accelerate
|
| 6 |
+
modelcards
|
| 7 |
+
transformers>=4.25.1
|
| 8 |
+
diffusers==0.19.0
|
| 9 |
+
tqdm
|
| 10 |
+
openai
|
| 11 |
+
triton
|
| 12 |
+
xformers>=0.0.20
|
| 13 |
+
bitsandbytes
|
trainer.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import PIL.Image
|
| 3 |
+
import shlex
|
| 4 |
+
import shutil
|
| 5 |
+
import subprocess
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
|
| 14 |
+
w, h = image.size
|
| 15 |
+
if w == h:
|
| 16 |
+
return image
|
| 17 |
+
elif w > h:
|
| 18 |
+
new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
|
| 19 |
+
new_image.paste(image, (0, (w - h) // 2))
|
| 20 |
+
return new_image
|
| 21 |
+
else:
|
| 22 |
+
new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
|
| 23 |
+
new_image.paste(image, ((h - w) // 2, 0))
|
| 24 |
+
return new_image
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def train_submit(
|
| 29 |
+
prompt, anchor_prompt, concept_type, reg_lambda, iterations, lr, openai_key, save_path, mem_impath=None
|
| 30 |
+
):
|
| 31 |
+
if not torch.cuda.is_available():
|
| 32 |
+
raise gr.Error('CUDA is not available.')
|
| 33 |
+
|
| 34 |
+
torch.cuda.empty_cache()
|
| 35 |
+
original_prompt = prompt
|
| 36 |
+
parameter_group = "cross-attn"
|
| 37 |
+
train_batch_size = 4
|
| 38 |
+
if concept_type == 'style':
|
| 39 |
+
class_data_dir = f'./data/samples_painting/'
|
| 40 |
+
anchor_prompt = f'./assets/painting.txt'
|
| 41 |
+
openai_key = ''
|
| 42 |
+
elif concept_type == 'object':
|
| 43 |
+
os.makedirs('temp', exist_ok=True)
|
| 44 |
+
class_data_dir = f'./temp/{anchor_prompt}'
|
| 45 |
+
name = save_path.split('/')[-1]
|
| 46 |
+
prompt = f'{anchor_prompt}+{prompt}'
|
| 47 |
+
assert openai_key is not None
|
| 48 |
+
|
| 49 |
+
if len(openai_key.split('\n')) > 1:
|
| 50 |
+
openai_key = openai_key.split('\n')
|
| 51 |
+
with open(f'./temp/{name}.txt', 'w') as f:
|
| 52 |
+
for prompt_ in openai_key:
|
| 53 |
+
f.write(prompt_.strip()+'\n')
|
| 54 |
+
openai_key = ''
|
| 55 |
+
anchor_prompt = f'./temp/{name}.txt'
|
| 56 |
+
elif concept_type == 'memorization':
|
| 57 |
+
os.system("wget https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_mixup.torchscript.pt -P assets/")
|
| 58 |
+
os.makedirs('temp', exist_ok=True)
|
| 59 |
+
prompt = f'*+{prompt}'
|
| 60 |
+
name = save_path.split('/')[-1]
|
| 61 |
+
train_batch_size = 1
|
| 62 |
+
lr = 5e-7
|
| 63 |
+
parameter_group = "full-weight"
|
| 64 |
+
|
| 65 |
+
assert openai_key is not None
|
| 66 |
+
assert mem_impath is not None
|
| 67 |
+
|
| 68 |
+
if len(openai_key.split('\n')) > 1:
|
| 69 |
+
openai_key = openai_key.split('\n')
|
| 70 |
+
with open(f'./temp/{name}.txt', 'w') as f:
|
| 71 |
+
for prompt_ in openai_key:
|
| 72 |
+
f.write(prompt_.strip()+'\n')
|
| 73 |
+
openai_key = ''
|
| 74 |
+
anchor_prompt = f'./temp/{name}.txt'
|
| 75 |
+
else:
|
| 76 |
+
anchor_prompt = prompt
|
| 77 |
+
|
| 78 |
+
print(mem_impath)
|
| 79 |
+
image = PIL.Image.open(mem_impath[0][0].name)
|
| 80 |
+
image = pad_image(image)
|
| 81 |
+
image = image.convert('RGB')
|
| 82 |
+
mem_impath = f"./temp/{original_prompt.lower().replace(' ', '')}.jpg"
|
| 83 |
+
image.save(mem_impath, format='JPEG', quality=100)
|
| 84 |
+
|
| 85 |
+
class_data_dir = f"./temp/{original_prompt.lower().replace(' ', '')}"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
command = f'''
|
| 89 |
+
accelerate launch concept-ablation-diffusers/train.py \
|
| 90 |
+
--pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \
|
| 91 |
+
--output_dir={save_path} \
|
| 92 |
+
--class_data_dir={class_data_dir} \
|
| 93 |
+
--class_prompt="{anchor_prompt}" \
|
| 94 |
+
--caption_target "{prompt}" \
|
| 95 |
+
--concept_type {concept_type} \
|
| 96 |
+
--resolution=512 \
|
| 97 |
+
--train_batch_size={train_batch_size} \
|
| 98 |
+
--learning_rate={lr} \
|
| 99 |
+
--max_train_steps={iterations} \
|
| 100 |
+
--scale_lr --hflip \
|
| 101 |
+
--parameter_group {parameter_group} \
|
| 102 |
+
--openai_key "{openai_key}" \
|
| 103 |
+
--enable_xformers_memory_efficient_attention --num_class_images 500
|
| 104 |
+
'''
|
| 105 |
+
|
| 106 |
+
if concept_type == 'style':
|
| 107 |
+
command += f' --noaug'
|
| 108 |
+
|
| 109 |
+
if concept_type == 'memorization':
|
| 110 |
+
command += f' --use_8bit_adam --with_prior_preservation --prior_loss_weight=1.0 --mem_impath {mem_impath}'
|
| 111 |
+
|
| 112 |
+
with open(f'{save_path}/train.sh', 'w') as f:
|
| 113 |
+
command_s = ' '.join(command.split())
|
| 114 |
+
f.write(command_s)
|
| 115 |
+
|
| 116 |
+
res = subprocess.run(shlex.split(command))
|
| 117 |
+
|
| 118 |
+
if res.returncode == 0:
|
| 119 |
+
result_message = 'Training Completed!'
|
| 120 |
+
else:
|
| 121 |
+
result_message = 'Training Failed!'
|
| 122 |
+
weight_paths = sorted(Path(save_path).glob('*.bin'))
|
| 123 |
+
print(weight_paths)
|
| 124 |
+
return gr.update(value=result_message), weight_paths[0]
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def inference(model_path, prompt, n_steps, generator):
|
| 128 |
+
import sys
|
| 129 |
+
sys.path.append('concept-ablation/diffusers/.')
|
| 130 |
+
from model_pipeline import CustomDiffusionPipeline
|
| 131 |
+
import torch
|
| 132 |
+
|
| 133 |
+
pipe = CustomDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda")
|
| 134 |
+
image1 = pipe(prompt, num_inference_steps=n_steps, guidance_scale=6., eta=1., generator=generator).images[0]
|
| 135 |
+
|
| 136 |
+
pipe.load_model(model_path)
|
| 137 |
+
image2 = pipe(prompt, num_inference_steps=n_steps, guidance_scale=6., eta=1., generator=generator).images[0]
|
| 138 |
+
|
| 139 |
+
return image1, image2
|