1.6b
Browse files- README.md +70 -15
- girl.jpg +2 -2
- media/girl.jpg +2 -2
- media/result_grid.jpg +2 -2
- model_index.json +2 -2
- pipeline_sdxs.py +282 -172
- processor/chat_template.jinja +154 -0
- processor/processor_config.json +3 -0
- processor/tokenizer.json +3 -0
- processor/tokenizer_config.json +3 -0
- samples/unet_384x704_0.jpg +2 -2
- samples/unet_416x704_0.jpg +2 -2
- samples/unet_448x704_0.jpg +2 -2
- samples/unet_480x704_0.jpg +2 -2
- samples/unet_512x704_0.jpg +2 -2
- samples/unet_544x704_0.jpg +2 -2
- samples/unet_576x704_0.jpg +2 -2
- samples/unet_608x704_0.jpg +2 -2
- samples/unet_640x704_0.jpg +2 -2
- samples/unet_672x704_0.jpg +2 -2
- samples/unet_704x384_0.jpg +2 -2
- samples/unet_704x416_0.jpg +2 -2
- samples/unet_704x448_0.jpg +2 -2
- samples/unet_704x480_0.jpg +2 -2
- samples/unet_704x512_0.jpg +2 -2
- samples/unet_704x544_0.jpg +2 -2
- samples/unet_704x576_0.jpg +2 -2
- samples/unet_704x608_0.jpg +2 -2
- samples/unet_704x640_0.jpg +2 -2
- samples/unet_704x672_0.jpg +2 -2
- samples/unet_704x704_0.jpg +2 -2
- src/unet1.5b.ipynb +1 -1
- test.ipynb +2 -2
- train.py +2 -2
- unet/config.json +1 -1
- unet/diffusion_pytorch_model.safetensors +2 -2
README.md
CHANGED
|
@@ -12,15 +12,30 @@ datasets:
|
|
| 12 |
|
| 13 |
At AiArtLab, we strive to create a free, compact and fast model that can be trained on consumer graphics cards.
|
| 14 |
|
| 15 |
-
- Unet:
|
| 16 |
-
- Qwen3.5:
|
| 17 |
-
- VAE:
|
| 18 |
-
- Speed:
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |

|
| 22 |
|
| 23 |
-
###
|
| 24 |
|
| 25 |
```
|
| 26 |
import torch
|
|
@@ -46,9 +61,41 @@ image = pipe(
|
|
| 46 |
image.show(image)
|
| 47 |
```
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
### VAE
|
| 50 |
|
| 51 |
-
The VAE in Simple Diffusion utilizes an asymmetric VAE architecture featuring an 8x encoder and a 16x decoder. While a compression factor of 8 is maintained during training, the resolution is effectively doubled during inference through an additional upscaling block. This strategy reduces training costs by an order of magnitude and boosts inference speed without perceptual quality loss. Effectively, this acts as an integrated latent upscaler. To ensure a fair comparison with other VAEs, we downsampled the generated images to match the input resolution for metric evaluation. The SDXS VAE was not trained from scratch but was initialized from weights of FLUX 2 VAE, then redisigned and retrained.
|
| 52 |
|
| 53 |
[eval.py](src/eval.py)
|
| 54 |
```
|
|
@@ -61,11 +108,16 @@ FLUX.2 | MSE=2.425e-04 PSNR=38.33 LPIPS=0.023 Edge=0.065 KL=2.160
|
|
| 61 |
Wan2.2-TI2V-5B (2Gb) | MSE=7.034e-04 PSNR=34.65 LPIPS=0.050 Edge=0.115 KL=9.429
|
| 62 |
sdxs-1b (200Mb) | MSE=2.655e-04 PSNR=37.83 LPIPS=0.026 Edge=0.066 KL=2.170
|
| 63 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
### Unet
|
| 65 |
|
| 66 |
The UNet architecture in Simple Diffusion is a direct descendant and conceptual continuation of the ideas introduced in the first version of Stable Diffusion. Key distinctions include a relatively small, yet sufficient, number of transformer blocks that ensure an even distribution of attention. Additionally, the number of channels in the final layer has been significantly increased to improve detail rendering. Overall, however, it remains a UNet, similar to SD 1.5.
|
| 67 |
|
| 68 |
-
Throughout the experiments, we tested hundreds of different configurations and trained dozens of models. Notably, we initially started from the SDXL architecture, assuming it would be a stronger baseline, but ultimately abandoned all the innovations proposed in it. These included uneven attention distribution with increased transformer block depth in the lower layers, a reduced number of blocks in the channel pyramid, micro-conditioning, the dual text encoder, text-time and so on. According to our experiments, all of these changes lead to increased training time and costs while having a near-zero or negative impact on the final result. In total, the investigation of various architectures and the search for the most efficient and optimal configuration took over a year.
|
| 69 |
|
| 70 |
Unfortunately, we were unable to secure grants for model training, with the exception of a grant from Google TPU—which, unfortunately, we were unable to utilize due to insufficient preparation and time constraints. As a result, training and experiments were financed primarily from our own funds and user donations. This left a significant mark on the model’s architecture.
|
| 71 |
We aimed to make it as small and cost-effective to train as possible while maintaining our quality generation requirements. So perhaps the limited budget even worked to our advantage.
|
|
@@ -87,14 +139,18 @@ Additionally, the use of a full-fledged language model allowed us to integrate a
|
|
| 87 |
This adventure started in December 2024 after the release of the SANA model. We received a donation from Stan for fine-tuning SANA and, together with Stas, began fine-tuning and further developing it. Despite spending the entire budget, we did not achieve significant improvements. However, we were shocked by how poorly the model was trained and designed, and we became convinced that we could do better—though we were wrong.
|
| 88 |
Shifting Gears
|
| 89 |
By February 2025, we split our efforts and began designing our own architectures—which we are still doing today. Stas favored the DiT architecture, while I believed in UNet. Despite some differences in architectural views, we maintained close communication, shared our work, and supported each other throughout the process. We also engaged with the AIArtLab community (a virtual Telegram chat for those contributing to model development)—thank you all for your support.
|
| 90 |
-
##
|
| 91 |
-
One of my key mistakes was relying too heavily on LLMs and research papers. Research often presents minor improvements as groundbreaking innovations, and LLMs, trained on such content, can draw incorrect conclusions
|
| 92 |
-
This shift led me to adopt a zero-trust policy toward any external information not personally verified.
|
|
|
|
|
|
|
| 93 |
## The Evolutionary Path
|
| 94 |
The second turning point was the transition to a continuous evolutionary improvement strategy. Unfortunately, the Butterflies dataset does not allow for evaluating prompt-following or anatomical generation capabilities. As a result, the model evolved incrementally rather than through revolutionary changes. The same model, from December 2025, underwent around 10 changes, including radical architectural shifts—while always preserving the pre-trained weights. It’s remarkable how well and quickly pre-trained models adapt to changes in architecture and external factors, even radical ones (e.g., switching VAE models, text encoders, or their combinations).
|
| 95 |
In addition to saving on training costs, this approach helped maintain minimal model size—for example, adding extra transformer blocks followed by an assessment of necessity and rolling back if the changes had no significant impact.
|
|
|
|
|
|
|
| 96 |
## The Role of Hyperparameters
|
| 97 |
-
One of the initial mistakes was an excessive focus on hyperparameters during training. Ironically, 80% of training speed and quality depend on the model architecture (UNet) and the quality of embeddings (VAE), while other 20% is influenced by the text encoder’s embeddings. The rest is
|
| 98 |
## Tools and Optimization
|
| 99 |
The model comes with two scripts:
|
| 100 |
|
|
@@ -102,8 +158,7 @@ A dataset script to convert a folder of image-text pairs into latent representat
|
|
| 102 |
A training script provided as a single monolithic file.
|
| 103 |
Additionally, there’s a script that can be pasted directly into the terminal to automatically train the model with optimized parameters.
|
| 104 |
## Training Optimization
|
| 105 |
-
All pre-training was done using the AdamW8bit optimizer, which significantly reduced training costs.
|
| 106 |
-
|
| 107 |
|
| 108 |
### Train:
|
| 109 |
|
|
|
|
| 12 |
|
| 13 |
At AiArtLab, we strive to create a free, compact and fast model that can be trained on consumer graphics cards.
|
| 14 |
|
| 15 |
+
- Unet: 1.6b parameters
|
| 16 |
+
- Qwen3.5: 1.8b parameters
|
| 17 |
+
- VAE: 32ch8x16x
|
| 18 |
+
- Speed: Sampling: 100%|██████████| 40/40 [00:01<00:00, 29.98it/s]
|
| 19 |
+
- Resolution: from 768px to 1404px, with step 64px
|
| 20 |
+
- Limitations: trained on small dataset ~1-2kk, focused on illustrations
|
| 21 |
+
|
| 22 |
+
### Train in progress
|
| 23 |
+
|
| 24 |
+
Key points
|
| 25 |
+
|
| 26 |
+
- Dec 24: Started research on Linear Transformers.
|
| 27 |
+
- Feb 25: Started research on UNet-based diffusion models.
|
| 28 |
+
- Aug 25: Started research on different VAEs.
|
| 29 |
+
- Sep 25: Created a simple VAE and a [vae collection](https://huggingface.co/AiArtLab/collections).
|
| 30 |
+
- Dec 25: Trained SDXS-1B (0.8B at this moment), featuring an SD1.5-like UNet, Long CLIP, 16-channel simple VAE, and flow matching target.
|
| 31 |
+
- Jan 25: Implemented a dual text encoder (SDXL-like style). Total rework.
|
| 32 |
+
- Feb 25: Reverted to classic architecture; tested all SDXL innovations and went back to simple diffusion. Total rework.
|
| 33 |
+
- Mar 25: Created an 32ch 8x/16x asymmetric VAE and switched to Qwen3.5 2B as text encoder.
|
| 34 |
+
|
| 35 |
+
### Samples with seed 0
|
| 36 |

|
| 37 |
|
| 38 |
+
### Text 2 image
|
| 39 |
|
| 40 |
```
|
| 41 |
import torch
|
|
|
|
| 61 |
image.show(image)
|
| 62 |
```
|
| 63 |
|
| 64 |
+
### Image upscale
|
| 65 |
+
```
|
| 66 |
+
upscaled = pipe.image_upscale("media/girl.jpg")
|
| 67 |
+
upscaled[0].show()
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
### Prompt refine
|
| 71 |
+
```
|
| 72 |
+
refined = pipe.refine_prompts("girl")
|
| 73 |
+
|
| 74 |
+
print(refined)
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Encode image (experimental)
|
| 78 |
+
```
|
| 79 |
+
emb, mask = pipe.encode_image("media/girl.jpg")
|
| 80 |
+
|
| 81 |
+
# Проверяем
|
| 82 |
+
print("Pooled vector shape:", emb[:, 0, :].shape)
|
| 83 |
+
image = pipeline(
|
| 84 |
+
prompt_embeds = emb,
|
| 85 |
+
prompt_attention_mask = mask,
|
| 86 |
+
negative_prompt = negative_prompt,
|
| 87 |
+
guidance_scale = 4,
|
| 88 |
+
width = 1088,
|
| 89 |
+
height = 1344,
|
| 90 |
+
seed = 0,
|
| 91 |
+
batch_size = 1,
|
| 92 |
+
)[0]
|
| 93 |
+
image[0].show()
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
### VAE
|
| 97 |
|
| 98 |
+
The VAE in Simple Diffusion utilizes an asymmetric VAE architecture featuring an 8x encoder and a 16x decoder. While a compression factor of 8 is maintained during training, the resolution is effectively doubled during inference through an additional upscaling block. This strategy reduces training costs by an order of magnitude and boosts inference speed without perceptual quality loss. Effectively, this acts as an integrated latent upscaler. To ensure a fair comparison with other VAEs, we downsampled the generated images to match the input resolution for metric evaluation. The SDXS VAE was not trained from scratch but was initialized from weights of FLUX 2 VAE, then redisigned and retrained. We also trained [16 ch vae](https://huggingface.co/AiArtLab/simplevae) with flux.1 quality based on aura vae.
|
| 99 |
|
| 100 |
[eval.py](src/eval.py)
|
| 101 |
```
|
|
|
|
| 108 |
Wan2.2-TI2V-5B (2Gb) | MSE=7.034e-04 PSNR=34.65 LPIPS=0.050 Edge=0.115 KL=9.429
|
| 109 |
sdxs-1b (200Mb) | MSE=2.655e-04 PSNR=37.83 LPIPS=0.026 Edge=0.066 KL=2.170
|
| 110 |
```
|
| 111 |
+
|
| 112 |
+
### Image upscale
|
| 113 |
+
|
| 114 |
+
One interesting feature of the asymmetric VAE is the ability to use it as a standalone image and video upscaler. This VAE was trained at resolutions of 512–768 pixels and is effective within this range. It should be noted that this is a latent upscaler, making it simple and fast. It is a "blind" upscaler; unlike model-based upscalers, it interferes with the process minimally and does not alter the essence of the image. This may be useful if you dislike it when upscalers change the image style or phone model—inventing something new based on the original image. On the other hand, you might not like it, as it changes the original minimally.
|
| 115 |
+
|
| 116 |
### Unet
|
| 117 |
|
| 118 |
The UNet architecture in Simple Diffusion is a direct descendant and conceptual continuation of the ideas introduced in the first version of Stable Diffusion. Key distinctions include a relatively small, yet sufficient, number of transformer blocks that ensure an even distribution of attention. Additionally, the number of channels in the final layer has been significantly increased to improve detail rendering. Overall, however, it remains a UNet, similar to SD 1.5.
|
| 119 |
|
| 120 |
+
Throughout the experiments, we tested [hundreds](https://wandb.ai/recoilme) of different configurations and trained dozens of [models](https://huggingface.co/AiArtLab/sdxs). Notably, we initially started from the SDXL architecture, assuming it would be a stronger baseline, but ultimately abandoned all the innovations proposed in it. These included uneven attention distribution with increased transformer block depth in the lower layers, a reduced number of blocks in the channel pyramid, micro-conditioning, the dual text encoder, text-time and so on. According to our experiments, all of these changes lead to increased training time and costs while having a near-zero or negative impact on the final result. In total, the investigation of various architectures and the search for the most efficient and optimal configuration took over a year.
|
| 121 |
|
| 122 |
Unfortunately, we were unable to secure grants for model training, with the exception of a grant from Google TPU—which, unfortunately, we were unable to utilize due to insufficient preparation and time constraints. As a result, training and experiments were financed primarily from our own funds and user donations. This left a significant mark on the model’s architecture.
|
| 123 |
We aimed to make it as small and cost-effective to train as possible while maintaining our quality generation requirements. So perhaps the limited budget even worked to our advantage.
|
|
|
|
| 139 |
This adventure started in December 2024 after the release of the SANA model. We received a donation from Stan for fine-tuning SANA and, together with Stas, began fine-tuning and further developing it. Despite spending the entire budget, we did not achieve significant improvements. However, we were shocked by how poorly the model was trained and designed, and we became convinced that we could do better—though we were wrong.
|
| 140 |
Shifting Gears
|
| 141 |
By February 2025, we split our efforts and began designing our own architectures—which we are still doing today. Stas favored the DiT architecture, while I believed in UNet. Despite some differences in architectural views, we maintained close communication, shared our work, and supported each other throughout the process. We also engaged with the AIArtLab community (a virtual Telegram chat for those contributing to model development)—thank you all for your support.
|
| 142 |
+
## Main mistake
|
| 143 |
+
One of my key mistakes was relying too heavily on LLMs and research papers. Research often presents minor improvements as groundbreaking innovations, and LLMs, trained on such content, can draw incorrect conclusions. From autumn 2025, I radically changed my strategy, switching to training simpler models (VAEs), where simple fine-tuning yielded more substantial improvements than expensive research projects—including fine-tuning a VAE to a quality level comparable to Flux-1 at the time.
|
| 144 |
+
This shift led me to adopt a zero-trust policy toward any external information not personally verified. This does not mean that you should not read papers, but I urge you not to trust the conclusions presented in them. This is an extremely radical approach, and I have intentionally radicalized it, but it allowed me to transition from reading papers and implementing other people's ideas to generating my own and training models.
|
| 145 |
+
|
| 146 |
+
As a result, I focused on building a strong local benchmark for rapid, cost-effective experiments on single rtx4080. This led me to train models on the "Butterflies" dataset—a set of 1,000 images of butterflies—where a model could be trained from scratch in just an hour to assess the impact of a hypothesis or improvement, [example](https://www.comet.com/recoilme/unet/356142c52c314078914d0c0db409e1f3?experiment-tab=images&viewId=new).
|
| 147 |
## The Evolutionary Path
|
| 148 |
The second turning point was the transition to a continuous evolutionary improvement strategy. Unfortunately, the Butterflies dataset does not allow for evaluating prompt-following or anatomical generation capabilities. As a result, the model evolved incrementally rather than through revolutionary changes. The same model, from December 2025, underwent around 10 changes, including radical architectural shifts—while always preserving the pre-trained weights. It’s remarkable how well and quickly pre-trained models adapt to changes in architecture and external factors, even radical ones (e.g., switching VAE models, text encoders, or their combinations).
|
| 149 |
In addition to saving on training costs, this approach helped maintain minimal model size—for example, adding extra transformer blocks followed by an assessment of necessity and rolling back if the changes had no significant impact.
|
| 150 |
+
## tldr;
|
| 151 |
+
Stop reading, start training
|
| 152 |
## The Role of Hyperparameters
|
| 153 |
+
One of the initial mistakes was an excessive focus on hyperparameters during training. Ironically, 80% of training speed and quality depend on the model architecture (UNet) and the quality of embeddings (VAE), while other 20% is influenced by the text encoder’s embeddings. The rest is Role of Hyperparameters. The irony here is that Adam (adamw8bit) is surprisingly forgiving of hyperparameter errors, so I won’t even list them. Default is ok.
|
| 154 |
## Tools and Optimization
|
| 155 |
The model comes with two scripts:
|
| 156 |
|
|
|
|
| 158 |
A training script provided as a single monolithic file.
|
| 159 |
Additionally, there’s a script that can be pasted directly into the terminal to automatically train the model with optimized parameters.
|
| 160 |
## Training Optimization
|
| 161 |
+
All pre-training was done using the AdamW8bit optimizer, which significantly reduced training costs.
|
|
|
|
| 162 |
|
| 163 |
### Train:
|
| 164 |
|
girl.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
media/girl.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
media/result_grid.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
model_index.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d717995bd1e270fd4694a62255b159df8ec189022ac70567e1d888ac8959161b
|
| 3 |
+
size 503
|
pipeline_sdxs.py
CHANGED
|
@@ -7,7 +7,6 @@ from dataclasses import dataclass
|
|
| 7 |
from diffusers import DiffusionPipeline
|
| 8 |
from diffusers.utils import BaseOutput
|
| 9 |
from tqdm import tqdm
|
| 10 |
-
from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5Tokenizer
|
| 11 |
|
| 12 |
@dataclass
|
| 13 |
class SdxsPipelineOutput(BaseOutput):
|
|
@@ -15,11 +14,14 @@ class SdxsPipelineOutput(BaseOutput):
|
|
| 15 |
prompt: Optional[Union[str, List[str]]] = None
|
| 16 |
|
| 17 |
class SdxsPipeline(DiffusionPipeline):
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
super().__init__()
|
| 20 |
self.register_modules(
|
| 21 |
vae=vae,
|
| 22 |
text_encoder=text_encoder,
|
|
|
|
| 23 |
tokenizer=tokenizer,
|
| 24 |
unet=unet,
|
| 25 |
scheduler=scheduler
|
|
@@ -30,109 +32,252 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 30 |
if mean is not None and std is not None:
|
| 31 |
self.vae_latents_std = torch.tensor(std, device=self.unet.device, dtype=self.unet.dtype).view(1, len(std), 1, 1)
|
| 32 |
self.vae_latents_mean = torch.tensor(mean, device=self.unet.device, dtype=self.unet.dtype).view(1, len(mean), 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
image = torch.from_numpy(image)
|
| 56 |
-
return 2.0 * image - 1.0 # [-1, 1]
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
if texts is None:
|
| 62 |
-
texts = ""
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
for t in texts:
|
| 71 |
-
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 72 |
-
res_text = self.tokenizer.apply_chat_template(
|
| 73 |
-
messages,
|
| 74 |
-
add_generation_prompt=True,
|
| 75 |
-
tokenize=False
|
| 76 |
-
)
|
| 77 |
-
formatted_prompts.append(res_text)
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
formatted_prompts,
|
| 82 |
-
padding="max_length",
|
| 83 |
-
max_length=248,
|
| 84 |
-
truncation=True, # Не забываем обрезать, если вдруг длиннее
|
| 85 |
-
return_tensors="pt"
|
| 86 |
-
).to(device)
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
seq_len = toks.attention_mask.sum(dim=1) - 1
|
| 98 |
-
pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
pooled_expanded = pooled.unsqueeze(1)
|
| 103 |
-
|
| 104 |
-
# 2. Объединяем последовательность токенов и пулинг-вектор
|
| 105 |
-
# !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
|
| 106 |
-
# Теперь: [B, 1 + L, 1024]. Пулинг стал токеном в НАЧАЛЕ.
|
| 107 |
-
new_encoder_hidden_states = torch.cat([pooled_expanded, last_hidden], dim=1)
|
| 108 |
-
|
| 109 |
-
# 3. Обновляем маску внимания для нового токена
|
| 110 |
-
# Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
|
| 111 |
-
# torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
|
| 112 |
-
new_attention_mask = torch.cat([torch.ones((last_hidden.shape[0], 1), device=device), toks.attention_mask], dim=1)
|
| 113 |
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
|
|
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
neg_mask = neg_mask.repeat(batch_size, 1)
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
@torch.no_grad()
|
| 130 |
def __call__(
|
| 131 |
self,
|
| 132 |
-
prompt: Union[str, List[str]],
|
| 133 |
-
image: Optional[Union[Image.Image, List[Image.Image]]] = None,
|
| 134 |
-
coef: float = 0.97, # ← strength (0.0 = оригинал, 1.0 = полный шум)
|
| 135 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
height: int = 1024,
|
| 137 |
width: int = 1024,
|
| 138 |
num_inference_steps: int = 40,
|
|
@@ -141,7 +286,6 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 141 |
seed: Optional[int] = None,
|
| 142 |
output_type: str = "pil",
|
| 143 |
return_dict: bool = True,
|
| 144 |
-
refine_prompt: bool = False,
|
| 145 |
**kwargs,
|
| 146 |
):
|
| 147 |
device = self.device
|
|
@@ -149,115 +293,81 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 149 |
|
| 150 |
if generator is None and seed is not None:
|
| 151 |
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
|
|
|
| 152 |
|
| 153 |
-
#
|
| 154 |
-
if
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
# Используем Qwen-Instruct формат (apply_chat_template сам подставит system/user/assistant токены)
|
| 167 |
-
inputs = self.tokenizer.apply_chat_template(
|
| 168 |
-
messages,
|
| 169 |
-
tokenize=True,
|
| 170 |
-
add_generation_prompt=True,
|
| 171 |
-
return_dict=True,
|
| 172 |
-
return_tensors="pt"
|
| 173 |
-
).to(device)
|
| 174 |
-
|
| 175 |
-
generated_ids = self.text_encoder.generate(
|
| 176 |
-
**inputs, max_new_tokens=248, do_sample=True,temperature = 0.7
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
# Обрезаем входные токены из ответа
|
| 180 |
-
generated_ids_trimmed = [
|
| 181 |
-
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 182 |
-
]
|
| 183 |
-
output_text = self.tokenizer.batch_decode(
|
| 184 |
-
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 185 |
-
)
|
| 186 |
-
refined_list.append(output_text)
|
| 187 |
-
|
| 188 |
-
prompt = refined_list[0] if isinstance(prompt, str) else refined_list
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
#
|
| 197 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 198 |
timesteps = self.scheduler.timesteps
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
else:
|
| 206 |
-
image_tensor = self.preprocess_image(image[0], width, height).to(device, self.vae.dtype)
|
| 207 |
-
|
| 208 |
-
# --- Кодируем в latent ---
|
| 209 |
-
latents_clean = self.vae.encode(image_tensor).latent_dist.sample(generator=generator)
|
| 210 |
-
latents_clean = (latents_clean - self.vae_latents_mean.to(device, self.vae.dtype)) / self.vae_latents_std.to(device, self.vae.dtype)
|
| 211 |
-
latents_clean = latents_clean.to(dtype)
|
| 212 |
-
|
| 213 |
-
# --- Добавляем шум по Rectified Flow формуле ---
|
| 214 |
-
noise = torch.randn_like(latents_clean)
|
| 215 |
-
|
| 216 |
-
# coef = strength (0.0 → оригинал, 1.0 → чистый шум)
|
| 217 |
-
sigma = coef # в Flow Matching sigma = t
|
| 218 |
-
if hasattr(self.scheduler, "sigma_shift"): # если есть shift (Flux-style)
|
| 219 |
-
sigma = self.scheduler.sigma_shift(sigma)
|
| 220 |
-
|
| 221 |
-
latents = (1.0 - sigma) * latents_clean + sigma * noise
|
| 222 |
-
|
| 223 |
-
# Обрезаем timesteps начиная с текущего sigma
|
| 224 |
-
init_timestep = int(num_inference_steps * coef)
|
| 225 |
-
t_start = max(num_inference_steps - init_timestep, 0)
|
| 226 |
-
timesteps = timesteps[t_start:]
|
| 227 |
-
|
| 228 |
else:
|
| 229 |
-
|
| 230 |
-
latent_h = height // self.vae_scale_factor
|
| 231 |
-
latent_w = width // self.vae_scale_factor
|
| 232 |
-
|
| 233 |
-
latents = torch.randn(
|
| 234 |
-
(batch_size, self.unet.config.in_channels, latent_h, latent_w),
|
| 235 |
-
generator=generator, device=device, dtype=dtype
|
| 236 |
-
)
|
| 237 |
|
| 238 |
-
#
|
| 239 |
for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
|
| 240 |
-
|
|
|
|
| 241 |
|
| 242 |
model_out = self.unet(
|
| 243 |
-
latent_model_input,
|
| 244 |
-
t,
|
| 245 |
encoder_hidden_states=text_embeddings,
|
| 246 |
encoder_attention_mask=attention_mask,
|
| 247 |
return_dict=False,
|
| 248 |
)[0]
|
| 249 |
|
| 250 |
-
|
|
|
|
| 251 |
flow_uncond, flow_cond = model_out.chunk(2)
|
| 252 |
model_out = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
|
| 253 |
|
| 254 |
-
# Важно: используем scheduler.step — он сам знает, что делать с velocity
|
| 255 |
latents = self.scheduler.step(model_out, t, latents, return_dict=False)[0]
|
| 256 |
|
| 257 |
-
#
|
| 258 |
if output_type == "latent":
|
| 259 |
if not return_dict: return (latents, prompt)
|
| 260 |
-
return SdxsPipelineOutput(images=latents
|
| 261 |
|
| 262 |
latents = latents * self.vae_latents_std.to(device, self.vae.dtype) + self.vae_latents_mean.to(device, self.vae.dtype)
|
| 263 |
image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
|
@@ -271,5 +381,5 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 271 |
images = image_np
|
| 272 |
|
| 273 |
if not return_dict:
|
| 274 |
-
return (images,
|
| 275 |
-
return SdxsPipelineOutput(images=images
|
|
|
|
| 7 |
from diffusers import DiffusionPipeline
|
| 8 |
from diffusers.utils import BaseOutput
|
| 9 |
from tqdm import tqdm
|
|
|
|
| 10 |
|
| 11 |
@dataclass
|
| 12 |
class SdxsPipelineOutput(BaseOutput):
|
|
|
|
| 14 |
prompt: Optional[Union[str, List[str]]] = None
|
| 15 |
|
| 16 |
class SdxsPipeline(DiffusionPipeline):
|
| 17 |
+
MAX_TEXT_TOKENS = 248
|
| 18 |
+
|
| 19 |
+
def __init__(self, vae, text_encoder, processor, tokenizer, unet, scheduler):
|
| 20 |
super().__init__()
|
| 21 |
self.register_modules(
|
| 22 |
vae=vae,
|
| 23 |
text_encoder=text_encoder,
|
| 24 |
+
processor=processor,
|
| 25 |
tokenizer=tokenizer,
|
| 26 |
unet=unet,
|
| 27 |
scheduler=scheduler
|
|
|
|
| 32 |
if mean is not None and std is not None:
|
| 33 |
self.vae_latents_std = torch.tensor(std, device=self.unet.device, dtype=self.unet.dtype).view(1, len(std), 1, 1)
|
| 34 |
self.vae_latents_mean = torch.tensor(mean, device=self.unet.device, dtype=self.unet.dtype).view(1, len(mean), 1, 1)
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def _pad_tensor_to_length(tensor: torch.Tensor, target_len: int, dim: int = 1, pad_value: float = 0) -> torch.Tensor:
|
| 38 |
+
current_len = tensor.shape[dim]
|
| 39 |
+
if current_len >= target_len:
|
| 40 |
+
return tensor
|
| 41 |
+
pad_size = target_len - current_len
|
| 42 |
+
if tensor.dim() == 3:
|
| 43 |
+
padding = (0, 0, 0, pad_size, 0, 0)
|
| 44 |
+
elif tensor.dim() == 2:
|
| 45 |
+
padding = (0, pad_size, 0, 0)
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")
|
| 48 |
+
return torch.nn.functional.pad(tensor, padding, value=pad_value)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def refine_prompts(
|
| 53 |
+
self,
|
| 54 |
+
prompts: Union[str, List[str]],
|
| 55 |
+
system_prompt: Optional[str] = None,
|
| 56 |
+
temperature: float = 0.7
|
| 57 |
+
) -> List[str]:
|
| 58 |
+
"""
|
| 59 |
+
Refines a list of prompts using the Text Encoder (LLM).
|
| 60 |
|
| 61 |
+
Args:
|
| 62 |
+
prompts: Single prompt string or list of prompts.
|
| 63 |
+
system_prompt: Custom instruction for the LLM. If None, uses default aesthetic enhancer.
|
| 64 |
+
temperature: Sampling temperature for generation.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
List of refined prompts.
|
| 68 |
+
"""
|
| 69 |
+
device = self.device
|
| 70 |
|
| 71 |
+
# Default system prompt if none provided
|
| 72 |
+
if system_prompt is None:
|
| 73 |
+
system_prompt = (
|
| 74 |
+
"You are a skilled text-to-image prompt engineer whose sole function is to transform "
|
| 75 |
+
"the user's input into an aesthetically optimized, detailed, and visually descriptive three-sentence output. "
|
| 76 |
+
"**The primary subject (e.g., 'girl', 'dog', 'house') MUST be the main focus of the revised prompt "
|
| 77 |
+
"and MUST be described in rich detail within the first sentence or two.** "
|
| 78 |
+
"Output **only** the final revised prompt, with absolutely no commentary. "
|
| 79 |
+
"Don't use cliches like warm, soft, vibrant, wildflowers. Be creative. User input prompt: "
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
pad_id = getattr(self.text_encoder.config, "pad_token_id", None) or \
|
| 83 |
+
getattr(self.text_encoder.config, "eos_token_id", None)
|
| 84 |
+
|
| 85 |
+
prompts_list = [prompts] if isinstance(prompts, str) else prompts
|
| 86 |
+
refined_list = []
|
| 87 |
|
| 88 |
+
for p in prompts_list:
|
| 89 |
+
# Prepend system prompt to user input
|
| 90 |
+
full_text = system_prompt + p
|
| 91 |
+
messages = [{"role": "user", "content": [{"type": "text", "text": full_text}]}]
|
| 92 |
+
|
| 93 |
+
inputs = self.tokenizer.apply_chat_template(
|
| 94 |
+
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
|
| 95 |
+
).to(device)
|
| 96 |
+
|
| 97 |
+
generated_ids = self.text_encoder.generate(
|
| 98 |
+
**inputs,
|
| 99 |
+
max_new_tokens=self.MAX_TEXT_TOKENS,
|
| 100 |
+
do_sample=True,
|
| 101 |
+
temperature=temperature,
|
| 102 |
+
pad_token_id=pad_id
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
generated_ids_trimmed = [
|
| 106 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 107 |
+
]
|
| 108 |
+
output_text = self.tokenizer.batch_decode(
|
| 109 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 110 |
+
)
|
| 111 |
+
refined_list.append(output_text[0])
|
| 112 |
+
|
| 113 |
+
return refined_list
|
| 114 |
+
|
| 115 |
+
@torch.no_grad()
|
| 116 |
+
def encode_text(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 117 |
+
device = self.device
|
| 118 |
+
dtype = self.unet.dtype
|
| 119 |
+
if text is None: text = ""
|
| 120 |
+
if isinstance(text, str): text = [text]
|
| 121 |
+
|
| 122 |
+
formatted_prompts = []
|
| 123 |
+
for t in text:
|
| 124 |
+
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 125 |
+
formatted_prompts.append(self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False))
|
| 126 |
+
|
| 127 |
+
toks = self.tokenizer(formatted_prompts, padding="max_length", max_length=self.MAX_TEXT_TOKENS, truncation=True, return_tensors="pt").to(device)
|
| 128 |
+
outputs = self.text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True)
|
| 129 |
|
| 130 |
+
last_hidden = outputs.hidden_states[-2]
|
| 131 |
+
seq_len = toks.attention_mask.sum(dim=1) - 1
|
| 132 |
+
pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
pooled_expanded = pooled.unsqueeze(1)
|
| 135 |
+
encoder_hidden_states = torch.cat([pooled_expanded, last_hidden], dim=1)
|
| 136 |
+
attention_mask = torch.cat([torch.ones((last_hidden.shape[0], 1), device=device, dtype=toks.attention_mask.dtype), toks.attention_mask], dim=1)
|
| 137 |
+
|
| 138 |
+
return encoder_hidden_states.to(dtype=dtype), attention_mask.to(dtype=torch.int64)
|
| 139 |
+
|
| 140 |
+
@torch.no_grad()
|
| 141 |
+
def encode_image(self, image: Union[Image.Image, str, List[Union[Image.Image, str]]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 142 |
+
device = self.device
|
| 143 |
+
dtype = self.unet.dtype
|
| 144 |
+
if isinstance(image, (str, Image.Image)): image = [image]
|
| 145 |
+
batch_size = len(image)
|
| 146 |
|
| 147 |
+
all_messages = [[{"role": "user", "content": [{"type": "image", "image": img}]}] for img in image]
|
| 148 |
+
formatted_prompts = [self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) for msgs in all_messages]
|
|
|
|
|
|
|
| 149 |
|
| 150 |
+
inputs = self.processor(text=formatted_prompts, images=image, return_tensors="pt", padding=True, truncation=False).to(device)
|
| 151 |
+
outputs = self.text_encoder(**inputs, output_hidden_states=True)
|
| 152 |
|
| 153 |
+
last_hidden = outputs.hidden_states[-2]
|
| 154 |
+
seq_lens = inputs.attention_mask.sum(dim=1) - 1
|
| 155 |
+
pooled = last_hidden[torch.arange(batch_size), seq_lens.clamp(min=0)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
final_embeddings = torch.cat([pooled.unsqueeze(1), last_hidden], dim=1)
|
| 158 |
+
final_mask = torch.cat([torch.ones((batch_size, 1), device=device, dtype=inputs.attention_mask.dtype), inputs.attention_mask], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
+
return final_embeddings.to(dtype=dtype), final_mask.to(dtype=torch.int64)
|
| 161 |
+
@torch.no_grad()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def encode_text_and_image_naive(self, text: Union[str, List[str]], image: Optional[Union[Image.Image, List[Image.Image], str, List[str]]] = None, scale = 0.5) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 165 |
+
# 1. Получаем текстовый эмбеддинг
|
| 166 |
+
text_embeds, text_mask = self.encode_text(text)
|
| 167 |
+
|
| 168 |
+
if image is not None:
|
| 169 |
+
if isinstance(image, (str, Image.Image)):
|
| 170 |
+
image = [image]
|
| 171 |
+
|
| 172 |
+
# Если картинка одна, а текстов много - размножаем картинку
|
| 173 |
+
if len(image) == 1 and text_embeds.shape[0] > 1:
|
| 174 |
+
image = image * text_embeds.shape[0]
|
| 175 |
+
|
| 176 |
+
# --- НАЧАЛО ВСТАВЛЕННОГО КОДА (Логика из encode_image) ---
|
| 177 |
+
device = self.device
|
| 178 |
+
dtype = self.unet.dtype
|
| 179 |
+
batch_size = len(image)
|
| 180 |
|
| 181 |
+
all_messages = [[{"role": "user", "content": [{"type": "image", "image": img}]}] for img in image]
|
| 182 |
+
formatted_prompts = [self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) for msgs in all_messages]
|
|
|
|
|
|
|
| 183 |
|
| 184 |
+
inputs = self.processor(text=formatted_prompts, images=image, return_tensors="pt", padding=True, truncation=False).to(device)
|
| 185 |
+
outputs = self.text_encoder(**inputs, output_hidden_states=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
+
# Берем нужный хайден (-2 слой)
|
| 188 |
+
img_hidden_states = outputs.hidden_states[-2]
|
| 189 |
+
# Берем маску attention из процессора
|
| 190 |
+
img_mask = inputs.attention_mask
|
| 191 |
+
# --- КОНЕЦ ВСТАВЛЕННОГО КОДА ---
|
| 192 |
|
| 193 |
+
# Применяем масштабирование
|
| 194 |
+
if scale != 1.0:
|
| 195 |
+
img_hidden_states = img_hidden_states * scale
|
| 196 |
|
| 197 |
+
# Приводим маску и типы данных к соответствию с текстом
|
| 198 |
+
img_mask = img_mask.to(text_mask.dtype)
|
| 199 |
+
img_hidden_states = img_hidden_states.to(dtype=dtype)
|
|
|
|
| 200 |
|
| 201 |
+
# Объединяем текст и последовательность токенов картинки
|
| 202 |
+
final_embeds = torch.cat([text_embeds, img_hidden_states], dim=1)
|
| 203 |
+
final_mask = torch.cat([text_mask, img_mask], dim=1)
|
| 204 |
+
|
| 205 |
+
return final_embeds, final_mask
|
| 206 |
+
|
| 207 |
+
return text_embeds, text_mask
|
| 208 |
|
| 209 |
+
@torch.no_grad()
|
| 210 |
+
def image_upscale(
|
| 211 |
+
self,
|
| 212 |
+
image: Union[str, Image.Image, List[Union[str, Image.Image]]],
|
| 213 |
+
batch_size: int = 1
|
| 214 |
+
) -> List[Image.Image]:
|
| 215 |
+
"""
|
| 216 |
+
Upscales images using asymmetric VAE (x2).
|
| 217 |
+
Uses smart batching: processes in parallel if sizes match, else falls back to sequential.
|
| 218 |
+
"""
|
| 219 |
+
images = [image] if isinstance(image, (str, Image.Image)) else image
|
| 220 |
+
|
| 221 |
+
# 1. Preprocess: Load, Handle Alpha, Pad to %8, Normalize
|
| 222 |
+
batch_data = []
|
| 223 |
+
for img in images:
|
| 224 |
+
if isinstance(img, str): img = Image.open(img)
|
| 225 |
+
if img.mode == "RGBA":
|
| 226 |
+
img = Image.alpha_composite(Image.new("RGBA", img.size, (255, 255, 255)), img)
|
| 227 |
+
img = img.convert("RGB")
|
| 228 |
+
|
| 229 |
+
w, h = img.size
|
| 230 |
+
pw, ph = (8 - w % 8) % 8, (8 - h % 8) % 8
|
| 231 |
+
if pw or ph:
|
| 232 |
+
padded = Image.new("RGB", (w + pw, h + ph), (255, 255, 255))
|
| 233 |
+
padded.paste(img)
|
| 234 |
+
img = padded
|
| 235 |
+
|
| 236 |
+
t = torch.from_numpy(np.array(img).astype(np.float32) / 127.5 - 1.0).permute(2, 0, 1)
|
| 237 |
+
batch_data.append((t.to(self.device, torch.float16), w, h))
|
| 238 |
|
| 239 |
+
# 2. Determine Execution Strategy
|
| 240 |
+
# If all shapes are identical, use batch_size. Else fallback to 1.
|
| 241 |
+
unique_shapes = {t.shape for t, _, _ in batch_data}
|
| 242 |
+
step = batch_size if len(unique_shapes) == 1 else 1
|
| 243 |
+
|
| 244 |
+
output_images = []
|
| 245 |
+
|
| 246 |
+
# 3. Process Batches
|
| 247 |
+
for i in range(0, len(batch_data), step):
|
| 248 |
+
chunk = batch_data[i : i + step]
|
| 249 |
+
|
| 250 |
+
# Stack tensors [B, C, H, W]
|
| 251 |
+
tensors = torch.stack([c[0] for c in chunk])
|
| 252 |
+
|
| 253 |
+
# Encode -> Decode (using mean for deterministic upscale)
|
| 254 |
+
latents = self.vae.encode(tensors).latent_dist.mean
|
| 255 |
+
latents = latents * self.vae_latents_std.to(latents) + self.vae_latents_mean.to(latents)
|
| 256 |
+
decoded = self.vae.decode(latents.to(self.vae.dtype))[0]
|
| 257 |
+
|
| 258 |
+
# 4. Post-process: Denormalize and Crop
|
| 259 |
+
decoded = (decoded.clamp(-1, 1) + 1) / 2
|
| 260 |
+
for j, tensor in enumerate(decoded):
|
| 261 |
+
w, h = chunk[j][1], chunk[j][2] # Original sizes
|
| 262 |
+
|
| 263 |
+
# Crop to exact 2x
|
| 264 |
+
arr = tensor.cpu().permute(1, 2, 0).float().numpy()
|
| 265 |
+
arr = arr[:h * 2, :w * 2]
|
| 266 |
+
|
| 267 |
+
output_images.append(Image.fromarray((arr * 255).astype("uint8")))
|
| 268 |
+
|
| 269 |
+
return output_images
|
| 270 |
+
|
| 271 |
@torch.no_grad()
|
| 272 |
def __call__(
|
| 273 |
self,
|
| 274 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
|
|
|
|
|
|
| 275 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 276 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 277 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 278 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 279 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 280 |
+
latents: Optional[torch.Tensor] = None,
|
| 281 |
height: int = 1024,
|
| 282 |
width: int = 1024,
|
| 283 |
num_inference_steps: int = 40,
|
|
|
|
| 286 |
seed: Optional[int] = None,
|
| 287 |
output_type: str = "pil",
|
| 288 |
return_dict: bool = True,
|
|
|
|
| 289 |
**kwargs,
|
| 290 |
):
|
| 291 |
device = self.device
|
|
|
|
| 293 |
|
| 294 |
if generator is None and seed is not None:
|
| 295 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 296 |
+
|
| 297 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 298 |
|
| 299 |
+
# 1. Encode Positive
|
| 300 |
+
if prompt_embeds is None:
|
| 301 |
+
if prompt is None: raise ValueError("`prompt` or `prompt_embeds` required.")
|
| 302 |
+
prompt_embeds, prompt_attention_mask = self.encode_text(prompt)
|
| 303 |
+
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
|
| 304 |
+
prompt_attention_mask = prompt_attention_mask.to(device=device, dtype=torch.int64)
|
| 305 |
+
batch_size = prompt_embeds.shape[0]
|
| 306 |
+
|
| 307 |
+
# 2. Encode Negative (only if CFG is enabled)
|
| 308 |
+
if do_classifier_free_guidance:
|
| 309 |
+
if negative_prompt_embeds is None:
|
| 310 |
+
neg_text = negative_prompt if negative_prompt is not None else ("" if isinstance(prompt, str) else [""] * len(prompt))
|
| 311 |
+
negative_prompt_embeds, negative_prompt_attention_mask = self.encode_text(neg_text)
|
| 312 |
|
| 313 |
+
negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=dtype)
|
| 314 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device, dtype=torch.int64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
+
# Batch size matching
|
| 317 |
+
if negative_prompt_embeds.shape[0] != batch_size:
|
| 318 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1)
|
| 319 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(batch_size, 1)
|
| 320 |
+
|
| 321 |
+
# 3. Align Length (Padding) for Concat
|
| 322 |
+
max_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
|
| 323 |
+
prompt_embeds = self._pad_tensor_to_length(prompt_embeds, max_len, dim=1, pad_value=0)
|
| 324 |
+
negative_prompt_embeds = self._pad_tensor_to_length(negative_prompt_embeds, max_len, dim=1, pad_value=0)
|
| 325 |
+
prompt_attention_mask = self._pad_tensor_to_length(prompt_attention_mask, max_len, dim=1, pad_value=0)
|
| 326 |
+
negative_prompt_attention_mask = self._pad_tensor_to_length(negative_prompt_attention_mask, max_len, dim=1, pad_value=0)
|
| 327 |
+
|
| 328 |
+
# 4. Concatenate for CFG: [Neg, Pos]
|
| 329 |
+
text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 330 |
+
attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
| 331 |
+
else:
|
| 332 |
+
# If no CFG, we just use positive embeddings as is
|
| 333 |
+
text_embeddings = prompt_embeds
|
| 334 |
+
attention_mask = prompt_attention_mask
|
| 335 |
|
| 336 |
+
# 5. Scheduler & Latents
|
| 337 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 338 |
timesteps = self.scheduler.timesteps
|
| 339 |
+
|
| 340 |
+
latent_h = height // self.vae_scale_factor
|
| 341 |
+
latent_w = width // self.vae_scale_factor
|
| 342 |
+
|
| 343 |
+
if latents is None:
|
| 344 |
+
latents = torch.randn((batch_size, self.unet.config.in_channels, latent_h, latent_w), generator=generator, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
else:
|
| 346 |
+
latents = latents.to(device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
+
# 6. Denoising Loop
|
| 349 |
for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
|
| 350 |
+
# Duplicate latents only if doing CFG
|
| 351 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 352 |
|
| 353 |
model_out = self.unet(
|
| 354 |
+
latent_model_input, t,
|
|
|
|
| 355 |
encoder_hidden_states=text_embeddings,
|
| 356 |
encoder_attention_mask=attention_mask,
|
| 357 |
return_dict=False,
|
| 358 |
)[0]
|
| 359 |
|
| 360 |
+
# Perform CFG guidance
|
| 361 |
+
if do_classifier_free_guidance:
|
| 362 |
flow_uncond, flow_cond = model_out.chunk(2)
|
| 363 |
model_out = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
|
| 364 |
|
|
|
|
| 365 |
latents = self.scheduler.step(model_out, t, latents, return_dict=False)[0]
|
| 366 |
|
| 367 |
+
# 7. Decode
|
| 368 |
if output_type == "latent":
|
| 369 |
if not return_dict: return (latents, prompt)
|
| 370 |
+
return SdxsPipelineOutput(images=latents)
|
| 371 |
|
| 372 |
latents = latents * self.vae_latents_std.to(device, self.vae.dtype) + self.vae_latents_mean.to(device, self.vae.dtype)
|
| 373 |
image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
|
|
|
| 381 |
images = image_np
|
| 382 |
|
| 383 |
if not return_dict:
|
| 384 |
+
return (images,)
|
| 385 |
+
return SdxsPipelineOutput(images=images)
|
processor/chat_template.jinja
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- set image_count = namespace(value=0) %}
|
| 2 |
+
{%- set video_count = namespace(value=0) %}
|
| 3 |
+
{%- macro render_content(content, do_vision_count, is_system_content=false) %}
|
| 4 |
+
{%- if content is string %}
|
| 5 |
+
{{- content }}
|
| 6 |
+
{%- elif content is iterable and content is not mapping %}
|
| 7 |
+
{%- for item in content %}
|
| 8 |
+
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
|
| 9 |
+
{%- if is_system_content %}
|
| 10 |
+
{{- raise_exception('System message cannot contain images.') }}
|
| 11 |
+
{%- endif %}
|
| 12 |
+
{%- if do_vision_count %}
|
| 13 |
+
{%- set image_count.value = image_count.value + 1 %}
|
| 14 |
+
{%- endif %}
|
| 15 |
+
{%- if add_vision_id %}
|
| 16 |
+
{{- 'Picture ' ~ image_count.value ~ ': ' }}
|
| 17 |
+
{%- endif %}
|
| 18 |
+
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
|
| 19 |
+
{%- elif 'video' in item or item.type == 'video' %}
|
| 20 |
+
{%- if is_system_content %}
|
| 21 |
+
{{- raise_exception('System message cannot contain videos.') }}
|
| 22 |
+
{%- endif %}
|
| 23 |
+
{%- if do_vision_count %}
|
| 24 |
+
{%- set video_count.value = video_count.value + 1 %}
|
| 25 |
+
{%- endif %}
|
| 26 |
+
{%- if add_vision_id %}
|
| 27 |
+
{{- 'Video ' ~ video_count.value ~ ': ' }}
|
| 28 |
+
{%- endif %}
|
| 29 |
+
{{- '<|vision_start|><|video_pad|><|vision_end|>' }}
|
| 30 |
+
{%- elif 'text' in item %}
|
| 31 |
+
{{- item.text }}
|
| 32 |
+
{%- else %}
|
| 33 |
+
{{- raise_exception('Unexpected item type in content.') }}
|
| 34 |
+
{%- endif %}
|
| 35 |
+
{%- endfor %}
|
| 36 |
+
{%- elif content is none or content is undefined %}
|
| 37 |
+
{{- '' }}
|
| 38 |
+
{%- else %}
|
| 39 |
+
{{- raise_exception('Unexpected content type.') }}
|
| 40 |
+
{%- endif %}
|
| 41 |
+
{%- endmacro %}
|
| 42 |
+
{%- if not messages %}
|
| 43 |
+
{{- raise_exception('No messages provided.') }}
|
| 44 |
+
{%- endif %}
|
| 45 |
+
{%- if tools and tools is iterable and tools is not mapping %}
|
| 46 |
+
{{- '<|im_start|>system\n' }}
|
| 47 |
+
{{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
|
| 48 |
+
{%- for tool in tools %}
|
| 49 |
+
{{- "\n" }}
|
| 50 |
+
{{- tool | tojson }}
|
| 51 |
+
{%- endfor %}
|
| 52 |
+
{{- "\n</tools>" }}
|
| 53 |
+
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
| 54 |
+
{%- if messages[0].role == 'system' %}
|
| 55 |
+
{%- set content = render_content(messages[0].content, false, true)|trim %}
|
| 56 |
+
{%- if content %}
|
| 57 |
+
{{- '\n\n' + content }}
|
| 58 |
+
{%- endif %}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{{- '<|im_end|>\n' }}
|
| 61 |
+
{%- else %}
|
| 62 |
+
{%- if messages[0].role == 'system' %}
|
| 63 |
+
{%- set content = render_content(messages[0].content, false, true)|trim %}
|
| 64 |
+
{{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
|
| 65 |
+
{%- endif %}
|
| 66 |
+
{%- endif %}
|
| 67 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 68 |
+
{%- for message in messages[::-1] %}
|
| 69 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 70 |
+
{%- if ns.multi_step_tool and message.role == "user" %}
|
| 71 |
+
{%- set content = render_content(message.content, false)|trim %}
|
| 72 |
+
{%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
|
| 73 |
+
{%- set ns.multi_step_tool = false %}
|
| 74 |
+
{%- set ns.last_query_index = index %}
|
| 75 |
+
{%- endif %}
|
| 76 |
+
{%- endif %}
|
| 77 |
+
{%- endfor %}
|
| 78 |
+
{%- if ns.multi_step_tool %}
|
| 79 |
+
{{- raise_exception('No user query found in messages.') }}
|
| 80 |
+
{%- endif %}
|
| 81 |
+
{%- for message in messages %}
|
| 82 |
+
{%- set content = render_content(message.content, true)|trim %}
|
| 83 |
+
{%- if message.role == "system" %}
|
| 84 |
+
{%- if not loop.first %}
|
| 85 |
+
{{- raise_exception('System message must be at the beginning.') }}
|
| 86 |
+
{%- endif %}
|
| 87 |
+
{%- elif message.role == "user" %}
|
| 88 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 89 |
+
{%- elif message.role == "assistant" %}
|
| 90 |
+
{%- set reasoning_content = '' %}
|
| 91 |
+
{%- if message.reasoning_content is string %}
|
| 92 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 93 |
+
{%- else %}
|
| 94 |
+
{%- if '</think>' in content %}
|
| 95 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 96 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 97 |
+
{%- endif %}
|
| 98 |
+
{%- endif %}
|
| 99 |
+
{%- set reasoning_content = reasoning_content|trim %}
|
| 100 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 101 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
|
| 102 |
+
{%- else %}
|
| 103 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 104 |
+
{%- endif %}
|
| 105 |
+
{%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
|
| 106 |
+
{%- for tool_call in message.tool_calls %}
|
| 107 |
+
{%- if tool_call.function is defined %}
|
| 108 |
+
{%- set tool_call = tool_call.function %}
|
| 109 |
+
{%- endif %}
|
| 110 |
+
{%- if loop.first %}
|
| 111 |
+
{%- if content|trim %}
|
| 112 |
+
{{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
| 113 |
+
{%- else %}
|
| 114 |
+
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
| 115 |
+
{%- endif %}
|
| 116 |
+
{%- else %}
|
| 117 |
+
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
| 118 |
+
{%- endif %}
|
| 119 |
+
{%- if tool_call.arguments is defined %}
|
| 120 |
+
{%- for args_name, args_value in tool_call.arguments|items %}
|
| 121 |
+
{{- '<parameter=' + args_name + '>\n' }}
|
| 122 |
+
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
| 123 |
+
{{- args_value }}
|
| 124 |
+
{{- '\n</parameter>\n' }}
|
| 125 |
+
{%- endfor %}
|
| 126 |
+
{%- endif %}
|
| 127 |
+
{{- '</function>\n</tool_call>' }}
|
| 128 |
+
{%- endfor %}
|
| 129 |
+
{%- endif %}
|
| 130 |
+
{{- '<|im_end|>\n' }}
|
| 131 |
+
{%- elif message.role == "tool" %}
|
| 132 |
+
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
| 133 |
+
{{- '<|im_start|>user' }}
|
| 134 |
+
{%- endif %}
|
| 135 |
+
{{- '\n<tool_response>\n' }}
|
| 136 |
+
{{- content }}
|
| 137 |
+
{{- '\n</tool_response>' }}
|
| 138 |
+
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
| 139 |
+
{{- '<|im_end|>\n' }}
|
| 140 |
+
{%- elif loop.last %}
|
| 141 |
+
{{- '<|im_end|>\n' }}
|
| 142 |
+
{%- endif %}
|
| 143 |
+
{%- else %}
|
| 144 |
+
{{- raise_exception('Unexpected message role.') }}
|
| 145 |
+
{%- endif %}
|
| 146 |
+
{%- endfor %}
|
| 147 |
+
{%- if add_generation_prompt %}
|
| 148 |
+
{{- '<|im_start|>assistant\n' }}
|
| 149 |
+
{%- if enable_thinking is defined and enable_thinking is true %}
|
| 150 |
+
{{- '<think>\n' }}
|
| 151 |
+
{%- else %}
|
| 152 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 153 |
+
{%- endif %}
|
| 154 |
+
{%- endif %}
|
processor/processor_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:14932921ca485d458a04dafd8069fbb0a4505622a48208d19ed247115801385b
|
| 3 |
+
size 1300
|
processor/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:87a7830d63fcf43bf241c3c5242e96e62dd3fdc29224ca26fed8ea333db72de4
|
| 3 |
+
size 19989343
|
processor/tokenizer_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e98f1901ac6f0adff67b1d540bfa0c36ac1a0cf59eb72ed78146ef89aafa1182
|
| 3 |
+
size 1139
|
samples/unet_384x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_416x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_448x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_480x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_512x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_544x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_576x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_608x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_640x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_672x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x384_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x416_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x448_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x480_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x512_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x544_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x576_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x608_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x672_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
src/unet1.5b.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 45191
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bc106e53f10b9fd143231839045c4fab5413c64a4d3f096304d1c689682299a8
|
| 3 |
size 45191
|
test.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c290bd0cd5bed79850d835c7cbd8c556ef02fc9cbf01cdf8f75229db879fe710
|
| 3 |
+
size 13053363
|
train.py
CHANGED
|
@@ -35,13 +35,13 @@ ds_path = "datasets/ds1234_noanime_704_vae8x16x"
|
|
| 35 |
project = "unet"
|
| 36 |
|
| 37 |
gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 38 |
-
local_bs = max(1, int((gpu_mem_gb / 32) *
|
| 39 |
num_gpus = torch.cuda.device_count()
|
| 40 |
batch_size = local_bs * num_gpus
|
| 41 |
|
| 42 |
base_learning_rate = 4e-5
|
| 43 |
min_learning_rate = 4e-6
|
| 44 |
-
learning_rate_scale =
|
| 45 |
base_learning_rate = base_learning_rate / learning_rate_scale
|
| 46 |
min_learning_rate = min_learning_rate / learning_rate_scale
|
| 47 |
print(f"Calculated params max-lr:{base_learning_rate} min-lr:{min_learning_rate} GPUs: {num_gpus}, Global BS: {batch_size}")
|
|
|
|
| 35 |
project = "unet"
|
| 36 |
|
| 37 |
gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 38 |
+
local_bs = max(1, int((gpu_mem_gb / 32) * 7))
|
| 39 |
num_gpus = torch.cuda.device_count()
|
| 40 |
batch_size = local_bs * num_gpus
|
| 41 |
|
| 42 |
base_learning_rate = 4e-5
|
| 43 |
min_learning_rate = 4e-6
|
| 44 |
+
learning_rate_scale = 1 # 5 - finetune (small details), 1 - pretrain
|
| 45 |
base_learning_rate = base_learning_rate / learning_rate_scale
|
| 46 |
min_learning_rate = min_learning_rate / learning_rate_scale
|
| 47 |
print(f"Calculated params max-lr:{base_learning_rate} min-lr:{min_learning_rate} GPUs: {num_gpus}, Global BS: {batch_size}")
|
unet/config.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1879
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fbb20721f35fd23f45183d6c2341c319ac059296734c98c786b278c7a42e2f50
|
| 3 |
size 1879
|
unet/diffusion_pytorch_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:35324c7f8ccdc476548954c82f76f6a38528201b7a514094dab2e8810519f47e
|
| 3 |
+
size 6420443856
|