Text-to-Image
Diffusers
Safetensors
recoilme commited on
Commit
77f9f15
·
1 Parent(s): 3d9b043
README.md CHANGED
@@ -12,15 +12,30 @@ datasets:
12
 
13
  At AiArtLab, we strive to create a free, compact and fast model that can be trained on consumer graphics cards.
14
 
15
- - Unet: 1.5b parameters
16
- - Qwen3.5: 1.8b parameters
17
- - VAE: 32ch8x16x
18
- - Speed: Sampling: 100%|██████████| 40/40 [00:01<00:00, 29.98it/s]
19
-
20
- ### Random samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  ![promo](media/result_grid.jpg)
22
 
23
- ### Example
24
 
25
  ```
26
  import torch
@@ -46,9 +61,41 @@ image = pipe(
46
  image.show(image)
47
  ```
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  ### VAE
50
 
51
- The VAE in Simple Diffusion utilizes an asymmetric VAE architecture featuring an 8x encoder and a 16x decoder. While a compression factor of 8 is maintained during training, the resolution is effectively doubled during inference through an additional upscaling block. This strategy reduces training costs by an order of magnitude and boosts inference speed without perceptual quality loss. Effectively, this acts as an integrated latent upscaler. To ensure a fair comparison with other VAEs, we downsampled the generated images to match the input resolution for metric evaluation. The SDXS VAE was not trained from scratch but was initialized from weights of FLUX 2 VAE, then redisigned and retrained.
52
 
53
  [eval.py](src/eval.py)
54
  ```
@@ -61,11 +108,16 @@ FLUX.2 | MSE=2.425e-04 PSNR=38.33 LPIPS=0.023 Edge=0.065 KL=2.160
61
  Wan2.2-TI2V-5B (2Gb) | MSE=7.034e-04 PSNR=34.65 LPIPS=0.050 Edge=0.115 KL=9.429
62
  sdxs-1b (200Mb) | MSE=2.655e-04 PSNR=37.83 LPIPS=0.026 Edge=0.066 KL=2.170
63
  ```
 
 
 
 
 
64
  ### Unet
65
 
66
  The UNet architecture in Simple Diffusion is a direct descendant and conceptual continuation of the ideas introduced in the first version of Stable Diffusion. Key distinctions include a relatively small, yet sufficient, number of transformer blocks that ensure an even distribution of attention. Additionally, the number of channels in the final layer has been significantly increased to improve detail rendering. Overall, however, it remains a UNet, similar to SD 1.5.
67
 
68
- Throughout the experiments, we tested hundreds of different configurations and trained dozens of models. Notably, we initially started from the SDXL architecture, assuming it would be a stronger baseline, but ultimately abandoned all the innovations proposed in it. These included uneven attention distribution with increased transformer block depth in the lower layers, a reduced number of blocks in the channel pyramid, micro-conditioning, the dual text encoder, text-time and so on. According to our experiments, all of these changes lead to increased training time and costs while having a near-zero or negative impact on the final result. In total, the investigation of various architectures and the search for the most efficient and optimal configuration took over a year.
69
 
70
  Unfortunately, we were unable to secure grants for model training, with the exception of a grant from Google TPU—which, unfortunately, we were unable to utilize due to insufficient preparation and time constraints. As a result, training and experiments were financed primarily from our own funds and user donations. This left a significant mark on the model’s architecture.
71
  We aimed to make it as small and cost-effective to train as possible while maintaining our quality generation requirements. So perhaps the limited budget even worked to our advantage.
@@ -87,14 +139,18 @@ Additionally, the use of a full-fledged language model allowed us to integrate a
87
  This adventure started in December 2024 after the release of the SANA model. We received a donation from Stan for fine-tuning SANA and, together with Stas, began fine-tuning and further developing it. Despite spending the entire budget, we did not achieve significant improvements. However, we were shocked by how poorly the model was trained and designed, and we became convinced that we could do better—though we were wrong.
88
  Shifting Gears
89
  By February 2025, we split our efforts and began designing our own architectures—which we are still doing today. Stas favored the DiT architecture, while I believed in UNet. Despite some differences in architectural views, we maintained close communication, shared our work, and supported each other throughout the process. We also engaged with the AIArtLab community (a virtual Telegram chat for those contributing to model development)—thank you all for your support.
90
- ## Lessons Learned
91
- One of my key mistakes was relying too heavily on LLMs and research papers. Research often presents minor improvements as groundbreaking innovations, and LLMs, trained on such content, can draw incorrect conclusions due to the abundance of clickbait. From autumn 2025, I radically changed my strategy, switching to training simpler models (VAEs), where simple fine-tuning yielded more substantial improvements than expensive research projects—including fine-tuning a VAE to a quality level comparable to Flux-1 at the time.
92
- This shift led me to adopt a zero-trust policy toward any external information not personally verified. As a result, I focused on building a strong local benchmark for rapid, cost-effective experiments. This led me to train models on the "Butterflies" dataset—a set of 1,000 images of butterflies—where a model could be trained from scratch in just an hour to assess the impact of a hypothesis or improvement.
 
 
93
  ## The Evolutionary Path
94
  The second turning point was the transition to a continuous evolutionary improvement strategy. Unfortunately, the Butterflies dataset does not allow for evaluating prompt-following or anatomical generation capabilities. As a result, the model evolved incrementally rather than through revolutionary changes. The same model, from December 2025, underwent around 10 changes, including radical architectural shifts—while always preserving the pre-trained weights. It’s remarkable how well and quickly pre-trained models adapt to changes in architecture and external factors, even radical ones (e.g., switching VAE models, text encoders, or their combinations).
95
  In addition to saving on training costs, this approach helped maintain minimal model size—for example, adding extra transformer blocks followed by an assessment of necessity and rolling back if the changes had no significant impact.
 
 
96
  ## The Role of Hyperparameters
97
- One of the initial mistakes was an excessive focus on hyperparameters during training. Ironically, 80% of training speed and quality depend on the model architecture (UNet) and the quality of embeddings (VAE), while other 20% is influenced by the text encoder’s embeddings. The rest is determined by the Adam optimizer. The irony here is that Adam is surprisingly forgiving of hyperparameter errors, so I won’t even list them.
98
  ## Tools and Optimization
99
  The model comes with two scripts:
100
 
@@ -102,8 +158,7 @@ A dataset script to convert a folder of image-text pairs into latent representat
102
  A training script provided as a single monolithic file.
103
  Additionally, there’s a script that can be pasted directly into the terminal to automatically train the model with optimized parameters.
104
  ## Training Optimization
105
- All pre-training was done using the AdamW8bit optimizer, which significantly reduced training costs. The final fine-tuning was performed using a more complex optimizer based on [https://github.com/recoilme/muon_adamw8bit](Muon + AdamW8bit).
106
-
107
 
108
  ### Train:
109
 
 
12
 
13
  At AiArtLab, we strive to create a free, compact and fast model that can be trained on consumer graphics cards.
14
 
15
+ - Unet: 1.6b parameters
16
+ - Qwen3.5: 1.8b parameters
17
+ - VAE: 32ch8x16x
18
+ - Speed: Sampling: 100%|██████████| 40/40 [00:01<00:00, 29.98it/s]
19
+ - Resolution: from 768px to 1404px, with step 64px
20
+ - Limitations: trained on small dataset ~1-2kk, focused on illustrations
21
+
22
+ ### Train in progress
23
+
24
+ Key points
25
+
26
+ - Dec 24: Started research on Linear Transformers.
27
+ - Feb 25: Started research on UNet-based diffusion models.
28
+ - Aug 25: Started research on different VAEs.
29
+ - Sep 25: Created a simple VAE and a [vae collection](https://huggingface.co/AiArtLab/collections).
30
+ - Dec 25: Trained SDXS-1B (0.8B at this moment), featuring an SD1.5-like UNet, Long CLIP, 16-channel simple VAE, and flow matching target.
31
+ - Jan 25: Implemented a dual text encoder (SDXL-like style). Total rework.
32
+ - Feb 25: Reverted to classic architecture; tested all SDXL innovations and went back to simple diffusion. Total rework.
33
+ - Mar 25: Created an 32ch 8x/16x asymmetric VAE and switched to Qwen3.5 2B as text encoder.
34
+
35
+ ### Samples with seed 0
36
  ![promo](media/result_grid.jpg)
37
 
38
+ ### Text 2 image
39
 
40
  ```
41
  import torch
 
61
  image.show(image)
62
  ```
63
 
64
+ ### Image upscale
65
+ ```
66
+ upscaled = pipe.image_upscale("media/girl.jpg")
67
+ upscaled[0].show()
68
+ ```
69
+
70
+ ### Prompt refine
71
+ ```
72
+ refined = pipe.refine_prompts("girl")
73
+
74
+ print(refined)
75
+ ```
76
+
77
+ ### Encode image (experimental)
78
+ ```
79
+ emb, mask = pipe.encode_image("media/girl.jpg")
80
+
81
+ # Проверяем
82
+ print("Pooled vector shape:", emb[:, 0, :].shape)
83
+ image = pipeline(
84
+ prompt_embeds = emb,
85
+ prompt_attention_mask = mask,
86
+ negative_prompt = negative_prompt,
87
+ guidance_scale = 4,
88
+ width = 1088,
89
+ height = 1344,
90
+ seed = 0,
91
+ batch_size = 1,
92
+ )[0]
93
+ image[0].show()
94
+ ```
95
+
96
  ### VAE
97
 
98
+ The VAE in Simple Diffusion utilizes an asymmetric VAE architecture featuring an 8x encoder and a 16x decoder. While a compression factor of 8 is maintained during training, the resolution is effectively doubled during inference through an additional upscaling block. This strategy reduces training costs by an order of magnitude and boosts inference speed without perceptual quality loss. Effectively, this acts as an integrated latent upscaler. To ensure a fair comparison with other VAEs, we downsampled the generated images to match the input resolution for metric evaluation. The SDXS VAE was not trained from scratch but was initialized from weights of FLUX 2 VAE, then redisigned and retrained. We also trained [16 ch vae](https://huggingface.co/AiArtLab/simplevae) with flux.1 quality based on aura vae.
99
 
100
  [eval.py](src/eval.py)
101
  ```
 
108
  Wan2.2-TI2V-5B (2Gb) | MSE=7.034e-04 PSNR=34.65 LPIPS=0.050 Edge=0.115 KL=9.429
109
  sdxs-1b (200Mb) | MSE=2.655e-04 PSNR=37.83 LPIPS=0.026 Edge=0.066 KL=2.170
110
  ```
111
+
112
+ ### Image upscale
113
+
114
+ One interesting feature of the asymmetric VAE is the ability to use it as a standalone image and video upscaler. This VAE was trained at resolutions of 512–768 pixels and is effective within this range. It should be noted that this is a latent upscaler, making it simple and fast. It is a "blind" upscaler; unlike model-based upscalers, it interferes with the process minimally and does not alter the essence of the image. This may be useful if you dislike it when upscalers change the image style or phone model—inventing something new based on the original image. On the other hand, you might not like it, as it changes the original minimally.
115
+
116
  ### Unet
117
 
118
  The UNet architecture in Simple Diffusion is a direct descendant and conceptual continuation of the ideas introduced in the first version of Stable Diffusion. Key distinctions include a relatively small, yet sufficient, number of transformer blocks that ensure an even distribution of attention. Additionally, the number of channels in the final layer has been significantly increased to improve detail rendering. Overall, however, it remains a UNet, similar to SD 1.5.
119
 
120
+ Throughout the experiments, we tested [hundreds](https://wandb.ai/recoilme) of different configurations and trained dozens of [models](https://huggingface.co/AiArtLab/sdxs). Notably, we initially started from the SDXL architecture, assuming it would be a stronger baseline, but ultimately abandoned all the innovations proposed in it. These included uneven attention distribution with increased transformer block depth in the lower layers, a reduced number of blocks in the channel pyramid, micro-conditioning, the dual text encoder, text-time and so on. According to our experiments, all of these changes lead to increased training time and costs while having a near-zero or negative impact on the final result. In total, the investigation of various architectures and the search for the most efficient and optimal configuration took over a year.
121
 
122
  Unfortunately, we were unable to secure grants for model training, with the exception of a grant from Google TPU—which, unfortunately, we were unable to utilize due to insufficient preparation and time constraints. As a result, training and experiments were financed primarily from our own funds and user donations. This left a significant mark on the model’s architecture.
123
  We aimed to make it as small and cost-effective to train as possible while maintaining our quality generation requirements. So perhaps the limited budget even worked to our advantage.
 
139
  This adventure started in December 2024 after the release of the SANA model. We received a donation from Stan for fine-tuning SANA and, together with Stas, began fine-tuning and further developing it. Despite spending the entire budget, we did not achieve significant improvements. However, we were shocked by how poorly the model was trained and designed, and we became convinced that we could do better—though we were wrong.
140
  Shifting Gears
141
  By February 2025, we split our efforts and began designing our own architectures—which we are still doing today. Stas favored the DiT architecture, while I believed in UNet. Despite some differences in architectural views, we maintained close communication, shared our work, and supported each other throughout the process. We also engaged with the AIArtLab community (a virtual Telegram chat for those contributing to model development)—thank you all for your support.
142
+ ## Main mistake
143
+ One of my key mistakes was relying too heavily on LLMs and research papers. Research often presents minor improvements as groundbreaking innovations, and LLMs, trained on such content, can draw incorrect conclusions. From autumn 2025, I radically changed my strategy, switching to training simpler models (VAEs), where simple fine-tuning yielded more substantial improvements than expensive research projects—including fine-tuning a VAE to a quality level comparable to Flux-1 at the time.
144
+ This shift led me to adopt a zero-trust policy toward any external information not personally verified. This does not mean that you should not read papers, but I urge you not to trust the conclusions presented in them. This is an extremely radical approach, and I have intentionally radicalized it, but it allowed me to transition from reading papers and implementing other people's ideas to generating my own and training models.
145
+
146
+ As a result, I focused on building a strong local benchmark for rapid, cost-effective experiments on single rtx4080. This led me to train models on the "Butterflies" dataset—a set of 1,000 images of butterflies—where a model could be trained from scratch in just an hour to assess the impact of a hypothesis or improvement, [example](https://www.comet.com/recoilme/unet/356142c52c314078914d0c0db409e1f3?experiment-tab=images&viewId=new).
147
  ## The Evolutionary Path
148
  The second turning point was the transition to a continuous evolutionary improvement strategy. Unfortunately, the Butterflies dataset does not allow for evaluating prompt-following or anatomical generation capabilities. As a result, the model evolved incrementally rather than through revolutionary changes. The same model, from December 2025, underwent around 10 changes, including radical architectural shifts—while always preserving the pre-trained weights. It’s remarkable how well and quickly pre-trained models adapt to changes in architecture and external factors, even radical ones (e.g., switching VAE models, text encoders, or their combinations).
149
  In addition to saving on training costs, this approach helped maintain minimal model size—for example, adding extra transformer blocks followed by an assessment of necessity and rolling back if the changes had no significant impact.
150
+ ## tldr;
151
+ Stop reading, start training
152
  ## The Role of Hyperparameters
153
+ One of the initial mistakes was an excessive focus on hyperparameters during training. Ironically, 80% of training speed and quality depend on the model architecture (UNet) and the quality of embeddings (VAE), while other 20% is influenced by the text encoder’s embeddings. The rest is Role of Hyperparameters. The irony here is that Adam (adamw8bit) is surprisingly forgiving of hyperparameter errors, so I won’t even list them. Default is ok.
154
  ## Tools and Optimization
155
  The model comes with two scripts:
156
 
 
158
  A training script provided as a single monolithic file.
159
  Additionally, there’s a script that can be pasted directly into the terminal to automatically train the model with optimized parameters.
160
  ## Training Optimization
161
+ All pre-training was done using the AdamW8bit optimizer, which significantly reduced training costs.
 
162
 
163
  ### Train:
164
 
girl.jpg CHANGED

Git LFS Details

  • SHA256: 1c805d884786deb953a5473e672f5ab8c9ccf616dcf2811011885d7c7ef767ba
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB

Git LFS Details

  • SHA256: 2def6f65476e848fc6076e7715421d9cf308fde93f998a478b1d169788548916
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB
media/girl.jpg CHANGED

Git LFS Details

  • SHA256: 9d9c7aac3206c22e5e40c29fa5f1ed2203af161921e302ae538f6cc9a20437f3
  • Pointer size: 130 Bytes
  • Size of remote file: 49.6 kB

Git LFS Details

  • SHA256: 2def6f65476e848fc6076e7715421d9cf308fde93f998a478b1d169788548916
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB
media/result_grid.jpg CHANGED

Git LFS Details

  • SHA256: 83795de2023af3ef0b99472dbcb9805c7c138573d475dae44506ceebcda808a9
  • Pointer size: 132 Bytes
  • Size of remote file: 7.34 MB

Git LFS Details

  • SHA256: a9c9cdd8c1fcb06b9abf6fad5b80043e216194a1f89187d1d42b40ad078cdd2c
  • Pointer size: 131 Bytes
  • Size of remote file: 455 kB
model_index.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b9fe0891c1d3f4f0b2a8cbca077be3533f28306768c3ea8d5256924fc677a4b1
3
- size 438
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d717995bd1e270fd4694a62255b159df8ec189022ac70567e1d888ac8959161b
3
+ size 503
pipeline_sdxs.py CHANGED
@@ -7,7 +7,6 @@ from dataclasses import dataclass
7
  from diffusers import DiffusionPipeline
8
  from diffusers.utils import BaseOutput
9
  from tqdm import tqdm
10
- from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5Tokenizer
11
 
12
  @dataclass
13
  class SdxsPipelineOutput(BaseOutput):
@@ -15,11 +14,14 @@ class SdxsPipelineOutput(BaseOutput):
15
  prompt: Optional[Union[str, List[str]]] = None
16
 
17
  class SdxsPipeline(DiffusionPipeline):
18
- def __init__(self, vae, text_encoder, tokenizer, unet, scheduler):
 
 
19
  super().__init__()
20
  self.register_modules(
21
  vae=vae,
22
  text_encoder=text_encoder,
 
23
  tokenizer=tokenizer,
24
  unet=unet,
25
  scheduler=scheduler
@@ -30,109 +32,252 @@ class SdxsPipeline(DiffusionPipeline):
30
  if mean is not None and std is not None:
31
  self.vae_latents_std = torch.tensor(std, device=self.unet.device, dtype=self.unet.dtype).view(1, len(std), 1, 1)
32
  self.vae_latents_mean = torch.tensor(mean, device=self.unet.device, dtype=self.unet.dtype).view(1, len(mean), 1, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def preprocess_image(self, image: Image.Image, width: int, height: int):
35
- """Ресайз и центрированный кроп изображения для асимметричного VAE."""
36
- # Для энкодера с масштабом 8
37
- target_height = ((height // self.vae_scale_factor) * self.vae_scale_factor)
38
- target_width = ((width // self.vae_scale_factor) * self.vae_scale_factor)
 
 
 
 
39
 
40
- w, h = image.size
41
- aspect_ratio = target_width / target_height
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- if w / h > aspect_ratio:
44
- new_w = int(h * aspect_ratio)
45
- left = (w - new_w) // 2
46
- image = image.crop((left, 0, left + new_w, h))
47
- else:
48
- new_h = int(w / aspect_ratio)
49
- top = (h - new_h) // 2
50
- image = image.crop((0, top, w, top + new_h))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- image = image.resize((target_width, target_height), resample=Image.LANCZOS)
53
- image = np.array(image).astype(np.float32) / 255.0
54
- image = image[None].transpose(0, 3, 1, 2) # [1, C, H, W]
55
- image = torch.from_numpy(image)
56
- return 2.0 * image - 1.0 # [-1, 1]
57
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- def encode_prompt(self, prompt, negative_prompt, device, dtype):
60
- def get_encode(texts):
61
- if texts is None:
62
- texts = ""
63
 
64
- if isinstance(texts, str):
65
- texts = [texts]
66
 
67
- with torch.no_grad():
68
- # 1. Собираем текстовые промпты оборачивая их в Chat Template
69
- formatted_prompts = []
70
- for t in texts:
71
- messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
72
- res_text = self.tokenizer.apply_chat_template(
73
- messages,
74
- add_generation_prompt=True,
75
- tokenize=False
76
- )
77
- formatted_prompts.append(res_text)
78
 
79
- # 2. Токенизируем, режем и добавляем паддинг за один раз
80
- toks = self.tokenizer(
81
- formatted_prompts,
82
- padding="max_length",
83
- max_length=248,
84
- truncation=True, # Не забываем обрезать, если вдруг длиннее
85
- return_tensors="pt"
86
- ).to(device)
87
 
88
- # 3. Прогоняем через модель
89
- outputs = self.text_encoder(
90
- input_ids=toks.input_ids,
91
- attention_mask=toks.attention_mask,
92
- output_hidden_states=True
93
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- layer_index = -2
96
- last_hidden = outputs.hidden_states[layer_index]
97
- seq_len = toks.attention_mask.sum(dim=1) - 1
98
- pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
99
 
100
- # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
101
- # 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
102
- pooled_expanded = pooled.unsqueeze(1)
103
-
104
- # 2. Объединяем последовательность токенов и пулинг-вектор
105
- # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
106
- # Теперь: [B, 1 + L, 1024]. Пулинг стал токеном в НАЧАЛЕ.
107
- new_encoder_hidden_states = torch.cat([pooled_expanded, last_hidden], dim=1)
108
-
109
- # 3. Обновляем маску внимания для нового токена
110
- # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
111
- # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
112
- new_attention_mask = torch.cat([torch.ones((last_hidden.shape[0], 1), device=device), toks.attention_mask], dim=1)
113
 
114
- return new_encoder_hidden_states, new_attention_mask
 
 
 
 
115
 
116
- pos_embeds, pos_mask = get_encode(prompt)
117
- neg_embeds, neg_mask = get_encode(negative_prompt)
 
118
 
119
- batch_size = pos_embeds.shape[0]
120
- if neg_embeds.shape[0] != batch_size:
121
- neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
122
- neg_mask = neg_mask.repeat(batch_size, 1)
123
 
124
- text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
125
- final_mask = torch.cat([neg_mask, pos_mask], dim=0)
 
 
 
 
 
126
 
127
- return text_embeddings.to(dtype=dtype), final_mask.to(dtype=torch.int64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @torch.no_grad()
130
  def __call__(
131
  self,
132
- prompt: Union[str, List[str]],
133
- image: Optional[Union[Image.Image, List[Image.Image]]] = None,
134
- coef: float = 0.97, # ← strength (0.0 = оригинал, 1.0 = полный шум)
135
  negative_prompt: Optional[Union[str, List[str]]] = None,
 
 
 
 
 
136
  height: int = 1024,
137
  width: int = 1024,
138
  num_inference_steps: int = 40,
@@ -141,7 +286,6 @@ class SdxsPipeline(DiffusionPipeline):
141
  seed: Optional[int] = None,
142
  output_type: str = "pil",
143
  return_dict: bool = True,
144
- refine_prompt: bool = False,
145
  **kwargs,
146
  ):
147
  device = self.device
@@ -149,115 +293,81 @@ class SdxsPipeline(DiffusionPipeline):
149
 
150
  if generator is None and seed is not None:
151
  generator = torch.Generator(device=device).manual_seed(seed)
 
 
152
 
153
- # ==================== REFINE PROMPT (INLINE) ====================
154
- if refine_prompt and prompt:
155
- sys_msg = (
156
- "You are a skilled text-to-image prompt engineer whose sole function is to transform the user's input into an aesthetically optimized, detailed, and visually descriptive three-sentence output. "
157
- "**The primary subject (e.g., 'girl', 'dog', 'house') MUST be the main focus of the revised prompt and MUST be described in rich detail within the first sentence or two.** "
158
- "Output **only** the final revised prompt, with absolutely no commentary.\n Don't use cliches like warm,soft,vibrant, wildflowers. Be creative. User input prompt: "
159
- )
160
- prompts_list = [prompt] if isinstance(prompt, str) else prompt
161
- refined_list = []
 
 
 
 
162
 
163
- for p in prompts_list:
164
- messages = [{"role": "user", "content": [{"type": "text", "text": sys_msg + p}]}]
165
-
166
- # Используем Qwen-Instruct формат (apply_chat_template сам подставит system/user/assistant токены)
167
- inputs = self.tokenizer.apply_chat_template(
168
- messages,
169
- tokenize=True,
170
- add_generation_prompt=True,
171
- return_dict=True,
172
- return_tensors="pt"
173
- ).to(device)
174
-
175
- generated_ids = self.text_encoder.generate(
176
- **inputs, max_new_tokens=248, do_sample=True,temperature = 0.7
177
- )
178
-
179
- # Обрезаем входные токены из ответа
180
- generated_ids_trimmed = [
181
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
182
- ]
183
- output_text = self.tokenizer.batch_decode(
184
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
185
- )
186
- refined_list.append(output_text)
187
-
188
- prompt = refined_list[0] if isinstance(prompt, str) else refined_list
189
 
190
- # ==================== ENCODE PROMPTS ====================
191
- text_embeddings, attention_mask = self.encode_prompt(
192
- prompt, negative_prompt, device, dtype
193
- )
194
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- # 2. Scheduler timesteps
197
  self.scheduler.set_timesteps(num_inference_steps, device=device)
198
  timesteps = self.scheduler.timesteps
199
-
200
- # ==================== IMG2IMG БЛОК (НОВАЯ ВЕРСИЯ) ====================
201
- if image is not None:
202
- # --- Подготовка изображения ---
203
- if isinstance(image, Image.Image):
204
- image_tensor = self.preprocess_image(image, width, height).to(device, self.vae.dtype)
205
- else:
206
- image_tensor = self.preprocess_image(image[0], width, height).to(device, self.vae.dtype)
207
-
208
- # --- Кодируем в latent ---
209
- latents_clean = self.vae.encode(image_tensor).latent_dist.sample(generator=generator)
210
- latents_clean = (latents_clean - self.vae_latents_mean.to(device, self.vae.dtype)) / self.vae_latents_std.to(device, self.vae.dtype)
211
- latents_clean = latents_clean.to(dtype)
212
-
213
- # --- Добавляем шум по Rectified Flow формуле ---
214
- noise = torch.randn_like(latents_clean)
215
-
216
- # coef = strength (0.0 → оригинал, 1.0 → чистый шум)
217
- sigma = coef # в Flow Matching sigma = t
218
- if hasattr(self.scheduler, "sigma_shift"): # если есть shift (Flux-style)
219
- sigma = self.scheduler.sigma_shift(sigma)
220
-
221
- latents = (1.0 - sigma) * latents_clean + sigma * noise
222
-
223
- # Обрезаем timesteps начиная с текущего sigma
224
- init_timestep = int(num_inference_steps * coef)
225
- t_start = max(num_inference_steps - init_timestep, 0)
226
- timesteps = timesteps[t_start:]
227
-
228
  else:
229
- # txt2img
230
- latent_h = height // self.vae_scale_factor
231
- latent_w = width // self.vae_scale_factor
232
-
233
- latents = torch.randn(
234
- (batch_size, self.unet.config.in_channels, latent_h, latent_w),
235
- generator=generator, device=device, dtype=dtype
236
- )
237
 
238
- # ==================== DENOISING LOOP (одинаковый для txt2img и img2img) ====================
239
  for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
240
- latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
 
241
 
242
  model_out = self.unet(
243
- latent_model_input,
244
- t,
245
  encoder_hidden_states=text_embeddings,
246
  encoder_attention_mask=attention_mask,
247
  return_dict=False,
248
  )[0]
249
 
250
- if guidance_scale > 1.0:
 
251
  flow_uncond, flow_cond = model_out.chunk(2)
252
  model_out = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
253
 
254
- # Важно: используем scheduler.step — он сам знает, что делать с velocity
255
  latents = self.scheduler.step(model_out, t, latents, return_dict=False)[0]
256
 
257
- # ==================== DECODE ====================
258
  if output_type == "latent":
259
  if not return_dict: return (latents, prompt)
260
- return SdxsPipelineOutput(images=latents, prompt=prompt)
261
 
262
  latents = latents * self.vae_latents_std.to(device, self.vae.dtype) + self.vae_latents_mean.to(device, self.vae.dtype)
263
  image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
@@ -271,5 +381,5 @@ class SdxsPipeline(DiffusionPipeline):
271
  images = image_np
272
 
273
  if not return_dict:
274
- return (images, prompt)
275
- return SdxsPipelineOutput(images=images, prompt=prompt)
 
7
  from diffusers import DiffusionPipeline
8
  from diffusers.utils import BaseOutput
9
  from tqdm import tqdm
 
10
 
11
  @dataclass
12
  class SdxsPipelineOutput(BaseOutput):
 
14
  prompt: Optional[Union[str, List[str]]] = None
15
 
16
  class SdxsPipeline(DiffusionPipeline):
17
+ MAX_TEXT_TOKENS = 248
18
+
19
+ def __init__(self, vae, text_encoder, processor, tokenizer, unet, scheduler):
20
  super().__init__()
21
  self.register_modules(
22
  vae=vae,
23
  text_encoder=text_encoder,
24
+ processor=processor,
25
  tokenizer=tokenizer,
26
  unet=unet,
27
  scheduler=scheduler
 
32
  if mean is not None and std is not None:
33
  self.vae_latents_std = torch.tensor(std, device=self.unet.device, dtype=self.unet.dtype).view(1, len(std), 1, 1)
34
  self.vae_latents_mean = torch.tensor(mean, device=self.unet.device, dtype=self.unet.dtype).view(1, len(mean), 1, 1)
35
+
36
+ @staticmethod
37
+ def _pad_tensor_to_length(tensor: torch.Tensor, target_len: int, dim: int = 1, pad_value: float = 0) -> torch.Tensor:
38
+ current_len = tensor.shape[dim]
39
+ if current_len >= target_len:
40
+ return tensor
41
+ pad_size = target_len - current_len
42
+ if tensor.dim() == 3:
43
+ padding = (0, 0, 0, pad_size, 0, 0)
44
+ elif tensor.dim() == 2:
45
+ padding = (0, pad_size, 0, 0)
46
+ else:
47
+ raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")
48
+ return torch.nn.functional.pad(tensor, padding, value=pad_value)
49
+
50
+
51
+ @torch.no_grad()
52
+ def refine_prompts(
53
+ self,
54
+ prompts: Union[str, List[str]],
55
+ system_prompt: Optional[str] = None,
56
+ temperature: float = 0.7
57
+ ) -> List[str]:
58
+ """
59
+ Refines a list of prompts using the Text Encoder (LLM).
60
 
61
+ Args:
62
+ prompts: Single prompt string or list of prompts.
63
+ system_prompt: Custom instruction for the LLM. If None, uses default aesthetic enhancer.
64
+ temperature: Sampling temperature for generation.
65
+
66
+ Returns:
67
+ List of refined prompts.
68
+ """
69
+ device = self.device
70
 
71
+ # Default system prompt if none provided
72
+ if system_prompt is None:
73
+ system_prompt = (
74
+ "You are a skilled text-to-image prompt engineer whose sole function is to transform "
75
+ "the user's input into an aesthetically optimized, detailed, and visually descriptive three-sentence output. "
76
+ "**The primary subject (e.g., 'girl', 'dog', 'house') MUST be the main focus of the revised prompt "
77
+ "and MUST be described in rich detail within the first sentence or two.** "
78
+ "Output **only** the final revised prompt, with absolutely no commentary. "
79
+ "Don't use cliches like warm, soft, vibrant, wildflowers. Be creative. User input prompt: "
80
+ )
81
+
82
+ pad_id = getattr(self.text_encoder.config, "pad_token_id", None) or \
83
+ getattr(self.text_encoder.config, "eos_token_id", None)
84
+
85
+ prompts_list = [prompts] if isinstance(prompts, str) else prompts
86
+ refined_list = []
87
 
88
+ for p in prompts_list:
89
+ # Prepend system prompt to user input
90
+ full_text = system_prompt + p
91
+ messages = [{"role": "user", "content": [{"type": "text", "text": full_text}]}]
92
+
93
+ inputs = self.tokenizer.apply_chat_template(
94
+ messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
95
+ ).to(device)
96
+
97
+ generated_ids = self.text_encoder.generate(
98
+ **inputs,
99
+ max_new_tokens=self.MAX_TEXT_TOKENS,
100
+ do_sample=True,
101
+ temperature=temperature,
102
+ pad_token_id=pad_id
103
+ )
104
+
105
+ generated_ids_trimmed = [
106
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
107
+ ]
108
+ output_text = self.tokenizer.batch_decode(
109
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
110
+ )
111
+ refined_list.append(output_text[0])
112
+
113
+ return refined_list
114
+
115
+ @torch.no_grad()
116
+ def encode_text(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
117
+ device = self.device
118
+ dtype = self.unet.dtype
119
+ if text is None: text = ""
120
+ if isinstance(text, str): text = [text]
121
+
122
+ formatted_prompts = []
123
+ for t in text:
124
+ messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
125
+ formatted_prompts.append(self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False))
126
+
127
+ toks = self.tokenizer(formatted_prompts, padding="max_length", max_length=self.MAX_TEXT_TOKENS, truncation=True, return_tensors="pt").to(device)
128
+ outputs = self.text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True)
129
 
130
+ last_hidden = outputs.hidden_states[-2]
131
+ seq_len = toks.attention_mask.sum(dim=1) - 1
132
+ pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
 
 
133
 
134
+ pooled_expanded = pooled.unsqueeze(1)
135
+ encoder_hidden_states = torch.cat([pooled_expanded, last_hidden], dim=1)
136
+ attention_mask = torch.cat([torch.ones((last_hidden.shape[0], 1), device=device, dtype=toks.attention_mask.dtype), toks.attention_mask], dim=1)
137
+
138
+ return encoder_hidden_states.to(dtype=dtype), attention_mask.to(dtype=torch.int64)
139
+
140
+ @torch.no_grad()
141
+ def encode_image(self, image: Union[Image.Image, str, List[Union[Image.Image, str]]]) -> Tuple[torch.Tensor, torch.Tensor]:
142
+ device = self.device
143
+ dtype = self.unet.dtype
144
+ if isinstance(image, (str, Image.Image)): image = [image]
145
+ batch_size = len(image)
146
 
147
+ all_messages = [[{"role": "user", "content": [{"type": "image", "image": img}]}] for img in image]
148
+ formatted_prompts = [self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) for msgs in all_messages]
 
 
149
 
150
+ inputs = self.processor(text=formatted_prompts, images=image, return_tensors="pt", padding=True, truncation=False).to(device)
151
+ outputs = self.text_encoder(**inputs, output_hidden_states=True)
152
 
153
+ last_hidden = outputs.hidden_states[-2]
154
+ seq_lens = inputs.attention_mask.sum(dim=1) - 1
155
+ pooled = last_hidden[torch.arange(batch_size), seq_lens.clamp(min=0)]
 
 
 
 
 
 
 
 
156
 
157
+ final_embeddings = torch.cat([pooled.unsqueeze(1), last_hidden], dim=1)
158
+ final_mask = torch.cat([torch.ones((batch_size, 1), device=device, dtype=inputs.attention_mask.dtype), inputs.attention_mask], dim=1)
 
 
 
 
 
 
159
 
160
+ return final_embeddings.to(dtype=dtype), final_mask.to(dtype=torch.int64)
161
+ @torch.no_grad()
162
+
163
+
164
+ def encode_text_and_image_naive(self, text: Union[str, List[str]], image: Optional[Union[Image.Image, List[Image.Image], str, List[str]]] = None, scale = 0.5) -> Tuple[torch.Tensor, torch.Tensor]:
165
+ # 1. Получаем текстовый эмбеддинг
166
+ text_embeds, text_mask = self.encode_text(text)
167
+
168
+ if image is not None:
169
+ if isinstance(image, (str, Image.Image)):
170
+ image = [image]
171
+
172
+ # Если картинка одна, а текстов много - размножаем картинку
173
+ if len(image) == 1 and text_embeds.shape[0] > 1:
174
+ image = image * text_embeds.shape[0]
175
+
176
+ # --- НАЧАЛО ВСТАВЛЕННОГО КОДА (Логика из encode_image) ---
177
+ device = self.device
178
+ dtype = self.unet.dtype
179
+ batch_size = len(image)
180
 
181
+ all_messages = [[{"role": "user", "content": [{"type": "image", "image": img}]}] for img in image]
182
+ formatted_prompts = [self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) for msgs in all_messages]
 
 
183
 
184
+ inputs = self.processor(text=formatted_prompts, images=image, return_tensors="pt", padding=True, truncation=False).to(device)
185
+ outputs = self.text_encoder(**inputs, output_hidden_states=True)
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ # Берем нужный хайден (-2 слой)
188
+ img_hidden_states = outputs.hidden_states[-2]
189
+ # Берем маску attention из процессора
190
+ img_mask = inputs.attention_mask
191
+ # --- КОНЕЦ ВСТАВЛЕННОГО КОДА ---
192
 
193
+ # Применяем масштабирование
194
+ if scale != 1.0:
195
+ img_hidden_states = img_hidden_states * scale
196
 
197
+ # Приводим маску и типы данных к соответствию с текстом
198
+ img_mask = img_mask.to(text_mask.dtype)
199
+ img_hidden_states = img_hidden_states.to(dtype=dtype)
 
200
 
201
+ # Объединяем текст и последовательность токенов картинки
202
+ final_embeds = torch.cat([text_embeds, img_hidden_states], dim=1)
203
+ final_mask = torch.cat([text_mask, img_mask], dim=1)
204
+
205
+ return final_embeds, final_mask
206
+
207
+ return text_embeds, text_mask
208
 
209
+ @torch.no_grad()
210
+ def image_upscale(
211
+ self,
212
+ image: Union[str, Image.Image, List[Union[str, Image.Image]]],
213
+ batch_size: int = 1
214
+ ) -> List[Image.Image]:
215
+ """
216
+ Upscales images using asymmetric VAE (x2).
217
+ Uses smart batching: processes in parallel if sizes match, else falls back to sequential.
218
+ """
219
+ images = [image] if isinstance(image, (str, Image.Image)) else image
220
+
221
+ # 1. Preprocess: Load, Handle Alpha, Pad to %8, Normalize
222
+ batch_data = []
223
+ for img in images:
224
+ if isinstance(img, str): img = Image.open(img)
225
+ if img.mode == "RGBA":
226
+ img = Image.alpha_composite(Image.new("RGBA", img.size, (255, 255, 255)), img)
227
+ img = img.convert("RGB")
228
+
229
+ w, h = img.size
230
+ pw, ph = (8 - w % 8) % 8, (8 - h % 8) % 8
231
+ if pw or ph:
232
+ padded = Image.new("RGB", (w + pw, h + ph), (255, 255, 255))
233
+ padded.paste(img)
234
+ img = padded
235
+
236
+ t = torch.from_numpy(np.array(img).astype(np.float32) / 127.5 - 1.0).permute(2, 0, 1)
237
+ batch_data.append((t.to(self.device, torch.float16), w, h))
238
 
239
+ # 2. Determine Execution Strategy
240
+ # If all shapes are identical, use batch_size. Else fallback to 1.
241
+ unique_shapes = {t.shape for t, _, _ in batch_data}
242
+ step = batch_size if len(unique_shapes) == 1 else 1
243
+
244
+ output_images = []
245
+
246
+ # 3. Process Batches
247
+ for i in range(0, len(batch_data), step):
248
+ chunk = batch_data[i : i + step]
249
+
250
+ # Stack tensors [B, C, H, W]
251
+ tensors = torch.stack([c[0] for c in chunk])
252
+
253
+ # Encode -> Decode (using mean for deterministic upscale)
254
+ latents = self.vae.encode(tensors).latent_dist.mean
255
+ latents = latents * self.vae_latents_std.to(latents) + self.vae_latents_mean.to(latents)
256
+ decoded = self.vae.decode(latents.to(self.vae.dtype))[0]
257
+
258
+ # 4. Post-process: Denormalize and Crop
259
+ decoded = (decoded.clamp(-1, 1) + 1) / 2
260
+ for j, tensor in enumerate(decoded):
261
+ w, h = chunk[j][1], chunk[j][2] # Original sizes
262
+
263
+ # Crop to exact 2x
264
+ arr = tensor.cpu().permute(1, 2, 0).float().numpy()
265
+ arr = arr[:h * 2, :w * 2]
266
+
267
+ output_images.append(Image.fromarray((arr * 255).astype("uint8")))
268
+
269
+ return output_images
270
+
271
  @torch.no_grad()
272
  def __call__(
273
  self,
274
+ prompt: Optional[Union[str, List[str]]] = None,
 
 
275
  negative_prompt: Optional[Union[str, List[str]]] = None,
276
+ prompt_embeds: Optional[torch.Tensor] = None,
277
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
278
+ prompt_attention_mask: Optional[torch.Tensor] = None,
279
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
280
+ latents: Optional[torch.Tensor] = None,
281
  height: int = 1024,
282
  width: int = 1024,
283
  num_inference_steps: int = 40,
 
286
  seed: Optional[int] = None,
287
  output_type: str = "pil",
288
  return_dict: bool = True,
 
289
  **kwargs,
290
  ):
291
  device = self.device
 
293
 
294
  if generator is None and seed is not None:
295
  generator = torch.Generator(device=device).manual_seed(seed)
296
+
297
+ do_classifier_free_guidance = guidance_scale > 1.0
298
 
299
+ # 1. Encode Positive
300
+ if prompt_embeds is None:
301
+ if prompt is None: raise ValueError("`prompt` or `prompt_embeds` required.")
302
+ prompt_embeds, prompt_attention_mask = self.encode_text(prompt)
303
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
304
+ prompt_attention_mask = prompt_attention_mask.to(device=device, dtype=torch.int64)
305
+ batch_size = prompt_embeds.shape[0]
306
+
307
+ # 2. Encode Negative (only if CFG is enabled)
308
+ if do_classifier_free_guidance:
309
+ if negative_prompt_embeds is None:
310
+ neg_text = negative_prompt if negative_prompt is not None else ("" if isinstance(prompt, str) else [""] * len(prompt))
311
+ negative_prompt_embeds, negative_prompt_attention_mask = self.encode_text(neg_text)
312
 
313
+ negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=dtype)
314
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device, dtype=torch.int64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
+ # Batch size matching
317
+ if negative_prompt_embeds.shape[0] != batch_size:
318
+ negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1)
319
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(batch_size, 1)
320
+
321
+ # 3. Align Length (Padding) for Concat
322
+ max_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
323
+ prompt_embeds = self._pad_tensor_to_length(prompt_embeds, max_len, dim=1, pad_value=0)
324
+ negative_prompt_embeds = self._pad_tensor_to_length(negative_prompt_embeds, max_len, dim=1, pad_value=0)
325
+ prompt_attention_mask = self._pad_tensor_to_length(prompt_attention_mask, max_len, dim=1, pad_value=0)
326
+ negative_prompt_attention_mask = self._pad_tensor_to_length(negative_prompt_attention_mask, max_len, dim=1, pad_value=0)
327
+
328
+ # 4. Concatenate for CFG: [Neg, Pos]
329
+ text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
330
+ attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
331
+ else:
332
+ # If no CFG, we just use positive embeddings as is
333
+ text_embeddings = prompt_embeds
334
+ attention_mask = prompt_attention_mask
335
 
336
+ # 5. Scheduler & Latents
337
  self.scheduler.set_timesteps(num_inference_steps, device=device)
338
  timesteps = self.scheduler.timesteps
339
+
340
+ latent_h = height // self.vae_scale_factor
341
+ latent_w = width // self.vae_scale_factor
342
+
343
+ if latents is None:
344
+ latents = torch.randn((batch_size, self.unet.config.in_channels, latent_h, latent_w), generator=generator, device=device, dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  else:
346
+ latents = latents.to(device=device, dtype=dtype)
 
 
 
 
 
 
 
347
 
348
+ # 6. Denoising Loop
349
  for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
350
+ # Duplicate latents only if doing CFG
351
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
352
 
353
  model_out = self.unet(
354
+ latent_model_input, t,
 
355
  encoder_hidden_states=text_embeddings,
356
  encoder_attention_mask=attention_mask,
357
  return_dict=False,
358
  )[0]
359
 
360
+ # Perform CFG guidance
361
+ if do_classifier_free_guidance:
362
  flow_uncond, flow_cond = model_out.chunk(2)
363
  model_out = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
364
 
 
365
  latents = self.scheduler.step(model_out, t, latents, return_dict=False)[0]
366
 
367
+ # 7. Decode
368
  if output_type == "latent":
369
  if not return_dict: return (latents, prompt)
370
+ return SdxsPipelineOutput(images=latents)
371
 
372
  latents = latents * self.vae_latents_std.to(device, self.vae.dtype) + self.vae_latents_mean.to(device, self.vae.dtype)
373
  image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
 
381
  images = image_np
382
 
383
  if not return_dict:
384
+ return (images,)
385
+ return SdxsPipelineOutput(images=images)
processor/chat_template.jinja ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- set image_count = namespace(value=0) %}
2
+ {%- set video_count = namespace(value=0) %}
3
+ {%- macro render_content(content, do_vision_count, is_system_content=false) %}
4
+ {%- if content is string %}
5
+ {{- content }}
6
+ {%- elif content is iterable and content is not mapping %}
7
+ {%- for item in content %}
8
+ {%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
9
+ {%- if is_system_content %}
10
+ {{- raise_exception('System message cannot contain images.') }}
11
+ {%- endif %}
12
+ {%- if do_vision_count %}
13
+ {%- set image_count.value = image_count.value + 1 %}
14
+ {%- endif %}
15
+ {%- if add_vision_id %}
16
+ {{- 'Picture ' ~ image_count.value ~ ': ' }}
17
+ {%- endif %}
18
+ {{- '<|vision_start|><|image_pad|><|vision_end|>' }}
19
+ {%- elif 'video' in item or item.type == 'video' %}
20
+ {%- if is_system_content %}
21
+ {{- raise_exception('System message cannot contain videos.') }}
22
+ {%- endif %}
23
+ {%- if do_vision_count %}
24
+ {%- set video_count.value = video_count.value + 1 %}
25
+ {%- endif %}
26
+ {%- if add_vision_id %}
27
+ {{- 'Video ' ~ video_count.value ~ ': ' }}
28
+ {%- endif %}
29
+ {{- '<|vision_start|><|video_pad|><|vision_end|>' }}
30
+ {%- elif 'text' in item %}
31
+ {{- item.text }}
32
+ {%- else %}
33
+ {{- raise_exception('Unexpected item type in content.') }}
34
+ {%- endif %}
35
+ {%- endfor %}
36
+ {%- elif content is none or content is undefined %}
37
+ {{- '' }}
38
+ {%- else %}
39
+ {{- raise_exception('Unexpected content type.') }}
40
+ {%- endif %}
41
+ {%- endmacro %}
42
+ {%- if not messages %}
43
+ {{- raise_exception('No messages provided.') }}
44
+ {%- endif %}
45
+ {%- if tools and tools is iterable and tools is not mapping %}
46
+ {{- '<|im_start|>system\n' }}
47
+ {{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
48
+ {%- for tool in tools %}
49
+ {{- "\n" }}
50
+ {{- tool | tojson }}
51
+ {%- endfor %}
52
+ {{- "\n</tools>" }}
53
+ {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
54
+ {%- if messages[0].role == 'system' %}
55
+ {%- set content = render_content(messages[0].content, false, true)|trim %}
56
+ {%- if content %}
57
+ {{- '\n\n' + content }}
58
+ {%- endif %}
59
+ {%- endif %}
60
+ {{- '<|im_end|>\n' }}
61
+ {%- else %}
62
+ {%- if messages[0].role == 'system' %}
63
+ {%- set content = render_content(messages[0].content, false, true)|trim %}
64
+ {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
65
+ {%- endif %}
66
+ {%- endif %}
67
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
68
+ {%- for message in messages[::-1] %}
69
+ {%- set index = (messages|length - 1) - loop.index0 %}
70
+ {%- if ns.multi_step_tool and message.role == "user" %}
71
+ {%- set content = render_content(message.content, false)|trim %}
72
+ {%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
73
+ {%- set ns.multi_step_tool = false %}
74
+ {%- set ns.last_query_index = index %}
75
+ {%- endif %}
76
+ {%- endif %}
77
+ {%- endfor %}
78
+ {%- if ns.multi_step_tool %}
79
+ {{- raise_exception('No user query found in messages.') }}
80
+ {%- endif %}
81
+ {%- for message in messages %}
82
+ {%- set content = render_content(message.content, true)|trim %}
83
+ {%- if message.role == "system" %}
84
+ {%- if not loop.first %}
85
+ {{- raise_exception('System message must be at the beginning.') }}
86
+ {%- endif %}
87
+ {%- elif message.role == "user" %}
88
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
89
+ {%- elif message.role == "assistant" %}
90
+ {%- set reasoning_content = '' %}
91
+ {%- if message.reasoning_content is string %}
92
+ {%- set reasoning_content = message.reasoning_content %}
93
+ {%- else %}
94
+ {%- if '</think>' in content %}
95
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
96
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
97
+ {%- endif %}
98
+ {%- endif %}
99
+ {%- set reasoning_content = reasoning_content|trim %}
100
+ {%- if loop.index0 > ns.last_query_index %}
101
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
102
+ {%- else %}
103
+ {{- '<|im_start|>' + message.role + '\n' + content }}
104
+ {%- endif %}
105
+ {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
106
+ {%- for tool_call in message.tool_calls %}
107
+ {%- if tool_call.function is defined %}
108
+ {%- set tool_call = tool_call.function %}
109
+ {%- endif %}
110
+ {%- if loop.first %}
111
+ {%- if content|trim %}
112
+ {{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
113
+ {%- else %}
114
+ {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
115
+ {%- endif %}
116
+ {%- else %}
117
+ {{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
118
+ {%- endif %}
119
+ {%- if tool_call.arguments is defined %}
120
+ {%- for args_name, args_value in tool_call.arguments|items %}
121
+ {{- '<parameter=' + args_name + '>\n' }}
122
+ {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
123
+ {{- args_value }}
124
+ {{- '\n</parameter>\n' }}
125
+ {%- endfor %}
126
+ {%- endif %}
127
+ {{- '</function>\n</tool_call>' }}
128
+ {%- endfor %}
129
+ {%- endif %}
130
+ {{- '<|im_end|>\n' }}
131
+ {%- elif message.role == "tool" %}
132
+ {%- if loop.previtem and loop.previtem.role != "tool" %}
133
+ {{- '<|im_start|>user' }}
134
+ {%- endif %}
135
+ {{- '\n<tool_response>\n' }}
136
+ {{- content }}
137
+ {{- '\n</tool_response>' }}
138
+ {%- if not loop.last and loop.nextitem.role != "tool" %}
139
+ {{- '<|im_end|>\n' }}
140
+ {%- elif loop.last %}
141
+ {{- '<|im_end|>\n' }}
142
+ {%- endif %}
143
+ {%- else %}
144
+ {{- raise_exception('Unexpected message role.') }}
145
+ {%- endif %}
146
+ {%- endfor %}
147
+ {%- if add_generation_prompt %}
148
+ {{- '<|im_start|>assistant\n' }}
149
+ {%- if enable_thinking is defined and enable_thinking is true %}
150
+ {{- '<think>\n' }}
151
+ {%- else %}
152
+ {{- '<think>\n\n</think>\n\n' }}
153
+ {%- endif %}
154
+ {%- endif %}
processor/processor_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14932921ca485d458a04dafd8069fbb0a4505622a48208d19ed247115801385b
3
+ size 1300
processor/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87a7830d63fcf43bf241c3c5242e96e62dd3fdc29224ca26fed8ea333db72de4
3
+ size 19989343
processor/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e98f1901ac6f0adff67b1d540bfa0c36ac1a0cf59eb72ed78146ef89aafa1182
3
+ size 1139
samples/unet_384x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 759eb0dd6eee67ad062e24cd39511d24498e90c765d555bc0d051231790d2612
  • Pointer size: 131 Bytes
  • Size of remote file: 480 kB

Git LFS Details

  • SHA256: 62e5e25dd1f3bbe5369d14da45241bd4b8a00852b33b452e3c2f9574c9073d55
  • Pointer size: 131 Bytes
  • Size of remote file: 253 kB
samples/unet_416x704_0.jpg CHANGED

Git LFS Details

  • SHA256: cd420618126075ebfdd645b9ecc354e9a87eaffd86d677a84234b166669fa998
  • Pointer size: 131 Bytes
  • Size of remote file: 383 kB

Git LFS Details

  • SHA256: f3a8b1425b4857fb890e4879cbf3427aa4d8e60ecd3a1a78033ef17aec8c4347
  • Pointer size: 131 Bytes
  • Size of remote file: 322 kB
samples/unet_448x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 545d587255b90cc6e580d4baa91a0edd1b75f90387dc9b261524e0d23d2ab840
  • Pointer size: 131 Bytes
  • Size of remote file: 335 kB

Git LFS Details

  • SHA256: f853ab52729edebbd5e858743bb2ae5429b91721d105643ebcf6a18467a58515
  • Pointer size: 131 Bytes
  • Size of remote file: 465 kB
samples/unet_480x704_0.jpg CHANGED

Git LFS Details

  • SHA256: fe203ec5c6e73e7dde25ba4ab4acc231cf0fb81591b37efa2a01e14c35014460
  • Pointer size: 131 Bytes
  • Size of remote file: 307 kB

Git LFS Details

  • SHA256: c18f53822c498a3c1a7d2cc71a1428490950fad5355847baf8985f3f6d6dd8a4
  • Pointer size: 131 Bytes
  • Size of remote file: 369 kB
samples/unet_512x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 05de9327a514a0117d8e40488a30e0c29692bd970fae214729d4542ae7772eef
  • Pointer size: 131 Bytes
  • Size of remote file: 417 kB

Git LFS Details

  • SHA256: 8480233d5ceeb9701f0c041189ee93cbcba1a91b21befe6c3a9bfd349597d2b0
  • Pointer size: 131 Bytes
  • Size of remote file: 468 kB
samples/unet_544x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 839170e8bf0ffe85f70220e74e136c850a139844471acbb47ed376f23f5794a2
  • Pointer size: 131 Bytes
  • Size of remote file: 338 kB

Git LFS Details

  • SHA256: 8b2d9f7a4e4b37d394c68a9ebe016f3d18c250fde2618ebc35ff015b81cf0892
  • Pointer size: 131 Bytes
  • Size of remote file: 336 kB
samples/unet_576x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 86eee23cf19ed84e8e89074c34f3a76a49db51f44c2f2548d52143560e416f76
  • Pointer size: 131 Bytes
  • Size of remote file: 409 kB

Git LFS Details

  • SHA256: de45c9182cdd06d8af8470262203d3fa4a7d9533ee8e2e73e99102072b967352
  • Pointer size: 131 Bytes
  • Size of remote file: 561 kB
samples/unet_608x704_0.jpg CHANGED

Git LFS Details

  • SHA256: c06fd992f4f896caab882dd82120205ce80bac7586650c82a0b68f7a7800dd77
  • Pointer size: 131 Bytes
  • Size of remote file: 800 kB

Git LFS Details

  • SHA256: 8f834749ea32033fb639f5c98437a5a55d0eb5bdf2f4747dde47e46c8fa75ab5
  • Pointer size: 131 Bytes
  • Size of remote file: 767 kB
samples/unet_640x704_0.jpg CHANGED

Git LFS Details

  • SHA256: dd5eab7fbe0c4be16c93db31e28574d19351d9ccd16ed4e387b598899473f645
  • Pointer size: 131 Bytes
  • Size of remote file: 647 kB

Git LFS Details

  • SHA256: 95417d4b52af113a45acde9cd81554e3bab496bc50ef11fd5555ca183c1559bf
  • Pointer size: 131 Bytes
  • Size of remote file: 554 kB
samples/unet_672x704_0.jpg CHANGED

Git LFS Details

  • SHA256: f7e1797579d0782b5230120add98cd1271593e63789bf3733aa8f09935d8e535
  • Pointer size: 131 Bytes
  • Size of remote file: 229 kB

Git LFS Details

  • SHA256: 23218ba170a001403cce9cfd950fac911a8ecb283acc50d9a0d863e68c331186
  • Pointer size: 131 Bytes
  • Size of remote file: 340 kB
samples/unet_704x384_0.jpg CHANGED

Git LFS Details

  • SHA256: c0e67b88e5436ed81c51bc1c446e15dad7ba7ae71213fc749d93249c94dc9cac
  • Pointer size: 131 Bytes
  • Size of remote file: 451 kB

Git LFS Details

  • SHA256: 587d91d13f7e0761379fa8ba84e7fb34cb398885626f1cfda759c8324776680b
  • Pointer size: 131 Bytes
  • Size of remote file: 239 kB
samples/unet_704x416_0.jpg CHANGED

Git LFS Details

  • SHA256: ac8f291ae9c5cab96d4beea4581e87fbf8da2c608a6ac3b0210280b0445e6770
  • Pointer size: 131 Bytes
  • Size of remote file: 233 kB

Git LFS Details

  • SHA256: e43513c302b52131886438ba790044a2aeeb9afdfe5254a80389c5a3a165bcaf
  • Pointer size: 131 Bytes
  • Size of remote file: 229 kB
samples/unet_704x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 955d1de974df209495430e6e1822a4d741983a6575dbb7b6e498465a4c8c4a3e
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB

Git LFS Details

  • SHA256: 86e8bbe945a0b148f7d7508e1a093e8ad95cc7ed440cc724da526f5472a1af79
  • Pointer size: 131 Bytes
  • Size of remote file: 369 kB
samples/unet_704x480_0.jpg CHANGED

Git LFS Details

  • SHA256: 0c492b7890843baaf0bc0555b37982b0f728c99ad2f91b992b8c394d5c9c2ff4
  • Pointer size: 131 Bytes
  • Size of remote file: 346 kB

Git LFS Details

  • SHA256: 615eeff210b8971ffc6b18f4a84193b543056f77ba9b10f70cd1fb6f9e2ca574
  • Pointer size: 131 Bytes
  • Size of remote file: 293 kB
samples/unet_704x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 42df9462770f2a75d17682def582478c0a44f1fe5cb1c945089fc21c046f1ea8
  • Pointer size: 131 Bytes
  • Size of remote file: 473 kB

Git LFS Details

  • SHA256: ff5aff046047598406af9a8a55857457b2c5e0a59e3426453c636157edc75f63
  • Pointer size: 131 Bytes
  • Size of remote file: 261 kB
samples/unet_704x544_0.jpg CHANGED

Git LFS Details

  • SHA256: b904cb5b5a3e3dc1d094db9cd18a2f4a1e9aa10de50462e7c07c029f1399cbcc
  • Pointer size: 131 Bytes
  • Size of remote file: 197 kB

Git LFS Details

  • SHA256: 943a8babb233d5bca7c76664cf9d9b9a75c1e3d240157c565a80ad337de95624
  • Pointer size: 131 Bytes
  • Size of remote file: 351 kB
samples/unet_704x576_0.jpg CHANGED

Git LFS Details

  • SHA256: edace3a43408b1f3d0d35dfa4d189fe72bcb3d8706e06226125b64124c810cef
  • Pointer size: 131 Bytes
  • Size of remote file: 455 kB

Git LFS Details

  • SHA256: c2782beac2d5d0aa02b089980159d5b588070c30d5de2093f47c5353b70ade3e
  • Pointer size: 131 Bytes
  • Size of remote file: 505 kB
samples/unet_704x608_0.jpg CHANGED

Git LFS Details

  • SHA256: 87b573f47d0316643de6b2cf15bad4bea5cb3c1e0a15319d7250e514ef0ce34c
  • Pointer size: 131 Bytes
  • Size of remote file: 377 kB

Git LFS Details

  • SHA256: b4ca3d1e1767a81e369e8d05b83f040eb8141bc2d2d59432f314e906cc0d6a72
  • Pointer size: 131 Bytes
  • Size of remote file: 450 kB
samples/unet_704x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 95bfe90417f6660fc9a610e334365a23311d1086b5f862f1f25c74f2ceb9ae98
  • Pointer size: 131 Bytes
  • Size of remote file: 348 kB

Git LFS Details

  • SHA256: be8a1e4e0040438ccf15e6686739d9a9f441c9665b286096a64d49b8bf93ce4e
  • Pointer size: 131 Bytes
  • Size of remote file: 243 kB
samples/unet_704x672_0.jpg CHANGED

Git LFS Details

  • SHA256: d44747d9162bcece687bec005622f43b3b457e065b8b136728487d39e63db440
  • Pointer size: 131 Bytes
  • Size of remote file: 787 kB

Git LFS Details

  • SHA256: ed33a29dfa17e16a1a6123fc7244762c993de55608d6f4e37b0951659193cdf3
  • Pointer size: 131 Bytes
  • Size of remote file: 419 kB
samples/unet_704x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 3bc6f506a521d092de01061966a3da1833eb658325d02c5c9bc5193a7689e5f9
  • Pointer size: 131 Bytes
  • Size of remote file: 899 kB

Git LFS Details

  • SHA256: 472f745d5230e5e23eb0f818a6d6f0fb52538fa31e66cc77fc370087d74b0279
  • Pointer size: 131 Bytes
  • Size of remote file: 489 kB
src/unet1.5b.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0e8e3028e9acfe5c8bf1cf2cb3a371eb91405c8080a120510172768bd86009ba
3
  size 45191
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc106e53f10b9fd143231839045c4fab5413c64a4d3f096304d1c689682299a8
3
  size 45191
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ba1d74b4649631d547cda5b137cb609e39ee1359d7025b2fb2a68e19424f041a
3
- size 9035778
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c290bd0cd5bed79850d835c7cbd8c556ef02fc9cbf01cdf8f75229db879fe710
3
+ size 13053363
train.py CHANGED
@@ -35,13 +35,13 @@ ds_path = "datasets/ds1234_noanime_704_vae8x16x"
35
  project = "unet"
36
 
37
  gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
38
- local_bs = max(1, int((gpu_mem_gb / 32) * 8))
39
  num_gpus = torch.cuda.device_count()
40
  batch_size = local_bs * num_gpus
41
 
42
  base_learning_rate = 4e-5
43
  min_learning_rate = 4e-6
44
- learning_rate_scale = 5 # 5 - finetune (small details), 1 - pretrain
45
  base_learning_rate = base_learning_rate / learning_rate_scale
46
  min_learning_rate = min_learning_rate / learning_rate_scale
47
  print(f"Calculated params max-lr:{base_learning_rate} min-lr:{min_learning_rate} GPUs: {num_gpus}, Global BS: {batch_size}")
 
35
  project = "unet"
36
 
37
  gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
38
+ local_bs = max(1, int((gpu_mem_gb / 32) * 7))
39
  num_gpus = torch.cuda.device_count()
40
  batch_size = local_bs * num_gpus
41
 
42
  base_learning_rate = 4e-5
43
  min_learning_rate = 4e-6
44
+ learning_rate_scale = 1 # 5 - finetune (small details), 1 - pretrain
45
  base_learning_rate = base_learning_rate / learning_rate_scale
46
  min_learning_rate = min_learning_rate / learning_rate_scale
47
  print(f"Calculated params max-lr:{base_learning_rate} min-lr:{min_learning_rate} GPUs: {num_gpus}, Global BS: {batch_size}")
unet/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0a85ea1867dbee11485b2de5f5777cf16f5c5a2ed261dba0a465f5c649092299
3
  size 1879
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbb20721f35fd23f45183d6c2341c319ac059296734c98c786b278c7a42e2f50
3
  size 1879
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cb299eef5f0c3e0e3ed02691466c474ef240eb5059a2742a6ca94c8c744234f8
3
- size 3147092928
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35324c7f8ccdc476548954c82f76f6a38528201b7a514094dab2e8810519f47e
3
+ size 6420443856