fffiloni commited on
Commit
26557da
·
verified ·
1 Parent(s): b2d49bb

Migrated from GitHub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. LICENSE +201 -0
  3. ORIGINAL_README.md +238 -0
  4. app.py +163 -0
  5. assets/Stand-In.png +0 -0
  6. configs/model_config.py +1809 -0
  7. data/video.py +158 -0
  8. distributed/__init__.py +0 -0
  9. distributed/xdit_context_parallel.py +154 -0
  10. download_models.py +21 -0
  11. infer.py +85 -0
  12. infer_face_swap.py +119 -0
  13. infer_with_lora.py +94 -0
  14. infer_with_vace.py +106 -0
  15. lora/__init__.py +91 -0
  16. models/__init__.py +1 -0
  17. models/attention.py +130 -0
  18. models/downloader.py +122 -0
  19. models/model_manager.py +610 -0
  20. models/set_condition_branch.py +41 -0
  21. models/tiler.py +333 -0
  22. models/utils.py +219 -0
  23. models/wan_video_camera_controller.py +290 -0
  24. models/wan_video_dit.py +952 -0
  25. models/wan_video_image_encoder.py +957 -0
  26. models/wan_video_motion_controller.py +41 -0
  27. models/wan_video_text_encoder.py +289 -0
  28. models/wan_video_vace.py +140 -0
  29. models/wan_video_vae.py +1634 -0
  30. pipelines/base.py +173 -0
  31. pipelines/wan_video.py +1793 -0
  32. pipelines/wan_video_face_swap.py +1786 -0
  33. preprocessor/__init__.py +2 -0
  34. preprocessor/image_input_preprocessor.py +181 -0
  35. preprocessor/videomask_generator.py +242 -0
  36. prompters/__init__.py +3 -0
  37. prompters/base_prompter.py +68 -0
  38. prompters/omost.py +472 -0
  39. prompters/prompt_refiners.py +131 -0
  40. prompters/wan_prompter.py +112 -0
  41. requirements.txt +17 -0
  42. schedulers/__init__.py +3 -0
  43. schedulers/continuous_ode.py +61 -0
  44. schedulers/ddim.py +136 -0
  45. schedulers/flow_match.py +100 -0
  46. test/input/first_frame.png +3 -0
  47. test/input/lecun.jpg +0 -0
  48. test/input/pose.mp4 +3 -0
  49. test/input/ruonan.jpg +3 -0
  50. test/input/woman.mp4 +3 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ test/input/first_frame.png filter=lfs diff=lfs merge=lfs -text
37
+ test/input/pose.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ test/input/ruonan.jpg filter=lfs diff=lfs merge=lfs -text
39
+ test/input/woman.mp4 filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
ORIGINAL_README.md ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <h1>
4
+ <img src="assets/Stand-In.png" width="85" alt="Logo" valign="middle">
5
+ Stand-In
6
+ </h1>
7
+
8
+ <h3>A Lightweight and Plug-and-Play Identity Control for Video Generation</h3>
9
+
10
+
11
+
12
+ [![arXiv](https://img.shields.io/badge/arXiv-2508.07901-b31b1b)](https://arxiv.org/abs/2508.07901)
13
+ [![Project Page](https://img.shields.io/badge/Project_Page-Link-green)](https://www.stand-in.tech)
14
+ [![🤗 HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Model-orange)](https://huggingface.co/BowenXue/Stand-In)
15
+
16
+ </div>
17
+
18
+ <img width="5333" height="2983" alt="Image" src="https://github.com/user-attachments/assets/2fe1e505-bcf7-4eb6-8628-f23e70020966" />
19
+
20
+ > **Stand-In** is a lightweight, plug-and-play framework for identity-preserving video generation. By training only **1%** additional parameters compared to the base video generation model, we achieve state-of-the-art results in both Face Similarity and Naturalness, outperforming various full-parameter training methods. Moreover, **Stand-In** can be seamlessly integrated into other tasks such as subject-driven video generation, pose-controlled video generation, video stylization, and face swapping.
21
+
22
+ ---
23
+
24
+ ## 🔥 News
25
+ * **[2025.08.18]** We have released a version compatible with VACE. Not only pose control, but you can also try other control methods such as depth maps, combined with Stand-In to maintain identity simultaneously.
26
+
27
+ * **[2025.08.16]** We have updated the experimental version of the face swapping feature. Feel free to try it out!
28
+
29
+ * **[2025.08.13]** Special thanks to @kijai for integrating Stand-In into the custom ComfyUI node **WanVideoWrapper**. However, the implementation differs from the official version, which may affect Stand-In’s performance.
30
+ In order to address part of the issue, we have urgently released the official Stand-In preprocessing ComfyUI node:
31
+ 👉 https://github.com/WeChatCV/Stand-In_Preprocessor_ComfyUI
32
+ If you wish to experience Stand-In within ComfyUI, please use **our official preprocessing node** to replace the one implemented by kijai.
33
+ For the best results, we recommend waiting for the release of our full **official Stand-In ComfyUI**.
34
+
35
+ * **[2025.08.12]** Released Stand-In v1.0 (153M parameters), the Wan2.1-14B-T2V–adapted weights and inference code are now open-sourced.
36
+
37
+ ---
38
+
39
+ ## 🌟 Showcase
40
+
41
+ ### Identity-Preserving Text-to-Video Generation
42
+
43
+ | Reference Image | Prompt | Generated Video |
44
+ | :---: | :---: | :---: |
45
+ |![Image](https://github.com/user-attachments/assets/86ce50d7-8ccb-45bf-9538-aea7f167a541)| "In a corridor where the walls ripple like water, a woman reaches out to touch the flowing surface, causing circles of ripples to spread. The camera moves from a medium shot to a close-up, capturing her curious expression as she sees her distorted reflection." |![Image](https://github.com/user-attachments/assets/c3c80bbf-a1cc-46a1-b47b-1b28bcad34a3) |
46
+ |![Image](https://github.com/user-attachments/assets/de10285e-7983-42bb-8534-80ac02210172)| "A young man dressed in traditional attire draws the long sword from his waist and begins to wield it. The blade flashes with light as he moves—his eyes sharp, his actions swift and powerful, with his flowing robes dancing in the wind." |![Image](https://github.com/user-attachments/assets/1532c701-ef01-47be-86da-d33c8c6894ab)|
47
+
48
+ ---
49
+ ### Non-Human Subjects-Preserving Video Generation
50
+
51
+ | Reference Image | Prompt | Generated Video |
52
+ | :---: | :---: | :---: |
53
+ |<img width="415" height="415" alt="Image" src="https://github.com/user-attachments/assets/b929444d-d724-4cf9-b422-be82b380ff78" />|"A chibi-style boy speeding on a skateboard, holding a detective novel in one hand. The background features city streets, with trees, streetlights, and billboards along the roads."|![Image](https://github.com/user-attachments/assets/a7239232-77bc-478b-a0d9-ecc77db97aa5) |
54
+
55
+ ---
56
+
57
+ ### Identity-Preserving Stylized Video Generation
58
+
59
+ | Reference Image | LoRA | Generated Video |
60
+ | :---: | :---: | :---: |
61
+ |![Image](https://github.com/user-attachments/assets/9c0687f9-e465-4bc5-bc62-8ac46d5f38b1)|Ghibli LoRA|![Image](https://github.com/user-attachments/assets/c6ca1858-de39-4fff-825a-26e6d04e695f)|
62
+ ---
63
+
64
+ ### Video Face Swapping
65
+
66
+ | Reference Video | Identity | Generated Video |
67
+ | :---: | :---: | :---: |
68
+ |![Image](https://github.com/user-attachments/assets/33370ac7-364a-4f97-8ba9-14e1009cd701)|<img width="415" height="415" alt="Image" src="https://github.com/user-attachments/assets/d2cd8da0-7aa0-4ee4-a61d-b52718c33756" />|![Image](https://github.com/user-attachments/assets/0db8aedd-411f-414a-9227-88f4e4050b50)|
69
+
70
+
71
+ ---
72
+
73
+ ### Pose-Guided Video Generation (With VACE)
74
+
75
+ | Reference Pose | First Frame | Generated Video |
76
+ | :---: | :---: | :---: |
77
+ |![Image](https://github.com/user-attachments/assets/5df5eec8-b71c-4270-8a78-906a488f9a94)|<img width="719" height="415" alt="Image" src="https://github.com/user-attachments/assets/1c2a69e1-e530-4164-848b-e7ea85a99763" />|![Image](https://github.com/user-attachments/assets/1c8a54da-01d6-43c1-a5fd-cab0c9e32c44)|
78
+
79
+ ---
80
+ ### For more results, please visit [https://stand-in-video.github.io/](https://www.Stand-In.tech)
81
+
82
+ ## 📖 Key Features
83
+ - Efficient Training: Only 1% of the base model parameters need to be trained.
84
+ - High Fidelity: Outstanding identity consistency without sacrificing video generation quality.
85
+ - Plug-and-Play: Easily integrates into existing T2V (Text-to-Video) models.
86
+ - Highly Extensible: Compatible with community models such as LoRA, and supports various downstream video tasks.
87
+
88
+ ---
89
+
90
+ ## ✅ Todo List
91
+ - [x] Release IP2V inference script (compatible with community LoRA).
92
+ - [x] Open-source model weights compatible with Wan2.1-14B-T2V: `Stand-In_Wan2.1-T2V-14B_153M_v1.0`。
93
+ - [ ] Open-source model weights compatible with Wan2.2-T2V-A14B.
94
+ - [ ] Release training dataset, data preprocessing scripts, and training code.
95
+
96
+ ---
97
+
98
+ ## 🚀 Quick Start
99
+
100
+ ### 1. Environment Setup
101
+ ```bash
102
+ # Clone the project repository
103
+ git clone https://github.com/WeChatCV/Stand-In.git
104
+ cd Stand-In
105
+
106
+ # Create and activate Conda environment
107
+ conda create -n Stand-In python=3.11 -y
108
+ conda activate Stand-In
109
+
110
+ # Install dependencies
111
+ pip install -r requirements.txt
112
+
113
+ # (Optional) Install Flash Attention for faster inference
114
+ # Note: Make sure your GPU and CUDA version are compatible with Flash Attention
115
+ pip install flash-attn --no-build-isolation
116
+ ```
117
+
118
+ ### 2. Model Download
119
+ We provide an automatic download script that will fetch all required model weights into the `checkpoints` directory.
120
+ ```bash
121
+ python download_models.py
122
+ ```
123
+ This script will download the following models:
124
+ * `wan2.1-T2V-14B` (base text-to-video model)
125
+ * `antelopev2` (face recognition model)
126
+ * `Stand-In` (our Stand-In model)
127
+
128
+ > Note: If you already have the `wan2.1-T2V-14B model` locally, you can manually edit the `download_model.py` script to comment out the relevant download code and place the model in the `checkpoints/wan2.1-T2V-14B` directory.
129
+
130
+ ---
131
+
132
+ ## 🧪 Usage
133
+
134
+ ### Standard Inference
135
+
136
+ Use the `infer.py` script for standard identity-preserving text-to-video generation.
137
+
138
+
139
+ ```bash
140
+ python infer.py \
141
+ --prompt "A man sits comfortably at a desk, facing the camera as if talking to a friend or family member on the screen. His gaze is focused and gentle, with a natural smile. The background is his carefully decorated personal space, with photos and a world map on the wall, conveying a sense of intimate and modern communication." \
142
+ --ip_image "test/input/lecun.jpg" \
143
+ --output "test/output/lecun.mp4"
144
+ ```
145
+ **Prompt Writing Tip:** If you do not wish to alter the subject's facial features, simply use *"a man"* or *"a woman"* without adding extra descriptions of their appearance. Prompts support both Chinese and English input. The prompt is intended for generating frontal, medium-to-close-up videos.
146
+
147
+ **Input Image Recommendation:** For best results, use a high-resolution frontal face image. There are no restrictions on resolution or file extension, as our built-in preprocessing pipeline will handle them automatically.
148
+
149
+ ---
150
+
151
+ ### Inference with Community LoRA
152
+
153
+ Use the `infer_with_lora.py` script to load one or more community LoRA models alongside Stand-In.
154
+
155
+ ```bash
156
+ python infer_with_lora.py \
157
+ --prompt "A man sits comfortably at a desk, facing the camera as if talking to a friend or family member on the screen. His gaze is focused and gentle, with a natural smile. The background is his carefully decorated personal space, with photos and a world map on the wall, conveying a sense of intimate and modern communication." \
158
+ --ip_image "test/input/lecun.jpg" \
159
+ --output "test/output/lecun.mp4" \
160
+ --lora_path "path/to/your/lora.safetensors" \
161
+ --lora_scale 1.0
162
+ ```
163
+
164
+ We recommend using this stylization LoRA: [https://civitai.com/models/1404755/studio-ghibli-wan21-t2v-14b](https://civitai.com/models/1404755/studio-ghibli-wan21-t2v-14b)
165
+
166
+ ---
167
+
168
+ ### Video Face Swapping
169
+
170
+ Use the `infer_face_swap.py` script to perform video face swapping with Stand-In.
171
+
172
+ ```bash
173
+ python infer_face_swap.py \
174
+ --prompt "The video features a woman standing in front of a large screen displaying the words ""Tech Minute"" and the logo for CNET. She is wearing a purple top and appears to be presenting or speaking about technology-related topics. The background includes a cityscape with tall buildings, suggesting an urban setting. The woman seems to be engaged in a discussion or providing information on technology news or trends. The overall atmosphere is professional and informative, likely aimed at educating viewers about the latest developments in the tech industry." \
175
+ --ip_image "test/input/ruonan.jpg" \
176
+ --output "test/output/ruonan.mp4" \
177
+ --denoising_strength 0.85
178
+ ```
179
+ **Note**: Since Wan2.1 itself does not have an inpainting function, our face swapping feature is still experimental.
180
+
181
+ The higher the denoising_strength, the more the background area is redrawn, and the more natural the face area becomes. Conversely, the lower the denoising_strength, the less the background area is redrawn, and the higher the degree of overfitting in the face area.
182
+
183
+ You can set --force_background_consistency to make the background completely consistent, but this may lead to potential and noticeable contour issues. Enabling this feature requires experimenting with different denoising_strength values to achieve the most natural effect. If slight changes to the background are not a concern, please do not enable this feature.
184
+
185
+
186
+ ### Infer with VACE
187
+ Use the `infer_with_vace.py` script to perform identity-preserving video generation with Stand-In, compatible with VACE.
188
+ ```bash
189
+ python infer_with_vace.py \
190
+ --prompt "A woman raises her hands." \
191
+ --vace_path "checkpoints/VACE/" \
192
+ --ip_image "test/input/first_frame.png" \
193
+ --reference_video "test/input/pose.mp4" \
194
+ --reference_image "test/input/first_frame.png" \
195
+ --output "test/output/woman.mp4" \
196
+ --vace_scale 0.8
197
+ ```
198
+ You need to download the corresponding weights from the `VACE` repository or provide the path to the `VACE` weights in the `vace_path` parameter.
199
+
200
+ ```bash
201
+ python download_models.py --vace
202
+ ```
203
+
204
+ The input control video needs to be preprocessed using VACE's preprocessing tool. Both `reference_video` and `reference_image` are optional and can exist simultaneously. Additionally, VACE’s control has a preset bias towards faces, which affects identity preservation. Please lower the `vace_scale` to a balance point where both motion and identity are preserved. When only `ip_image` and `reference_video` are provided, the weight can be reduced to 0.5.
205
+
206
+ Using both Stand-In and VACE together is more challenging than using Stand-In alone. We are still maintaining this feature, so if you encounter unexpected outputs or have other questions, feel free to raise them in the issue.
207
+
208
+
209
+ ## 🤝 Acknowledgements
210
+
211
+ This project is built upon the following excellent open-source projects:
212
+ * [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) (training/inference framework)
213
+ * [Wan2.1](https://github.com/Wan-Video/Wan2.1) (base video generation model)
214
+
215
+ We sincerely thank the authors and contributors of these projects.
216
+
217
+ The original raw material of our dataset was collected with the help of our team member [Binxin Yang](https://binxinyang.github.io/), and we appreciate his contribution!
218
+
219
+ ---
220
+
221
+ ## ✏ Citation
222
+
223
+ If you find our work helpful for your research, please consider citing our paper:
224
+
225
+ ```bibtex
226
+ @article{xue2025standin,
227
+ title={Stand-In: A Lightweight and Plug-and-Play Identity Control for Video Generation},
228
+ author={Bowen Xue and Qixin Yan and Wenjing Wang and Hao Liu and Chen Li},
229
+ journal={arXiv preprint arXiv:2508.07901},
230
+ year={2025},
231
+ }
232
+ ```
233
+
234
+ ---
235
+
236
+ ## 📬 Contact Us
237
+
238
+ If you have any questions or suggestions, feel free to reach out via [GitHub Issues](https://github.com/WeChatCV/Stand-In/issues) . We look forward to your feedback!
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import time
4
+ from PIL import Image
5
+ import tempfile
6
+ import os
7
+
8
+ from data.video import save_video
9
+ from wan_loader import load_wan_pipe
10
+ from models.set_condition_branch import set_stand_in
11
+ from preprocessor import FaceProcessor
12
+
13
+ print("Loading model, please wait...")
14
+ try:
15
+ ANTELOPEV2_PATH = "checkpoints/antelopev2"
16
+ BASE_MODEL_PATH = "checkpoints/base_model/"
17
+ LORA_MODEL_PATH = "checkpoints/Stand-In/Stand-In_wan2.1_T2V_14B_ver1.0.ckpt"
18
+
19
+ if not os.path.exists(ANTELOPEV2_PATH):
20
+ raise FileNotFoundError(
21
+ f"AntelopeV2 checkpoint not found at: {ANTELOPEV2_PATH}"
22
+ )
23
+ if not os.path.exists(BASE_MODEL_PATH):
24
+ raise FileNotFoundError(f"Base model not found at: {BASE_MODEL_PATH}")
25
+ if not os.path.exists(LORA_MODEL_PATH):
26
+ raise FileNotFoundError(f"LoRA model not found at: {LORA_MODEL_PATH}")
27
+
28
+ face_processor = FaceProcessor(antelopv2_path=ANTELOPEV2_PATH)
29
+ pipe = load_wan_pipe(base_path=BASE_MODEL_PATH, torch_dtype=torch.bfloat16)
30
+ set_stand_in(pipe, model_path=LORA_MODEL_PATH)
31
+ print("Model loaded successfully!")
32
+ except Exception as e:
33
+ print(f"Model loading failed: {e}")
34
+ with gr.Blocks() as demo:
35
+ gr.Markdown("# Error: Model Loading Failed")
36
+ gr.Markdown(f"""
37
+ Please check the following:
38
+ 1. Make sure the checkpoint files are placed in the correct directory.
39
+ 2. Ensure all dependencies are properly installed.
40
+ 3. Check the console output for detailed error information.
41
+
42
+ **Error details**: {e}
43
+ """)
44
+ demo.launch()
45
+ exit()
46
+
47
+
48
+ def generate_video(
49
+ pil_image: Image.Image,
50
+ prompt: str,
51
+ seed: int,
52
+ negative_prompt: str,
53
+ num_steps: int,
54
+ fps: int,
55
+ quality: int,
56
+ ):
57
+ if pil_image is None:
58
+ raise gr.Error("Please upload a face image first!")
59
+
60
+ print("Processing face...")
61
+ ip_image = face_processor.process(pil_image)
62
+ print("Face processing completed.")
63
+
64
+ print("Generating video...")
65
+ start_time = time.time()
66
+ video = pipe(
67
+ prompt=prompt,
68
+ negative_prompt=negative_prompt,
69
+ seed=int(seed),
70
+ ip_image=ip_image,
71
+ num_inference_steps=int(num_steps),
72
+ tiled=False,
73
+ )
74
+ end_time = time.time()
75
+ print(f"Video generated in {end_time - start_time:.2f} seconds.")
76
+
77
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
78
+ video_path = temp_file.name
79
+ save_video(video, video_path, fps=int(fps), quality=quality)
80
+ print(f"Video saved to: {video_path}")
81
+ return video_path
82
+
83
+
84
+ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
85
+ gr.Markdown(
86
+ """
87
+ # Stand-In IP2V
88
+ """
89
+ )
90
+
91
+ with gr.Row():
92
+ with gr.Column(scale=1):
93
+ gr.Markdown("### 1. Upload a Face Image")
94
+ input_image = gr.Image(
95
+ label="Upload Image",
96
+ type="pil",
97
+ image_mode="RGB",
98
+ height=300,
99
+ )
100
+
101
+ gr.Markdown("### 2. Enter Core Parameters")
102
+ input_prompt = gr.Textbox(
103
+ label="Prompt",
104
+ lines=4,
105
+ value="一位男性舒适地坐在书桌前,正对着镜头,仿佛在与屏幕前的亲友对话。他的眼神专注而温柔,嘴角带着自然的笑意。背景是他精心布置的个人空间,墙上贴着照片和一张世界地图,传达出一种亲密而现代的沟通感。",
106
+ placeholder="Please enter a detailed description of the scene, character actions, expressions, etc...",
107
+ )
108
+
109
+ input_seed = gr.Slider(
110
+ label="Seed",
111
+ minimum=0,
112
+ maximum=100000,
113
+ step=1,
114
+ value=0,
115
+ info="The same seed and parameters will generate the same result.",
116
+ )
117
+
118
+ with gr.Accordion("Advanced Options", open=False):
119
+ input_negative_prompt = gr.Textbox(
120
+ label="Negative Prompt",
121
+ lines=3,
122
+ value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
123
+ )
124
+ input_steps = gr.Slider(
125
+ label="Inference Steps",
126
+ minimum=10,
127
+ maximum=50,
128
+ step=1,
129
+ value=20,
130
+ info="More steps may improve details but will take longer to generate.",
131
+ )
132
+ output_fps = gr.Slider(
133
+ label="Video FPS", minimum=10, maximum=30, step=1, value=25
134
+ )
135
+ output_quality = gr.Slider(
136
+ label="Video Quality", minimum=1, maximum=10, step=1, value=9
137
+ )
138
+
139
+ generate_btn = gr.Button("Generate Video", variant="primary")
140
+
141
+ with gr.Column(scale=1):
142
+ gr.Markdown("### 3. View Generated Result")
143
+ output_video = gr.Video(
144
+ label="Generated Video",
145
+ height=480,
146
+ )
147
+ generate_btn.click(
148
+ fn=generate_video,
149
+ inputs=[
150
+ input_image,
151
+ input_prompt,
152
+ input_seed,
153
+ input_negative_prompt,
154
+ input_steps,
155
+ output_fps,
156
+ output_quality,
157
+ ],
158
+ outputs=output_video,
159
+ api_name="generate_video",
160
+ )
161
+
162
+ if __name__ == "__main__":
163
+ demo.launch(share=True, server_name="0.0.0.0", server_port=8080)
assets/Stand-In.png ADDED
configs/model_config.py ADDED
@@ -0,0 +1,1809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+ from models.wan_video_dit import WanModel
4
+ from models.wan_video_text_encoder import WanTextEncoder
5
+ from models.wan_video_image_encoder import WanImageEncoder
6
+ from models.wan_video_vae import WanVideoVAE, WanVideoVAE38
7
+ from models.wan_video_motion_controller import WanMotionControllerModel
8
+ from models.wan_video_vace import VaceWanModel
9
+
10
+
11
+ model_loader_configs = [
12
+ (
13
+ None,
14
+ "9269f8db9040a9d860eaca435be61814",
15
+ ["wan_video_dit"],
16
+ [WanModel],
17
+ "civitai",
18
+ ),
19
+ (
20
+ None,
21
+ "aafcfd9672c3a2456dc46e1cb6e52c70",
22
+ ["wan_video_dit"],
23
+ [WanModel],
24
+ "civitai",
25
+ ),
26
+ (
27
+ None,
28
+ "6bfcfb3b342cb286ce886889d519a77e",
29
+ ["wan_video_dit"],
30
+ [WanModel],
31
+ "civitai",
32
+ ),
33
+ (
34
+ None,
35
+ "6d6ccde6845b95ad9114ab993d917893",
36
+ ["wan_video_dit"],
37
+ [WanModel],
38
+ "civitai",
39
+ ),
40
+ (
41
+ None,
42
+ "6bfcfb3b342cb286ce886889d519a77e",
43
+ ["wan_video_dit"],
44
+ [WanModel],
45
+ "civitai",
46
+ ),
47
+ (
48
+ None,
49
+ "349723183fc063b2bfc10bb2835cf677",
50
+ ["wan_video_dit"],
51
+ [WanModel],
52
+ "civitai",
53
+ ),
54
+ (
55
+ None,
56
+ "efa44cddf936c70abd0ea28b6cbe946c",
57
+ ["wan_video_dit"],
58
+ [WanModel],
59
+ "civitai",
60
+ ),
61
+ (
62
+ None,
63
+ "3ef3b1f8e1dab83d5b71fd7b617f859f",
64
+ ["wan_video_dit"],
65
+ [WanModel],
66
+ "civitai",
67
+ ),
68
+ (
69
+ None,
70
+ "70ddad9d3a133785da5ea371aae09504",
71
+ ["wan_video_dit"],
72
+ [WanModel],
73
+ "civitai",
74
+ ),
75
+ (
76
+ None,
77
+ "26bde73488a92e64cc20b0a7485b9e5b",
78
+ ["wan_video_dit"],
79
+ [WanModel],
80
+ "civitai",
81
+ ),
82
+ (
83
+ None,
84
+ "ac6a5aa74f4a0aab6f64eb9a72f19901",
85
+ ["wan_video_dit"],
86
+ [WanModel],
87
+ "civitai",
88
+ ),
89
+ (
90
+ None,
91
+ "b61c605c2adbd23124d152ed28e049ae",
92
+ ["wan_video_dit"],
93
+ [WanModel],
94
+ "civitai",
95
+ ),
96
+ (
97
+ None,
98
+ "1f5ab7703c6fc803fdded85ff040c316",
99
+ ["wan_video_dit"],
100
+ [WanModel],
101
+ "civitai",
102
+ ),
103
+ (
104
+ None,
105
+ "5b013604280dd715f8457c6ed6d6a626",
106
+ ["wan_video_dit"],
107
+ [WanModel],
108
+ "civitai",
109
+ ),
110
+ (
111
+ None,
112
+ "a61453409b67cd3246cf0c3bebad47ba",
113
+ ["wan_video_dit", "wan_video_vace"],
114
+ [WanModel, VaceWanModel],
115
+ "civitai",
116
+ ),
117
+ (
118
+ None,
119
+ "7a513e1f257a861512b1afd387a8ecd9",
120
+ ["wan_video_dit", "wan_video_vace"],
121
+ [WanModel, VaceWanModel],
122
+ "civitai",
123
+ ),
124
+ (
125
+ None,
126
+ "cb104773c6c2cb6df4f9529ad5c60d0b",
127
+ ["wan_video_dit"],
128
+ [WanModel],
129
+ "diffusers",
130
+ ),
131
+ (
132
+ None,
133
+ "9c8818c2cbea55eca56c7b447df170da",
134
+ ["wan_video_text_encoder"],
135
+ [WanTextEncoder],
136
+ "civitai",
137
+ ),
138
+ (
139
+ None,
140
+ "5941c53e207d62f20f9025686193c40b",
141
+ ["wan_video_image_encoder"],
142
+ [WanImageEncoder],
143
+ "civitai",
144
+ ),
145
+ (
146
+ None,
147
+ "1378ea763357eea97acdef78e65d6d96",
148
+ ["wan_video_vae"],
149
+ [WanVideoVAE],
150
+ "civitai",
151
+ ),
152
+ (
153
+ None,
154
+ "ccc42284ea13e1ad04693284c7a09be6",
155
+ ["wan_video_vae"],
156
+ [WanVideoVAE],
157
+ "civitai",
158
+ ),
159
+ (
160
+ None,
161
+ "e1de6c02cdac79f8b739f4d3698cd216",
162
+ ["wan_video_vae"],
163
+ [WanVideoVAE38],
164
+ "civitai",
165
+ ),
166
+ (
167
+ None,
168
+ "dbd5ec76bbf977983f972c151d545389",
169
+ ["wan_video_motion_controller"],
170
+ [WanMotionControllerModel],
171
+ "civitai",
172
+ ),
173
+ ]
174
+ huggingface_model_loader_configs = [
175
+ # These configs are provided for detecting model type automatically.
176
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
177
+ (
178
+ "ChatGLMModel",
179
+ "diffsynth.models.kolors_text_encoder",
180
+ "kolors_text_encoder",
181
+ None,
182
+ ),
183
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
184
+ (
185
+ "BloomForCausalLM",
186
+ "transformers.models.bloom.modeling_bloom",
187
+ "beautiful_prompt",
188
+ None,
189
+ ),
190
+ (
191
+ "Qwen2ForCausalLM",
192
+ "transformers.models.qwen2.modeling_qwen2",
193
+ "qwen_prompt",
194
+ None,
195
+ ),
196
+ # ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
197
+ (
198
+ "T5EncoderModel",
199
+ "diffsynth.models.flux_text_encoder",
200
+ "flux_text_encoder_2",
201
+ "FluxTextEncoder2",
202
+ ),
203
+ ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
204
+ (
205
+ "SiglipModel",
206
+ "transformers.models.siglip.modeling_siglip",
207
+ "siglip_vision_model",
208
+ "SiglipVisionModel",
209
+ ),
210
+ (
211
+ "LlamaForCausalLM",
212
+ "diffsynth.models.hunyuan_video_text_encoder",
213
+ "hunyuan_video_text_encoder_2",
214
+ "HunyuanVideoLLMEncoder",
215
+ ),
216
+ (
217
+ "LlavaForConditionalGeneration",
218
+ "diffsynth.models.hunyuan_video_text_encoder",
219
+ "hunyuan_video_text_encoder_2",
220
+ "HunyuanVideoMLLMEncoder",
221
+ ),
222
+ (
223
+ "Step1Model",
224
+ "diffsynth.models.stepvideo_text_encoder",
225
+ "stepvideo_text_encoder_2",
226
+ "STEP1TextEncoder",
227
+ ),
228
+ (
229
+ "Qwen2_5_VLForConditionalGeneration",
230
+ "diffsynth.models.qwenvl",
231
+ "qwenvl",
232
+ "Qwen25VL_7b_Embedder",
233
+ ),
234
+ ]
235
+ patch_model_loader_configs = [
236
+ # These configs are provided for detecting model type automatically.
237
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
238
+ ]
239
+
240
+ preset_models_on_huggingface = {
241
+ "HunyuanDiT": [
242
+ (
243
+ "Tencent-Hunyuan/HunyuanDiT",
244
+ "t2i/clip_text_encoder/pytorch_model.bin",
245
+ "models/HunyuanDiT/t2i/clip_text_encoder",
246
+ ),
247
+ (
248
+ "Tencent-Hunyuan/HunyuanDiT",
249
+ "t2i/mt5/pytorch_model.bin",
250
+ "models/HunyuanDiT/t2i/mt5",
251
+ ),
252
+ (
253
+ "Tencent-Hunyuan/HunyuanDiT",
254
+ "t2i/model/pytorch_model_ema.pt",
255
+ "models/HunyuanDiT/t2i/model",
256
+ ),
257
+ (
258
+ "Tencent-Hunyuan/HunyuanDiT",
259
+ "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin",
260
+ "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix",
261
+ ),
262
+ ],
263
+ "stable-video-diffusion-img2vid-xt": [
264
+ (
265
+ "stabilityai/stable-video-diffusion-img2vid-xt",
266
+ "svd_xt.safetensors",
267
+ "models/stable_video_diffusion",
268
+ ),
269
+ ],
270
+ "ExVideo-SVD-128f-v1": [
271
+ (
272
+ "ECNU-CILab/ExVideo-SVD-128f-v1",
273
+ "model.fp16.safetensors",
274
+ "models/stable_video_diffusion",
275
+ ),
276
+ ],
277
+ # Stable Diffusion
278
+ "StableDiffusion_v15": [
279
+ (
280
+ "benjamin-paine/stable-diffusion-v1-5",
281
+ "v1-5-pruned-emaonly.safetensors",
282
+ "models/stable_diffusion",
283
+ ),
284
+ ],
285
+ "DreamShaper_8": [
286
+ ("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
287
+ ],
288
+ # Textual Inversion
289
+ "TextualInversion_VeryBadImageNegative_v1.3": [
290
+ (
291
+ "gemasai/verybadimagenegative_v1.3",
292
+ "verybadimagenegative_v1.3.pt",
293
+ "models/textual_inversion",
294
+ ),
295
+ ],
296
+ # Stable Diffusion XL
297
+ "StableDiffusionXL_v1": [
298
+ (
299
+ "stabilityai/stable-diffusion-xl-base-1.0",
300
+ "sd_xl_base_1.0.safetensors",
301
+ "models/stable_diffusion_xl",
302
+ ),
303
+ ],
304
+ "BluePencilXL_v200": [
305
+ (
306
+ "frankjoshua/bluePencilXL_v200",
307
+ "bluePencilXL_v200.safetensors",
308
+ "models/stable_diffusion_xl",
309
+ ),
310
+ ],
311
+ "StableDiffusionXL_Turbo": [
312
+ (
313
+ "stabilityai/sdxl-turbo",
314
+ "sd_xl_turbo_1.0_fp16.safetensors",
315
+ "models/stable_diffusion_xl_turbo",
316
+ ),
317
+ ],
318
+ # Stable Diffusion 3
319
+ "StableDiffusion3": [
320
+ (
321
+ "stabilityai/stable-diffusion-3-medium",
322
+ "sd3_medium_incl_clips_t5xxlfp16.safetensors",
323
+ "models/stable_diffusion_3",
324
+ ),
325
+ ],
326
+ "StableDiffusion3_without_T5": [
327
+ (
328
+ "stabilityai/stable-diffusion-3-medium",
329
+ "sd3_medium_incl_clips.safetensors",
330
+ "models/stable_diffusion_3",
331
+ ),
332
+ ],
333
+ # ControlNet
334
+ "ControlNet_v11f1p_sd15_depth": [
335
+ (
336
+ "lllyasviel/ControlNet-v1-1",
337
+ "control_v11f1p_sd15_depth.pth",
338
+ "models/ControlNet",
339
+ ),
340
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
341
+ ],
342
+ "ControlNet_v11p_sd15_softedge": [
343
+ (
344
+ "lllyasviel/ControlNet-v1-1",
345
+ "control_v11p_sd15_softedge.pth",
346
+ "models/ControlNet",
347
+ ),
348
+ ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators"),
349
+ ],
350
+ "ControlNet_v11f1e_sd15_tile": [
351
+ (
352
+ "lllyasviel/ControlNet-v1-1",
353
+ "control_v11f1e_sd15_tile.pth",
354
+ "models/ControlNet",
355
+ )
356
+ ],
357
+ "ControlNet_v11p_sd15_lineart": [
358
+ (
359
+ "lllyasviel/ControlNet-v1-1",
360
+ "control_v11p_sd15_lineart.pth",
361
+ "models/ControlNet",
362
+ ),
363
+ ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
364
+ ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators"),
365
+ ],
366
+ "ControlNet_union_sdxl_promax": [
367
+ (
368
+ "xinsir/controlnet-union-sdxl-1.0",
369
+ "diffusion_pytorch_model_promax.safetensors",
370
+ "models/ControlNet/controlnet_union",
371
+ ),
372
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
373
+ ],
374
+ # AnimateDiff
375
+ "AnimateDiff_v2": [
376
+ ("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
377
+ ],
378
+ "AnimateDiff_xl_beta": [
379
+ ("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
380
+ ],
381
+ # Qwen Prompt
382
+ "QwenPrompt": [
383
+ (
384
+ "Qwen/Qwen2-1.5B-Instruct",
385
+ "config.json",
386
+ "models/QwenPrompt/qwen2-1.5b-instruct",
387
+ ),
388
+ (
389
+ "Qwen/Qwen2-1.5B-Instruct",
390
+ "generation_config.json",
391
+ "models/QwenPrompt/qwen2-1.5b-instruct",
392
+ ),
393
+ (
394
+ "Qwen/Qwen2-1.5B-Instruct",
395
+ "model.safetensors",
396
+ "models/QwenPrompt/qwen2-1.5b-instruct",
397
+ ),
398
+ (
399
+ "Qwen/Qwen2-1.5B-Instruct",
400
+ "special_tokens_map.json",
401
+ "models/QwenPrompt/qwen2-1.5b-instruct",
402
+ ),
403
+ (
404
+ "Qwen/Qwen2-1.5B-Instruct",
405
+ "tokenizer.json",
406
+ "models/QwenPrompt/qwen2-1.5b-instruct",
407
+ ),
408
+ (
409
+ "Qwen/Qwen2-1.5B-Instruct",
410
+ "tokenizer_config.json",
411
+ "models/QwenPrompt/qwen2-1.5b-instruct",
412
+ ),
413
+ (
414
+ "Qwen/Qwen2-1.5B-Instruct",
415
+ "merges.txt",
416
+ "models/QwenPrompt/qwen2-1.5b-instruct",
417
+ ),
418
+ (
419
+ "Qwen/Qwen2-1.5B-Instruct",
420
+ "vocab.json",
421
+ "models/QwenPrompt/qwen2-1.5b-instruct",
422
+ ),
423
+ ],
424
+ # Beautiful Prompt
425
+ "BeautifulPrompt": [
426
+ (
427
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
428
+ "config.json",
429
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
430
+ ),
431
+ (
432
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
433
+ "generation_config.json",
434
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
435
+ ),
436
+ (
437
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
438
+ "model.safetensors",
439
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
440
+ ),
441
+ (
442
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
443
+ "special_tokens_map.json",
444
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
445
+ ),
446
+ (
447
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
448
+ "tokenizer.json",
449
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
450
+ ),
451
+ (
452
+ "alibaba-pai/pai-bloom-1b1-text2prompt-sd",
453
+ "tokenizer_config.json",
454
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
455
+ ),
456
+ ],
457
+ # Omost prompt
458
+ "OmostPrompt": [
459
+ (
460
+ "lllyasviel/omost-llama-3-8b-4bits",
461
+ "model-00001-of-00002.safetensors",
462
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
463
+ ),
464
+ (
465
+ "lllyasviel/omost-llama-3-8b-4bits",
466
+ "model-00002-of-00002.safetensors",
467
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
468
+ ),
469
+ (
470
+ "lllyasviel/omost-llama-3-8b-4bits",
471
+ "tokenizer.json",
472
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
473
+ ),
474
+ (
475
+ "lllyasviel/omost-llama-3-8b-4bits",
476
+ "tokenizer_config.json",
477
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
478
+ ),
479
+ (
480
+ "lllyasviel/omost-llama-3-8b-4bits",
481
+ "config.json",
482
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
483
+ ),
484
+ (
485
+ "lllyasviel/omost-llama-3-8b-4bits",
486
+ "generation_config.json",
487
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
488
+ ),
489
+ (
490
+ "lllyasviel/omost-llama-3-8b-4bits",
491
+ "model.safetensors.index.json",
492
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
493
+ ),
494
+ (
495
+ "lllyasviel/omost-llama-3-8b-4bits",
496
+ "special_tokens_map.json",
497
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
498
+ ),
499
+ ],
500
+ # Translator
501
+ "opus-mt-zh-en": [
502
+ (
503
+ "Helsinki-NLP/opus-mt-zh-en",
504
+ "config.json",
505
+ "models/translator/opus-mt-zh-en",
506
+ ),
507
+ (
508
+ "Helsinki-NLP/opus-mt-zh-en",
509
+ "generation_config.json",
510
+ "models/translator/opus-mt-zh-en",
511
+ ),
512
+ (
513
+ "Helsinki-NLP/opus-mt-zh-en",
514
+ "metadata.json",
515
+ "models/translator/opus-mt-zh-en",
516
+ ),
517
+ (
518
+ "Helsinki-NLP/opus-mt-zh-en",
519
+ "pytorch_model.bin",
520
+ "models/translator/opus-mt-zh-en",
521
+ ),
522
+ ("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
523
+ ("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
524
+ (
525
+ "Helsinki-NLP/opus-mt-zh-en",
526
+ "tokenizer_config.json",
527
+ "models/translator/opus-mt-zh-en",
528
+ ),
529
+ ("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
530
+ ],
531
+ # IP-Adapter
532
+ "IP-Adapter-SD": [
533
+ (
534
+ "h94/IP-Adapter",
535
+ "models/image_encoder/model.safetensors",
536
+ "models/IpAdapter/stable_diffusion/image_encoder",
537
+ ),
538
+ (
539
+ "h94/IP-Adapter",
540
+ "models/ip-adapter_sd15.bin",
541
+ "models/IpAdapter/stable_diffusion",
542
+ ),
543
+ ],
544
+ "IP-Adapter-SDXL": [
545
+ (
546
+ "h94/IP-Adapter",
547
+ "sdxl_models/image_encoder/model.safetensors",
548
+ "models/IpAdapter/stable_diffusion_xl/image_encoder",
549
+ ),
550
+ (
551
+ "h94/IP-Adapter",
552
+ "sdxl_models/ip-adapter_sdxl.bin",
553
+ "models/IpAdapter/stable_diffusion_xl",
554
+ ),
555
+ ],
556
+ "SDXL-vae-fp16-fix": [
557
+ (
558
+ "madebyollin/sdxl-vae-fp16-fix",
559
+ "diffusion_pytorch_model.safetensors",
560
+ "models/sdxl-vae-fp16-fix",
561
+ )
562
+ ],
563
+ # Kolors
564
+ "Kolors": [
565
+ (
566
+ "Kwai-Kolors/Kolors",
567
+ "text_encoder/config.json",
568
+ "models/kolors/Kolors/text_encoder",
569
+ ),
570
+ (
571
+ "Kwai-Kolors/Kolors",
572
+ "text_encoder/pytorch_model.bin.index.json",
573
+ "models/kolors/Kolors/text_encoder",
574
+ ),
575
+ (
576
+ "Kwai-Kolors/Kolors",
577
+ "text_encoder/pytorch_model-00001-of-00007.bin",
578
+ "models/kolors/Kolors/text_encoder",
579
+ ),
580
+ (
581
+ "Kwai-Kolors/Kolors",
582
+ "text_encoder/pytorch_model-00002-of-00007.bin",
583
+ "models/kolors/Kolors/text_encoder",
584
+ ),
585
+ (
586
+ "Kwai-Kolors/Kolors",
587
+ "text_encoder/pytorch_model-00003-of-00007.bin",
588
+ "models/kolors/Kolors/text_encoder",
589
+ ),
590
+ (
591
+ "Kwai-Kolors/Kolors",
592
+ "text_encoder/pytorch_model-00004-of-00007.bin",
593
+ "models/kolors/Kolors/text_encoder",
594
+ ),
595
+ (
596
+ "Kwai-Kolors/Kolors",
597
+ "text_encoder/pytorch_model-00005-of-00007.bin",
598
+ "models/kolors/Kolors/text_encoder",
599
+ ),
600
+ (
601
+ "Kwai-Kolors/Kolors",
602
+ "text_encoder/pytorch_model-00006-of-00007.bin",
603
+ "models/kolors/Kolors/text_encoder",
604
+ ),
605
+ (
606
+ "Kwai-Kolors/Kolors",
607
+ "text_encoder/pytorch_model-00007-of-00007.bin",
608
+ "models/kolors/Kolors/text_encoder",
609
+ ),
610
+ (
611
+ "Kwai-Kolors/Kolors",
612
+ "unet/diffusion_pytorch_model.safetensors",
613
+ "models/kolors/Kolors/unet",
614
+ ),
615
+ (
616
+ "Kwai-Kolors/Kolors",
617
+ "vae/diffusion_pytorch_model.safetensors",
618
+ "models/kolors/Kolors/vae",
619
+ ),
620
+ ],
621
+ # FLUX
622
+ "FLUX.1-dev": [
623
+ (
624
+ "black-forest-labs/FLUX.1-dev",
625
+ "text_encoder/model.safetensors",
626
+ "models/FLUX/FLUX.1-dev/text_encoder",
627
+ ),
628
+ (
629
+ "black-forest-labs/FLUX.1-dev",
630
+ "text_encoder_2/config.json",
631
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
632
+ ),
633
+ (
634
+ "black-forest-labs/FLUX.1-dev",
635
+ "text_encoder_2/model-00001-of-00002.safetensors",
636
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
637
+ ),
638
+ (
639
+ "black-forest-labs/FLUX.1-dev",
640
+ "text_encoder_2/model-00002-of-00002.safetensors",
641
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
642
+ ),
643
+ (
644
+ "black-forest-labs/FLUX.1-dev",
645
+ "text_encoder_2/model.safetensors.index.json",
646
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
647
+ ),
648
+ ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
649
+ (
650
+ "black-forest-labs/FLUX.1-dev",
651
+ "flux1-dev.safetensors",
652
+ "models/FLUX/FLUX.1-dev",
653
+ ),
654
+ ],
655
+ "InstantX/FLUX.1-dev-IP-Adapter": {
656
+ "file_list": [
657
+ (
658
+ "InstantX/FLUX.1-dev-IP-Adapter",
659
+ "ip-adapter.bin",
660
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter",
661
+ ),
662
+ (
663
+ "google/siglip-so400m-patch14-384",
664
+ "model.safetensors",
665
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
666
+ ),
667
+ (
668
+ "google/siglip-so400m-patch14-384",
669
+ "config.json",
670
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
671
+ ),
672
+ ],
673
+ "load_path": [
674
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
675
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
676
+ ],
677
+ },
678
+ # RIFE
679
+ "RIFE": [
680
+ ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
681
+ ],
682
+ # CogVideo
683
+ "CogVideoX-5B": [
684
+ (
685
+ "THUDM/CogVideoX-5b",
686
+ "text_encoder/config.json",
687
+ "models/CogVideo/CogVideoX-5b/text_encoder",
688
+ ),
689
+ (
690
+ "THUDM/CogVideoX-5b",
691
+ "text_encoder/model.safetensors.index.json",
692
+ "models/CogVideo/CogVideoX-5b/text_encoder",
693
+ ),
694
+ (
695
+ "THUDM/CogVideoX-5b",
696
+ "text_encoder/model-00001-of-00002.safetensors",
697
+ "models/CogVideo/CogVideoX-5b/text_encoder",
698
+ ),
699
+ (
700
+ "THUDM/CogVideoX-5b",
701
+ "text_encoder/model-00002-of-00002.safetensors",
702
+ "models/CogVideo/CogVideoX-5b/text_encoder",
703
+ ),
704
+ (
705
+ "THUDM/CogVideoX-5b",
706
+ "transformer/config.json",
707
+ "models/CogVideo/CogVideoX-5b/transformer",
708
+ ),
709
+ (
710
+ "THUDM/CogVideoX-5b",
711
+ "transformer/diffusion_pytorch_model.safetensors.index.json",
712
+ "models/CogVideo/CogVideoX-5b/transformer",
713
+ ),
714
+ (
715
+ "THUDM/CogVideoX-5b",
716
+ "transformer/diffusion_pytorch_model-00001-of-00002.safetensors",
717
+ "models/CogVideo/CogVideoX-5b/transformer",
718
+ ),
719
+ (
720
+ "THUDM/CogVideoX-5b",
721
+ "transformer/diffusion_pytorch_model-00002-of-00002.safetensors",
722
+ "models/CogVideo/CogVideoX-5b/transformer",
723
+ ),
724
+ (
725
+ "THUDM/CogVideoX-5b",
726
+ "vae/diffusion_pytorch_model.safetensors",
727
+ "models/CogVideo/CogVideoX-5b/vae",
728
+ ),
729
+ ],
730
+ # Stable Diffusion 3.5
731
+ "StableDiffusion3.5-large": [
732
+ (
733
+ "stabilityai/stable-diffusion-3.5-large",
734
+ "sd3.5_large.safetensors",
735
+ "models/stable_diffusion_3",
736
+ ),
737
+ (
738
+ "stabilityai/stable-diffusion-3.5-large",
739
+ "text_encoders/clip_l.safetensors",
740
+ "models/stable_diffusion_3/text_encoders",
741
+ ),
742
+ (
743
+ "stabilityai/stable-diffusion-3.5-large",
744
+ "text_encoders/clip_g.safetensors",
745
+ "models/stable_diffusion_3/text_encoders",
746
+ ),
747
+ (
748
+ "stabilityai/stable-diffusion-3.5-large",
749
+ "text_encoders/t5xxl_fp16.safetensors",
750
+ "models/stable_diffusion_3/text_encoders",
751
+ ),
752
+ ],
753
+ }
754
+ preset_models_on_modelscope = {
755
+ # Hunyuan DiT
756
+ "HunyuanDiT": [
757
+ (
758
+ "modelscope/HunyuanDiT",
759
+ "t2i/clip_text_encoder/pytorch_model.bin",
760
+ "models/HunyuanDiT/t2i/clip_text_encoder",
761
+ ),
762
+ (
763
+ "modelscope/HunyuanDiT",
764
+ "t2i/mt5/pytorch_model.bin",
765
+ "models/HunyuanDiT/t2i/mt5",
766
+ ),
767
+ (
768
+ "modelscope/HunyuanDiT",
769
+ "t2i/model/pytorch_model_ema.pt",
770
+ "models/HunyuanDiT/t2i/model",
771
+ ),
772
+ (
773
+ "modelscope/HunyuanDiT",
774
+ "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin",
775
+ "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix",
776
+ ),
777
+ ],
778
+ # Stable Video Diffusion
779
+ "stable-video-diffusion-img2vid-xt": [
780
+ (
781
+ "AI-ModelScope/stable-video-diffusion-img2vid-xt",
782
+ "svd_xt.safetensors",
783
+ "models/stable_video_diffusion",
784
+ ),
785
+ ],
786
+ # ExVideo
787
+ "ExVideo-SVD-128f-v1": [
788
+ (
789
+ "ECNU-CILab/ExVideo-SVD-128f-v1",
790
+ "model.fp16.safetensors",
791
+ "models/stable_video_diffusion",
792
+ ),
793
+ ],
794
+ "ExVideo-CogVideoX-LoRA-129f-v1": [
795
+ (
796
+ "ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1",
797
+ "ExVideo-CogVideoX-LoRA-129f-v1.safetensors",
798
+ "models/lora",
799
+ ),
800
+ ],
801
+ # Stable Diffusion
802
+ "StableDiffusion_v15": [
803
+ (
804
+ "AI-ModelScope/stable-diffusion-v1-5",
805
+ "v1-5-pruned-emaonly.safetensors",
806
+ "models/stable_diffusion",
807
+ ),
808
+ ],
809
+ "DreamShaper_8": [
810
+ (
811
+ "sd_lora/dreamshaper_8",
812
+ "dreamshaper_8.safetensors",
813
+ "models/stable_diffusion",
814
+ ),
815
+ ],
816
+ "AingDiffusion_v12": [
817
+ (
818
+ "sd_lora/aingdiffusion_v12",
819
+ "aingdiffusion_v12.safetensors",
820
+ "models/stable_diffusion",
821
+ ),
822
+ ],
823
+ "Flat2DAnimerge_v45Sharp": [
824
+ (
825
+ "sd_lora/Flat-2D-Animerge",
826
+ "flat2DAnimerge_v45Sharp.safetensors",
827
+ "models/stable_diffusion",
828
+ ),
829
+ ],
830
+ # Textual Inversion
831
+ "TextualInversion_VeryBadImageNegative_v1.3": [
832
+ (
833
+ "sd_lora/verybadimagenegative_v1.3",
834
+ "verybadimagenegative_v1.3.pt",
835
+ "models/textual_inversion",
836
+ ),
837
+ ],
838
+ # Stable Diffusion XL
839
+ "StableDiffusionXL_v1": [
840
+ (
841
+ "AI-ModelScope/stable-diffusion-xl-base-1.0",
842
+ "sd_xl_base_1.0.safetensors",
843
+ "models/stable_diffusion_xl",
844
+ ),
845
+ ],
846
+ "BluePencilXL_v200": [
847
+ (
848
+ "sd_lora/bluePencilXL_v200",
849
+ "bluePencilXL_v200.safetensors",
850
+ "models/stable_diffusion_xl",
851
+ ),
852
+ ],
853
+ "StableDiffusionXL_Turbo": [
854
+ (
855
+ "AI-ModelScope/sdxl-turbo",
856
+ "sd_xl_turbo_1.0_fp16.safetensors",
857
+ "models/stable_diffusion_xl_turbo",
858
+ ),
859
+ ],
860
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
861
+ (
862
+ "sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0",
863
+ "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors",
864
+ "models/lora",
865
+ ),
866
+ ],
867
+ # Stable Diffusion 3
868
+ "StableDiffusion3": [
869
+ (
870
+ "AI-ModelScope/stable-diffusion-3-medium",
871
+ "sd3_medium_incl_clips_t5xxlfp16.safetensors",
872
+ "models/stable_diffusion_3",
873
+ ),
874
+ ],
875
+ "StableDiffusion3_without_T5": [
876
+ (
877
+ "AI-ModelScope/stable-diffusion-3-medium",
878
+ "sd3_medium_incl_clips.safetensors",
879
+ "models/stable_diffusion_3",
880
+ ),
881
+ ],
882
+ # ControlNet
883
+ "ControlNet_v11f1p_sd15_depth": [
884
+ (
885
+ "AI-ModelScope/ControlNet-v1-1",
886
+ "control_v11f1p_sd15_depth.pth",
887
+ "models/ControlNet",
888
+ ),
889
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
890
+ ],
891
+ "ControlNet_v11p_sd15_softedge": [
892
+ (
893
+ "AI-ModelScope/ControlNet-v1-1",
894
+ "control_v11p_sd15_softedge.pth",
895
+ "models/ControlNet",
896
+ ),
897
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
898
+ ],
899
+ "ControlNet_v11f1e_sd15_tile": [
900
+ (
901
+ "AI-ModelScope/ControlNet-v1-1",
902
+ "control_v11f1e_sd15_tile.pth",
903
+ "models/ControlNet",
904
+ )
905
+ ],
906
+ "ControlNet_v11p_sd15_lineart": [
907
+ (
908
+ "AI-ModelScope/ControlNet-v1-1",
909
+ "control_v11p_sd15_lineart.pth",
910
+ "models/ControlNet",
911
+ ),
912
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
913
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
914
+ ],
915
+ "ControlNet_union_sdxl_promax": [
916
+ (
917
+ "AI-ModelScope/controlnet-union-sdxl-1.0",
918
+ "diffusion_pytorch_model_promax.safetensors",
919
+ "models/ControlNet/controlnet_union",
920
+ ),
921
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
922
+ ],
923
+ "Annotators:Depth": [
924
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
925
+ ],
926
+ "Annotators:Softedge": [
927
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
928
+ ],
929
+ "Annotators:Lineart": [
930
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
931
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
932
+ ],
933
+ "Annotators:Normal": [
934
+ ("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
935
+ ],
936
+ "Annotators:Openpose": [
937
+ ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
938
+ ("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
939
+ ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
940
+ ],
941
+ # AnimateDiff
942
+ "AnimateDiff_v2": [
943
+ (
944
+ "Shanghai_AI_Laboratory/animatediff",
945
+ "mm_sd_v15_v2.ckpt",
946
+ "models/AnimateDiff",
947
+ ),
948
+ ],
949
+ "AnimateDiff_xl_beta": [
950
+ (
951
+ "Shanghai_AI_Laboratory/animatediff",
952
+ "mm_sdxl_v10_beta.ckpt",
953
+ "models/AnimateDiff",
954
+ ),
955
+ ],
956
+ # RIFE
957
+ "RIFE": [
958
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
959
+ ],
960
+ # Qwen Prompt
961
+ "QwenPrompt": {
962
+ "file_list": [
963
+ (
964
+ "qwen/Qwen2-1.5B-Instruct",
965
+ "config.json",
966
+ "models/QwenPrompt/qwen2-1.5b-instruct",
967
+ ),
968
+ (
969
+ "qwen/Qwen2-1.5B-Instruct",
970
+ "generation_config.json",
971
+ "models/QwenPrompt/qwen2-1.5b-instruct",
972
+ ),
973
+ (
974
+ "qwen/Qwen2-1.5B-Instruct",
975
+ "model.safetensors",
976
+ "models/QwenPrompt/qwen2-1.5b-instruct",
977
+ ),
978
+ (
979
+ "qwen/Qwen2-1.5B-Instruct",
980
+ "special_tokens_map.json",
981
+ "models/QwenPrompt/qwen2-1.5b-instruct",
982
+ ),
983
+ (
984
+ "qwen/Qwen2-1.5B-Instruct",
985
+ "tokenizer.json",
986
+ "models/QwenPrompt/qwen2-1.5b-instruct",
987
+ ),
988
+ (
989
+ "qwen/Qwen2-1.5B-Instruct",
990
+ "tokenizer_config.json",
991
+ "models/QwenPrompt/qwen2-1.5b-instruct",
992
+ ),
993
+ (
994
+ "qwen/Qwen2-1.5B-Instruct",
995
+ "merges.txt",
996
+ "models/QwenPrompt/qwen2-1.5b-instruct",
997
+ ),
998
+ (
999
+ "qwen/Qwen2-1.5B-Instruct",
1000
+ "vocab.json",
1001
+ "models/QwenPrompt/qwen2-1.5b-instruct",
1002
+ ),
1003
+ ],
1004
+ "load_path": [
1005
+ "models/QwenPrompt/qwen2-1.5b-instruct",
1006
+ ],
1007
+ },
1008
+ # Beautiful Prompt
1009
+ "BeautifulPrompt": {
1010
+ "file_list": [
1011
+ (
1012
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
1013
+ "config.json",
1014
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
1015
+ ),
1016
+ (
1017
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
1018
+ "generation_config.json",
1019
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
1020
+ ),
1021
+ (
1022
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
1023
+ "model.safetensors",
1024
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
1025
+ ),
1026
+ (
1027
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
1028
+ "special_tokens_map.json",
1029
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
1030
+ ),
1031
+ (
1032
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
1033
+ "tokenizer.json",
1034
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
1035
+ ),
1036
+ (
1037
+ "AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
1038
+ "tokenizer_config.json",
1039
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
1040
+ ),
1041
+ ],
1042
+ "load_path": [
1043
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
1044
+ ],
1045
+ },
1046
+ # Omost prompt
1047
+ "OmostPrompt": {
1048
+ "file_list": [
1049
+ (
1050
+ "Omost/omost-llama-3-8b-4bits",
1051
+ "model-00001-of-00002.safetensors",
1052
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
1053
+ ),
1054
+ (
1055
+ "Omost/omost-llama-3-8b-4bits",
1056
+ "model-00002-of-00002.safetensors",
1057
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
1058
+ ),
1059
+ (
1060
+ "Omost/omost-llama-3-8b-4bits",
1061
+ "tokenizer.json",
1062
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
1063
+ ),
1064
+ (
1065
+ "Omost/omost-llama-3-8b-4bits",
1066
+ "tokenizer_config.json",
1067
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
1068
+ ),
1069
+ (
1070
+ "Omost/omost-llama-3-8b-4bits",
1071
+ "config.json",
1072
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
1073
+ ),
1074
+ (
1075
+ "Omost/omost-llama-3-8b-4bits",
1076
+ "generation_config.json",
1077
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
1078
+ ),
1079
+ (
1080
+ "Omost/omost-llama-3-8b-4bits",
1081
+ "model.safetensors.index.json",
1082
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
1083
+ ),
1084
+ (
1085
+ "Omost/omost-llama-3-8b-4bits",
1086
+ "special_tokens_map.json",
1087
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
1088
+ ),
1089
+ ],
1090
+ "load_path": [
1091
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
1092
+ ],
1093
+ },
1094
+ # Translator
1095
+ "opus-mt-zh-en": {
1096
+ "file_list": [
1097
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
1098
+ (
1099
+ "moxying/opus-mt-zh-en",
1100
+ "generation_config.json",
1101
+ "models/translator/opus-mt-zh-en",
1102
+ ),
1103
+ (
1104
+ "moxying/opus-mt-zh-en",
1105
+ "metadata.json",
1106
+ "models/translator/opus-mt-zh-en",
1107
+ ),
1108
+ (
1109
+ "moxying/opus-mt-zh-en",
1110
+ "pytorch_model.bin",
1111
+ "models/translator/opus-mt-zh-en",
1112
+ ),
1113
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
1114
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
1115
+ (
1116
+ "moxying/opus-mt-zh-en",
1117
+ "tokenizer_config.json",
1118
+ "models/translator/opus-mt-zh-en",
1119
+ ),
1120
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
1121
+ ],
1122
+ "load_path": [
1123
+ "models/translator/opus-mt-zh-en",
1124
+ ],
1125
+ },
1126
+ # IP-Adapter
1127
+ "IP-Adapter-SD": [
1128
+ (
1129
+ "AI-ModelScope/IP-Adapter",
1130
+ "models/image_encoder/model.safetensors",
1131
+ "models/IpAdapter/stable_diffusion/image_encoder",
1132
+ ),
1133
+ (
1134
+ "AI-ModelScope/IP-Adapter",
1135
+ "models/ip-adapter_sd15.bin",
1136
+ "models/IpAdapter/stable_diffusion",
1137
+ ),
1138
+ ],
1139
+ "IP-Adapter-SDXL": [
1140
+ (
1141
+ "AI-ModelScope/IP-Adapter",
1142
+ "sdxl_models/image_encoder/model.safetensors",
1143
+ "models/IpAdapter/stable_diffusion_xl/image_encoder",
1144
+ ),
1145
+ (
1146
+ "AI-ModelScope/IP-Adapter",
1147
+ "sdxl_models/ip-adapter_sdxl.bin",
1148
+ "models/IpAdapter/stable_diffusion_xl",
1149
+ ),
1150
+ ],
1151
+ # Kolors
1152
+ "Kolors": {
1153
+ "file_list": [
1154
+ (
1155
+ "Kwai-Kolors/Kolors",
1156
+ "text_encoder/config.json",
1157
+ "models/kolors/Kolors/text_encoder",
1158
+ ),
1159
+ (
1160
+ "Kwai-Kolors/Kolors",
1161
+ "text_encoder/pytorch_model.bin.index.json",
1162
+ "models/kolors/Kolors/text_encoder",
1163
+ ),
1164
+ (
1165
+ "Kwai-Kolors/Kolors",
1166
+ "text_encoder/pytorch_model-00001-of-00007.bin",
1167
+ "models/kolors/Kolors/text_encoder",
1168
+ ),
1169
+ (
1170
+ "Kwai-Kolors/Kolors",
1171
+ "text_encoder/pytorch_model-00002-of-00007.bin",
1172
+ "models/kolors/Kolors/text_encoder",
1173
+ ),
1174
+ (
1175
+ "Kwai-Kolors/Kolors",
1176
+ "text_encoder/pytorch_model-00003-of-00007.bin",
1177
+ "models/kolors/Kolors/text_encoder",
1178
+ ),
1179
+ (
1180
+ "Kwai-Kolors/Kolors",
1181
+ "text_encoder/pytorch_model-00004-of-00007.bin",
1182
+ "models/kolors/Kolors/text_encoder",
1183
+ ),
1184
+ (
1185
+ "Kwai-Kolors/Kolors",
1186
+ "text_encoder/pytorch_model-00005-of-00007.bin",
1187
+ "models/kolors/Kolors/text_encoder",
1188
+ ),
1189
+ (
1190
+ "Kwai-Kolors/Kolors",
1191
+ "text_encoder/pytorch_model-00006-of-00007.bin",
1192
+ "models/kolors/Kolors/text_encoder",
1193
+ ),
1194
+ (
1195
+ "Kwai-Kolors/Kolors",
1196
+ "text_encoder/pytorch_model-00007-of-00007.bin",
1197
+ "models/kolors/Kolors/text_encoder",
1198
+ ),
1199
+ (
1200
+ "Kwai-Kolors/Kolors",
1201
+ "unet/diffusion_pytorch_model.safetensors",
1202
+ "models/kolors/Kolors/unet",
1203
+ ),
1204
+ (
1205
+ "Kwai-Kolors/Kolors",
1206
+ "vae/diffusion_pytorch_model.safetensors",
1207
+ "models/kolors/Kolors/vae",
1208
+ ),
1209
+ ],
1210
+ "load_path": [
1211
+ "models/kolors/Kolors/text_encoder",
1212
+ "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
1213
+ "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
1214
+ ],
1215
+ },
1216
+ "SDXL-vae-fp16-fix": [
1217
+ (
1218
+ "AI-ModelScope/sdxl-vae-fp16-fix",
1219
+ "diffusion_pytorch_model.safetensors",
1220
+ "models/sdxl-vae-fp16-fix",
1221
+ )
1222
+ ],
1223
+ # FLUX
1224
+ "FLUX.1-dev": {
1225
+ "file_list": [
1226
+ (
1227
+ "AI-ModelScope/FLUX.1-dev",
1228
+ "text_encoder/model.safetensors",
1229
+ "models/FLUX/FLUX.1-dev/text_encoder",
1230
+ ),
1231
+ (
1232
+ "AI-ModelScope/FLUX.1-dev",
1233
+ "text_encoder_2/config.json",
1234
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1235
+ ),
1236
+ (
1237
+ "AI-ModelScope/FLUX.1-dev",
1238
+ "text_encoder_2/model-00001-of-00002.safetensors",
1239
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1240
+ ),
1241
+ (
1242
+ "AI-ModelScope/FLUX.1-dev",
1243
+ "text_encoder_2/model-00002-of-00002.safetensors",
1244
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1245
+ ),
1246
+ (
1247
+ "AI-ModelScope/FLUX.1-dev",
1248
+ "text_encoder_2/model.safetensors.index.json",
1249
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1250
+ ),
1251
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
1252
+ (
1253
+ "AI-ModelScope/FLUX.1-dev",
1254
+ "flux1-dev.safetensors",
1255
+ "models/FLUX/FLUX.1-dev",
1256
+ ),
1257
+ ],
1258
+ "load_path": [
1259
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
1260
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1261
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
1262
+ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
1263
+ ],
1264
+ },
1265
+ "FLUX.1-schnell": {
1266
+ "file_list": [
1267
+ (
1268
+ "AI-ModelScope/FLUX.1-dev",
1269
+ "text_encoder/model.safetensors",
1270
+ "models/FLUX/FLUX.1-dev/text_encoder",
1271
+ ),
1272
+ (
1273
+ "AI-ModelScope/FLUX.1-dev",
1274
+ "text_encoder_2/config.json",
1275
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1276
+ ),
1277
+ (
1278
+ "AI-ModelScope/FLUX.1-dev",
1279
+ "text_encoder_2/model-00001-of-00002.safetensors",
1280
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1281
+ ),
1282
+ (
1283
+ "AI-ModelScope/FLUX.1-dev",
1284
+ "text_encoder_2/model-00002-of-00002.safetensors",
1285
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1286
+ ),
1287
+ (
1288
+ "AI-ModelScope/FLUX.1-dev",
1289
+ "text_encoder_2/model.safetensors.index.json",
1290
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1291
+ ),
1292
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
1293
+ (
1294
+ "AI-ModelScope/FLUX.1-schnell",
1295
+ "flux1-schnell.safetensors",
1296
+ "models/FLUX/FLUX.1-schnell",
1297
+ ),
1298
+ ],
1299
+ "load_path": [
1300
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
1301
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
1302
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
1303
+ "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors",
1304
+ ],
1305
+ },
1306
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
1307
+ (
1308
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
1309
+ "diffusion_pytorch_model.safetensors",
1310
+ "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha",
1311
+ ),
1312
+ ],
1313
+ "jasperai/Flux.1-dev-Controlnet-Depth": [
1314
+ (
1315
+ "jasperai/Flux.1-dev-Controlnet-Depth",
1316
+ "diffusion_pytorch_model.safetensors",
1317
+ "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth",
1318
+ ),
1319
+ ],
1320
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
1321
+ (
1322
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
1323
+ "diffusion_pytorch_model.safetensors",
1324
+ "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals",
1325
+ ),
1326
+ ],
1327
+ "jasperai/Flux.1-dev-Controlnet-Upscaler": [
1328
+ (
1329
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
1330
+ "diffusion_pytorch_model.safetensors",
1331
+ "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler",
1332
+ ),
1333
+ ],
1334
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
1335
+ (
1336
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
1337
+ "diffusion_pytorch_model.safetensors",
1338
+ "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
1339
+ ),
1340
+ ],
1341
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
1342
+ (
1343
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
1344
+ "diffusion_pytorch_model.safetensors",
1345
+ "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
1346
+ ),
1347
+ ],
1348
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
1349
+ (
1350
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
1351
+ "diffusion_pytorch_model.safetensors",
1352
+ "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
1353
+ ),
1354
+ ],
1355
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
1356
+ (
1357
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
1358
+ "diffusion_pytorch_model.safetensors",
1359
+ "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
1360
+ ),
1361
+ ],
1362
+ "InstantX/FLUX.1-dev-IP-Adapter": {
1363
+ "file_list": [
1364
+ (
1365
+ "InstantX/FLUX.1-dev-IP-Adapter",
1366
+ "ip-adapter.bin",
1367
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter",
1368
+ ),
1369
+ (
1370
+ "AI-ModelScope/siglip-so400m-patch14-384",
1371
+ "model.safetensors",
1372
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
1373
+ ),
1374
+ (
1375
+ "AI-ModelScope/siglip-so400m-patch14-384",
1376
+ "config.json",
1377
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
1378
+ ),
1379
+ ],
1380
+ "load_path": [
1381
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
1382
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
1383
+ ],
1384
+ },
1385
+ "InfiniteYou": {
1386
+ "file_list": [
1387
+ (
1388
+ "ByteDance/InfiniteYou",
1389
+ "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
1390
+ "models/InfiniteYou/InfuseNetModel",
1391
+ ),
1392
+ (
1393
+ "ByteDance/InfiniteYou",
1394
+ "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors",
1395
+ "models/InfiniteYou/InfuseNetModel",
1396
+ ),
1397
+ (
1398
+ "ByteDance/InfiniteYou",
1399
+ "infu_flux_v1.0/aes_stage2/image_proj_model.bin",
1400
+ "models/InfiniteYou",
1401
+ ),
1402
+ (
1403
+ "ByteDance/InfiniteYou",
1404
+ "supports/insightface/models/antelopev2/1k3d68.onnx",
1405
+ "models/InfiniteYou/insightface/models/antelopev2",
1406
+ ),
1407
+ (
1408
+ "ByteDance/InfiniteYou",
1409
+ "supports/insightface/models/antelopev2/2d106det.onnx",
1410
+ "models/InfiniteYou/insightface/models/antelopev2",
1411
+ ),
1412
+ (
1413
+ "ByteDance/InfiniteYou",
1414
+ "supports/insightface/models/antelopev2/genderage.onnx",
1415
+ "models/InfiniteYou/insightface/models/antelopev2",
1416
+ ),
1417
+ (
1418
+ "ByteDance/InfiniteYou",
1419
+ "supports/insightface/models/antelopev2/glintr100.onnx",
1420
+ "models/InfiniteYou/insightface/models/antelopev2",
1421
+ ),
1422
+ (
1423
+ "ByteDance/InfiniteYou",
1424
+ "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx",
1425
+ "models/InfiniteYou/insightface/models/antelopev2",
1426
+ ),
1427
+ ],
1428
+ "load_path": [
1429
+ [
1430
+ "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
1431
+ "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors",
1432
+ ],
1433
+ "models/InfiniteYou/image_proj_model.bin",
1434
+ ],
1435
+ },
1436
+ # ESRGAN
1437
+ "ESRGAN_x4": [
1438
+ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
1439
+ ],
1440
+ # RIFE
1441
+ "RIFE": [
1442
+ ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
1443
+ ],
1444
+ # Omnigen
1445
+ "OmniGen-v1": {
1446
+ "file_list": [
1447
+ (
1448
+ "BAAI/OmniGen-v1",
1449
+ "vae/diffusion_pytorch_model.safetensors",
1450
+ "models/OmniGen/OmniGen-v1/vae",
1451
+ ),
1452
+ ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
1453
+ ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
1454
+ ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
1455
+ ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
1456
+ ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
1457
+ ],
1458
+ "load_path": [
1459
+ "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
1460
+ "models/OmniGen/OmniGen-v1/model.safetensors",
1461
+ ],
1462
+ },
1463
+ # CogVideo
1464
+ "CogVideoX-5B": {
1465
+ "file_list": [
1466
+ (
1467
+ "ZhipuAI/CogVideoX-5b",
1468
+ "text_encoder/config.json",
1469
+ "models/CogVideo/CogVideoX-5b/text_encoder",
1470
+ ),
1471
+ (
1472
+ "ZhipuAI/CogVideoX-5b",
1473
+ "text_encoder/model.safetensors.index.json",
1474
+ "models/CogVideo/CogVideoX-5b/text_encoder",
1475
+ ),
1476
+ (
1477
+ "ZhipuAI/CogVideoX-5b",
1478
+ "text_encoder/model-00001-of-00002.safetensors",
1479
+ "models/CogVideo/CogVideoX-5b/text_encoder",
1480
+ ),
1481
+ (
1482
+ "ZhipuAI/CogVideoX-5b",
1483
+ "text_encoder/model-00002-of-00002.safetensors",
1484
+ "models/CogVideo/CogVideoX-5b/text_encoder",
1485
+ ),
1486
+ (
1487
+ "ZhipuAI/CogVideoX-5b",
1488
+ "transformer/config.json",
1489
+ "models/CogVideo/CogVideoX-5b/transformer",
1490
+ ),
1491
+ (
1492
+ "ZhipuAI/CogVideoX-5b",
1493
+ "transformer/diffusion_pytorch_model.safetensors.index.json",
1494
+ "models/CogVideo/CogVideoX-5b/transformer",
1495
+ ),
1496
+ (
1497
+ "ZhipuAI/CogVideoX-5b",
1498
+ "transformer/diffusion_pytorch_model-00001-of-00002.safetensors",
1499
+ "models/CogVideo/CogVideoX-5b/transformer",
1500
+ ),
1501
+ (
1502
+ "ZhipuAI/CogVideoX-5b",
1503
+ "transformer/diffusion_pytorch_model-00002-of-00002.safetensors",
1504
+ "models/CogVideo/CogVideoX-5b/transformer",
1505
+ ),
1506
+ (
1507
+ "ZhipuAI/CogVideoX-5b",
1508
+ "vae/diffusion_pytorch_model.safetensors",
1509
+ "models/CogVideo/CogVideoX-5b/vae",
1510
+ ),
1511
+ ],
1512
+ "load_path": [
1513
+ "models/CogVideo/CogVideoX-5b/text_encoder",
1514
+ "models/CogVideo/CogVideoX-5b/transformer",
1515
+ "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
1516
+ ],
1517
+ },
1518
+ # Stable Diffusion 3.5
1519
+ "StableDiffusion3.5-large": [
1520
+ (
1521
+ "AI-ModelScope/stable-diffusion-3.5-large",
1522
+ "sd3.5_large.safetensors",
1523
+ "models/stable_diffusion_3",
1524
+ ),
1525
+ (
1526
+ "AI-ModelScope/stable-diffusion-3.5-large",
1527
+ "text_encoders/clip_l.safetensors",
1528
+ "models/stable_diffusion_3/text_encoders",
1529
+ ),
1530
+ (
1531
+ "AI-ModelScope/stable-diffusion-3.5-large",
1532
+ "text_encoders/clip_g.safetensors",
1533
+ "models/stable_diffusion_3/text_encoders",
1534
+ ),
1535
+ (
1536
+ "AI-ModelScope/stable-diffusion-3.5-large",
1537
+ "text_encoders/t5xxl_fp16.safetensors",
1538
+ "models/stable_diffusion_3/text_encoders",
1539
+ ),
1540
+ ],
1541
+ "StableDiffusion3.5-medium": [
1542
+ (
1543
+ "AI-ModelScope/stable-diffusion-3.5-medium",
1544
+ "sd3.5_medium.safetensors",
1545
+ "models/stable_diffusion_3",
1546
+ ),
1547
+ (
1548
+ "AI-ModelScope/stable-diffusion-3.5-large",
1549
+ "text_encoders/clip_l.safetensors",
1550
+ "models/stable_diffusion_3/text_encoders",
1551
+ ),
1552
+ (
1553
+ "AI-ModelScope/stable-diffusion-3.5-large",
1554
+ "text_encoders/clip_g.safetensors",
1555
+ "models/stable_diffusion_3/text_encoders",
1556
+ ),
1557
+ (
1558
+ "AI-ModelScope/stable-diffusion-3.5-large",
1559
+ "text_encoders/t5xxl_fp16.safetensors",
1560
+ "models/stable_diffusion_3/text_encoders",
1561
+ ),
1562
+ ],
1563
+ "StableDiffusion3.5-large-turbo": [
1564
+ (
1565
+ "AI-ModelScope/stable-diffusion-3.5-large-turbo",
1566
+ "sd3.5_large_turbo.safetensors",
1567
+ "models/stable_diffusion_3",
1568
+ ),
1569
+ (
1570
+ "AI-ModelScope/stable-diffusion-3.5-large",
1571
+ "text_encoders/clip_l.safetensors",
1572
+ "models/stable_diffusion_3/text_encoders",
1573
+ ),
1574
+ (
1575
+ "AI-ModelScope/stable-diffusion-3.5-large",
1576
+ "text_encoders/clip_g.safetensors",
1577
+ "models/stable_diffusion_3/text_encoders",
1578
+ ),
1579
+ (
1580
+ "AI-ModelScope/stable-diffusion-3.5-large",
1581
+ "text_encoders/t5xxl_fp16.safetensors",
1582
+ "models/stable_diffusion_3/text_encoders",
1583
+ ),
1584
+ ],
1585
+ "HunyuanVideo": {
1586
+ "file_list": [
1587
+ (
1588
+ "AI-ModelScope/clip-vit-large-patch14",
1589
+ "model.safetensors",
1590
+ "models/HunyuanVideo/text_encoder",
1591
+ ),
1592
+ (
1593
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1594
+ "model-00001-of-00004.safetensors",
1595
+ "models/HunyuanVideo/text_encoder_2",
1596
+ ),
1597
+ (
1598
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1599
+ "model-00002-of-00004.safetensors",
1600
+ "models/HunyuanVideo/text_encoder_2",
1601
+ ),
1602
+ (
1603
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1604
+ "model-00003-of-00004.safetensors",
1605
+ "models/HunyuanVideo/text_encoder_2",
1606
+ ),
1607
+ (
1608
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1609
+ "model-00004-of-00004.safetensors",
1610
+ "models/HunyuanVideo/text_encoder_2",
1611
+ ),
1612
+ (
1613
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1614
+ "config.json",
1615
+ "models/HunyuanVideo/text_encoder_2",
1616
+ ),
1617
+ (
1618
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1619
+ "model.safetensors.index.json",
1620
+ "models/HunyuanVideo/text_encoder_2",
1621
+ ),
1622
+ (
1623
+ "AI-ModelScope/HunyuanVideo",
1624
+ "hunyuan-video-t2v-720p/vae/pytorch_model.pt",
1625
+ "models/HunyuanVideo/vae",
1626
+ ),
1627
+ (
1628
+ "AI-ModelScope/HunyuanVideo",
1629
+ "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
1630
+ "models/HunyuanVideo/transformers",
1631
+ ),
1632
+ ],
1633
+ "load_path": [
1634
+ "models/HunyuanVideo/text_encoder/model.safetensors",
1635
+ "models/HunyuanVideo/text_encoder_2",
1636
+ "models/HunyuanVideo/vae/pytorch_model.pt",
1637
+ "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt",
1638
+ ],
1639
+ },
1640
+ "HunyuanVideoI2V": {
1641
+ "file_list": [
1642
+ (
1643
+ "AI-ModelScope/clip-vit-large-patch14",
1644
+ "model.safetensors",
1645
+ "models/HunyuanVideoI2V/text_encoder",
1646
+ ),
1647
+ (
1648
+ "AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
1649
+ "model-00001-of-00004.safetensors",
1650
+ "models/HunyuanVideoI2V/text_encoder_2",
1651
+ ),
1652
+ (
1653
+ "AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
1654
+ "model-00002-of-00004.safetensors",
1655
+ "models/HunyuanVideoI2V/text_encoder_2",
1656
+ ),
1657
+ (
1658
+ "AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
1659
+ "model-00003-of-00004.safetensors",
1660
+ "models/HunyuanVideoI2V/text_encoder_2",
1661
+ ),
1662
+ (
1663
+ "AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
1664
+ "model-00004-of-00004.safetensors",
1665
+ "models/HunyuanVideoI2V/text_encoder_2",
1666
+ ),
1667
+ (
1668
+ "AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
1669
+ "config.json",
1670
+ "models/HunyuanVideoI2V/text_encoder_2",
1671
+ ),
1672
+ (
1673
+ "AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
1674
+ "model.safetensors.index.json",
1675
+ "models/HunyuanVideoI2V/text_encoder_2",
1676
+ ),
1677
+ (
1678
+ "AI-ModelScope/HunyuanVideo-I2V",
1679
+ "hunyuan-video-i2v-720p/vae/pytorch_model.pt",
1680
+ "models/HunyuanVideoI2V/vae",
1681
+ ),
1682
+ (
1683
+ "AI-ModelScope/HunyuanVideo-I2V",
1684
+ "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt",
1685
+ "models/HunyuanVideoI2V/transformers",
1686
+ ),
1687
+ ],
1688
+ "load_path": [
1689
+ "models/HunyuanVideoI2V/text_encoder/model.safetensors",
1690
+ "models/HunyuanVideoI2V/text_encoder_2",
1691
+ "models/HunyuanVideoI2V/vae/pytorch_model.pt",
1692
+ "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt",
1693
+ ],
1694
+ },
1695
+ "HunyuanVideo-fp8": {
1696
+ "file_list": [
1697
+ (
1698
+ "AI-ModelScope/clip-vit-large-patch14",
1699
+ "model.safetensors",
1700
+ "models/HunyuanVideo/text_encoder",
1701
+ ),
1702
+ (
1703
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1704
+ "model-00001-of-00004.safetensors",
1705
+ "models/HunyuanVideo/text_encoder_2",
1706
+ ),
1707
+ (
1708
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1709
+ "model-00002-of-00004.safetensors",
1710
+ "models/HunyuanVideo/text_encoder_2",
1711
+ ),
1712
+ (
1713
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1714
+ "model-00003-of-00004.safetensors",
1715
+ "models/HunyuanVideo/text_encoder_2",
1716
+ ),
1717
+ (
1718
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1719
+ "model-00004-of-00004.safetensors",
1720
+ "models/HunyuanVideo/text_encoder_2",
1721
+ ),
1722
+ (
1723
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1724
+ "config.json",
1725
+ "models/HunyuanVideo/text_encoder_2",
1726
+ ),
1727
+ (
1728
+ "DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
1729
+ "model.safetensors.index.json",
1730
+ "models/HunyuanVideo/text_encoder_2",
1731
+ ),
1732
+ (
1733
+ "AI-ModelScope/HunyuanVideo",
1734
+ "hunyuan-video-t2v-720p/vae/pytorch_model.pt",
1735
+ "models/HunyuanVideo/vae",
1736
+ ),
1737
+ (
1738
+ "DiffSynth-Studio/HunyuanVideo-safetensors",
1739
+ "model.fp8.safetensors",
1740
+ "models/HunyuanVideo/transformers",
1741
+ ),
1742
+ ],
1743
+ "load_path": [
1744
+ "models/HunyuanVideo/text_encoder/model.safetensors",
1745
+ "models/HunyuanVideo/text_encoder_2",
1746
+ "models/HunyuanVideo/vae/pytorch_model.pt",
1747
+ "models/HunyuanVideo/transformers/model.fp8.safetensors",
1748
+ ],
1749
+ },
1750
+ }
1751
+ Preset_model_id: TypeAlias = Literal[
1752
+ "HunyuanDiT",
1753
+ "stable-video-diffusion-img2vid-xt",
1754
+ "ExVideo-SVD-128f-v1",
1755
+ "ExVideo-CogVideoX-LoRA-129f-v1",
1756
+ "StableDiffusion_v15",
1757
+ "DreamShaper_8",
1758
+ "AingDiffusion_v12",
1759
+ "Flat2DAnimerge_v45Sharp",
1760
+ "TextualInversion_VeryBadImageNegative_v1.3",
1761
+ "StableDiffusionXL_v1",
1762
+ "BluePencilXL_v200",
1763
+ "StableDiffusionXL_Turbo",
1764
+ "ControlNet_v11f1p_sd15_depth",
1765
+ "ControlNet_v11p_sd15_softedge",
1766
+ "ControlNet_v11f1e_sd15_tile",
1767
+ "ControlNet_v11p_sd15_lineart",
1768
+ "AnimateDiff_v2",
1769
+ "AnimateDiff_xl_beta",
1770
+ "RIFE",
1771
+ "BeautifulPrompt",
1772
+ "opus-mt-zh-en",
1773
+ "IP-Adapter-SD",
1774
+ "IP-Adapter-SDXL",
1775
+ "StableDiffusion3",
1776
+ "StableDiffusion3_without_T5",
1777
+ "Kolors",
1778
+ "SDXL-vae-fp16-fix",
1779
+ "ControlNet_union_sdxl_promax",
1780
+ "FLUX.1-dev",
1781
+ "FLUX.1-schnell",
1782
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
1783
+ "jasperai/Flux.1-dev-Controlnet-Depth",
1784
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
1785
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
1786
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
1787
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
1788
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
1789
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
1790
+ "InstantX/FLUX.1-dev-IP-Adapter",
1791
+ "InfiniteYou",
1792
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
1793
+ "QwenPrompt",
1794
+ "OmostPrompt",
1795
+ "ESRGAN_x4",
1796
+ "RIFE",
1797
+ "OmniGen-v1",
1798
+ "CogVideoX-5B",
1799
+ "Annotators:Depth",
1800
+ "Annotators:Softedge",
1801
+ "Annotators:Lineart",
1802
+ "Annotators:Normal",
1803
+ "Annotators:Openpose",
1804
+ "StableDiffusion3.5-large",
1805
+ "StableDiffusion3.5-medium",
1806
+ "HunyuanVideo",
1807
+ "HunyuanVideo-fp8",
1808
+ "HunyuanVideoI2V",
1809
+ ]
data/video.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+
6
+
7
+ class LowMemoryVideo:
8
+ def __init__(self, file_name):
9
+ self.reader = imageio.get_reader(file_name)
10
+
11
+ def __len__(self):
12
+ return self.reader.count_frames()
13
+
14
+ def __getitem__(self, item):
15
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
16
+
17
+ def __del__(self):
18
+ self.reader.close()
19
+
20
+
21
+ def split_file_name(file_name):
22
+ result = []
23
+ number = -1
24
+ for i in file_name:
25
+ if ord(i) >= ord("0") and ord(i) <= ord("9"):
26
+ if number == -1:
27
+ number = 0
28
+ number = number * 10 + ord(i) - ord("0")
29
+ else:
30
+ if number != -1:
31
+ result.append(number)
32
+ number = -1
33
+ result.append(i)
34
+ if number != -1:
35
+ result.append(number)
36
+ result = tuple(result)
37
+ return result
38
+
39
+
40
+ def search_for_images(folder):
41
+ file_list = [
42
+ i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")
43
+ ]
44
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
45
+ file_list = [i[1] for i in sorted(file_list)]
46
+ file_list = [os.path.join(folder, i) for i in file_list]
47
+ return file_list
48
+
49
+
50
+ class LowMemoryImageFolder:
51
+ def __init__(self, folder, file_list=None):
52
+ if file_list is None:
53
+ self.file_list = search_for_images(folder)
54
+ else:
55
+ self.file_list = [
56
+ os.path.join(folder, file_name) for file_name in file_list
57
+ ]
58
+
59
+ def __len__(self):
60
+ return len(self.file_list)
61
+
62
+ def __getitem__(self, item):
63
+ return Image.open(self.file_list[item]).convert("RGB")
64
+
65
+ def __del__(self):
66
+ pass
67
+
68
+
69
+ def crop_and_resize(image, height, width):
70
+ image = np.array(image)
71
+ image_height, image_width, _ = image.shape
72
+ if image_height / image_width < height / width:
73
+ croped_width = int(image_height / height * width)
74
+ left = (image_width - croped_width) // 2
75
+ image = image[:, left : left + croped_width]
76
+ image = Image.fromarray(image).resize((width, height))
77
+ else:
78
+ croped_height = int(image_width / width * height)
79
+ left = (image_height - croped_height) // 2
80
+ image = image[left : left + croped_height, :]
81
+ image = Image.fromarray(image).resize((width, height))
82
+ return image
83
+
84
+
85
+ class VideoData:
86
+ def __init__(
87
+ self, video_file=None, image_folder=None, height=None, width=None, **kwargs
88
+ ):
89
+ if video_file is not None:
90
+ self.data_type = "video"
91
+ self.data = LowMemoryVideo(video_file, **kwargs)
92
+ elif image_folder is not None:
93
+ self.data_type = "images"
94
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
95
+ else:
96
+ raise ValueError("Cannot open video or image folder")
97
+ self.length = None
98
+ self.set_shape(height, width)
99
+
100
+ def raw_data(self):
101
+ frames = []
102
+ for i in range(self.__len__()):
103
+ frames.append(self.__getitem__(i))
104
+ return frames
105
+
106
+ def set_length(self, length):
107
+ self.length = length
108
+
109
+ def set_shape(self, height, width):
110
+ self.height = height
111
+ self.width = width
112
+
113
+ def __len__(self):
114
+ if self.length is None:
115
+ return len(self.data)
116
+ else:
117
+ return self.length
118
+
119
+ def shape(self):
120
+ if self.height is not None and self.width is not None:
121
+ return self.height, self.width
122
+ else:
123
+ height, width, _ = self.__getitem__(0).shape
124
+ return height, width
125
+
126
+ def __getitem__(self, item):
127
+ frame = self.data.__getitem__(item)
128
+ width, height = frame.size
129
+ if self.height is not None and self.width is not None:
130
+ if self.height != height or self.width != width:
131
+ frame = crop_and_resize(frame, self.height, self.width)
132
+ return frame
133
+
134
+ def __del__(self):
135
+ pass
136
+
137
+ def save_images(self, folder):
138
+ os.makedirs(folder, exist_ok=True)
139
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
140
+ frame = self.__getitem__(i)
141
+ frame.save(os.path.join(folder, f"{i}.png"))
142
+
143
+
144
+ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
145
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
146
+ writer = imageio.get_writer(
147
+ save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
148
+ )
149
+ for frame in tqdm(frames, desc="Saving video"):
150
+ frame = np.array(frame)
151
+ writer.append_data(frame)
152
+ writer.close()
153
+
154
+
155
+ def save_frames(frames, save_path):
156
+ os.makedirs(save_path, exist_ok=True)
157
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
158
+ frame.save(os.path.join(save_path, f"{i}.png"))
distributed/__init__.py ADDED
File without changes
distributed/xdit_context_parallel.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ from einops import rearrange
4
+ from xfuser.core.distributed import (
5
+ get_sequence_parallel_rank,
6
+ get_sequence_parallel_world_size,
7
+ get_sp_group,
8
+ )
9
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
10
+
11
+
12
+ def sinusoidal_embedding_1d(dim, position):
13
+ sinusoid = torch.outer(
14
+ position.type(torch.float64),
15
+ torch.pow(
16
+ 10000,
17
+ -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(
18
+ dim // 2
19
+ ),
20
+ ),
21
+ )
22
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
23
+ return x.to(position.dtype)
24
+
25
+
26
+ def pad_freqs(original_tensor, target_len):
27
+ seq_len, s1, s2 = original_tensor.shape
28
+ pad_size = target_len - seq_len
29
+ padding_tensor = torch.ones(
30
+ pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device
31
+ )
32
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
33
+ return padded_tensor
34
+
35
+
36
+ def rope_apply(x, freqs, num_heads):
37
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
38
+ s_per_rank = x.shape[1]
39
+
40
+ x_out = torch.view_as_complex(
41
+ x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
42
+ )
43
+
44
+ sp_size = get_sequence_parallel_world_size()
45
+ sp_rank = get_sequence_parallel_rank()
46
+ freqs = pad_freqs(freqs, s_per_rank * sp_size)
47
+ freqs_rank = freqs[(sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, :]
48
+
49
+ x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
50
+ return x_out.to(x.dtype)
51
+
52
+
53
+ def usp_dit_forward(
54
+ self,
55
+ x: torch.Tensor,
56
+ timestep: torch.Tensor,
57
+ context: torch.Tensor,
58
+ clip_feature: Optional[torch.Tensor] = None,
59
+ y: Optional[torch.Tensor] = None,
60
+ use_gradient_checkpointing: bool = False,
61
+ use_gradient_checkpointing_offload: bool = False,
62
+ **kwargs,
63
+ ):
64
+ t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
65
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
66
+ context = self.text_embedding(context)
67
+
68
+ if self.has_image_input:
69
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
70
+ clip_embdding = self.img_emb(clip_feature)
71
+ context = torch.cat([clip_embdding, context], dim=1)
72
+
73
+ x, (f, h, w) = self.patchify(x)
74
+
75
+ freqs = (
76
+ torch.cat(
77
+ [
78
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
79
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
80
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
81
+ ],
82
+ dim=-1,
83
+ )
84
+ .reshape(f * h * w, 1, -1)
85
+ .to(x.device)
86
+ )
87
+
88
+ def create_custom_forward(module):
89
+ def custom_forward(*inputs):
90
+ return module(*inputs)
91
+
92
+ return custom_forward
93
+
94
+ # Context Parallel
95
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[
96
+ get_sequence_parallel_rank()
97
+ ]
98
+
99
+ for block in self.blocks:
100
+ if self.training and use_gradient_checkpointing:
101
+ if use_gradient_checkpointing_offload:
102
+ with torch.autograd.graph.save_on_cpu():
103
+ x = torch.utils.checkpoint.checkpoint(
104
+ create_custom_forward(block),
105
+ x,
106
+ context,
107
+ t_mod,
108
+ freqs,
109
+ use_reentrant=False,
110
+ )
111
+ else:
112
+ x = torch.utils.checkpoint.checkpoint(
113
+ create_custom_forward(block),
114
+ x,
115
+ context,
116
+ t_mod,
117
+ freqs,
118
+ use_reentrant=False,
119
+ )
120
+ else:
121
+ x = block(x, context, t_mod, freqs)
122
+
123
+ x = self.head(x, t)
124
+
125
+ # Context Parallel
126
+ x = get_sp_group().all_gather(x, dim=1)
127
+
128
+ # unpatchify
129
+ x = self.unpatchify(x, (f, h, w))
130
+ return x
131
+
132
+
133
+ def usp_attn_forward(self, x, freqs):
134
+ q = self.norm_q(self.q(x))
135
+ k = self.norm_k(self.k(x))
136
+ v = self.v(x)
137
+
138
+ q = rope_apply(q, freqs, self.num_heads)
139
+ k = rope_apply(k, freqs, self.num_heads)
140
+ q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
141
+ k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
142
+ v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
143
+
144
+ x = xFuserLongContextAttention()(
145
+ None,
146
+ query=q,
147
+ key=k,
148
+ value=v,
149
+ )
150
+ x = x.flatten(2)
151
+
152
+ del q, k, v
153
+ torch.cuda.empty_cache()
154
+ return self.o(x)
download_models.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from huggingface_hub import snapshot_download
3
+
4
+ def main(use_vace: bool):
5
+ if use_vace:
6
+ snapshot_download("Wan-AI/Wan2.1-VACE-14B", local_dir="checkpoints/VACE/")
7
+ else:
8
+ snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="checkpoints/base_model/")
9
+
10
+ snapshot_download(
11
+ "DIAMONIK7777/antelopev2",
12
+ local_dir="checkpoints/antelopev2/models/antelopev2"
13
+ )
14
+ snapshot_download("BowenXue/Stand-In", local_dir="checkpoints/Stand-In/")
15
+
16
+ if __name__ == "__main__":
17
+ parser = argparse.ArgumentParser(description="Download models with or without VACE.")
18
+ parser.add_argument("--vace", action="store_true", help="Use VACE model instead of T2V.")
19
+ args = parser.parse_args()
20
+
21
+ main(args.vace)
infer.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data.video import save_video
3
+ from wan_loader import load_wan_pipe
4
+ from models.set_condition_branch import set_stand_in
5
+ from preprocessor import FaceProcessor
6
+ import argparse
7
+
8
+ parser = argparse.ArgumentParser()
9
+
10
+ parser.add_argument(
11
+ "--ip_image",
12
+ type=str,
13
+ default="test/input/lecun.jpg",
14
+ help="Input face image path or URL",
15
+ )
16
+ parser.add_argument(
17
+ "--prompt",
18
+ type=str,
19
+ default="一位男性舒适地坐在书桌前,正对着镜头,仿佛在与屏幕前的亲友对话。他的眼神专注而温柔,嘴角带着自然的笑意。背景是他精心布置的个人空间,墙上贴着照片和一张世界地图,传达出一种亲密而现代的沟通感。",
20
+ help="Text prompt for video generation",
21
+ )
22
+ parser.add_argument(
23
+ "--output", type=str, default="test/output/lecun.mp4", help="Output video file path"
24
+ )
25
+ parser.add_argument(
26
+ "--seed", type=int, default=0, help="Random seed for reproducibility"
27
+ )
28
+ parser.add_argument(
29
+ "--num_inference_steps", type=int, default=20, help="Number of inference steps"
30
+ )
31
+
32
+ parser.add_argument(
33
+ "--negative_prompt",
34
+ type=str,
35
+ default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
36
+ help="Negative prompt to avoid unwanted features",
37
+ )
38
+ parser.add_argument("--tiled", action="store_true", help="Enable tiled mode")
39
+ parser.add_argument(
40
+ "--fps", type=int, default=25, help="Frames per second for output video"
41
+ )
42
+ parser.add_argument(
43
+ "--quality", type=int, default=9, help="Output video quality (1-9)"
44
+ )
45
+ parser.add_argument(
46
+ "--base_path",
47
+ type=str,
48
+ default="checkpoints/base_model/",
49
+ help="Path to base model checkpoint",
50
+ )
51
+ parser.add_argument(
52
+ "--stand_in_path",
53
+ type=str,
54
+ default="checkpoints/Stand-In/Stand-In_wan2.1_T2V_14B_ver1.0.ckpt",
55
+ help="Path to LoRA weights checkpoint",
56
+ )
57
+ parser.add_argument(
58
+ "--antelopv2_path",
59
+ type=str,
60
+ default="checkpoints/antelopev2",
61
+ help="Path to AntelopeV2 model checkpoint",
62
+ )
63
+
64
+ args = parser.parse_args()
65
+
66
+
67
+ face_processor = FaceProcessor(antelopv2_path=args.antelopv2_path)
68
+ ip_image = face_processor.process(args.ip_image)
69
+
70
+ pipe = load_wan_pipe(base_path=args.base_path, torch_dtype=torch.bfloat16)
71
+
72
+ set_stand_in(
73
+ pipe,
74
+ model_path=args.stand_in_path,
75
+ )
76
+
77
+ video = pipe(
78
+ prompt=args.prompt,
79
+ negative_prompt=args.negative_prompt,
80
+ seed=args.seed,
81
+ ip_image=ip_image,
82
+ num_inference_steps=args.num_inference_steps,
83
+ tiled=args.tiled,
84
+ )
85
+ save_video(video, args.output, fps=args.fps, quality=args.quality)
infer_face_swap.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data.video import save_video
3
+ from wan_loader import load_wan_pipe
4
+ from models.set_condition_branch import set_stand_in
5
+ from preprocessor import FaceProcessor, VideoMaskGenerator
6
+ import argparse
7
+
8
+ parser = argparse.ArgumentParser()
9
+
10
+ parser.add_argument(
11
+ "--ip_image",
12
+ type=str,
13
+ default="test/input/ruonan.jpg",
14
+ help="Input face image path or URL",
15
+ )
16
+ parser.add_argument(
17
+ "--input_video",
18
+ type=str,
19
+ default="test/input/woman.mp4",
20
+ help="Input video path",
21
+ )
22
+ parser.add_argument(
23
+ "--denoising_strength",
24
+ type=float,
25
+ default=0.85,
26
+ help="The lower denoising strength represents a higher similarity to the original video.",
27
+ )
28
+ parser.add_argument(
29
+ "--prompt",
30
+ type=str,
31
+ default="The video features a woman standing in front of a large screen displaying the words "
32
+ "Tech Minute"
33
+ " and the logo for CNET. She is wearing a purple top and appears to be presenting or speaking about technology-related topics. The background includes a cityscape with tall buildings, suggesting an urban setting. The woman seems to be engaged in a discussion or providing information on technology news or trends. The overall atmosphere is professional and informative, likely aimed at educating viewers about the latest developments in the tech industry.",
34
+ help="Text prompt for video generation",
35
+ )
36
+ parser.add_argument(
37
+ "--output",
38
+ type=str,
39
+ default="test/output/ruonan.mp4",
40
+ help="Output video file path",
41
+ )
42
+ parser.add_argument(
43
+ "--seed", type=int, default=0, help="Random seed for reproducibility"
44
+ )
45
+ parser.add_argument(
46
+ "--num_inference_steps", type=int, default=20, help="Number of inference steps"
47
+ )
48
+ parser.add_argument(
49
+ "--force_background_consistency",
50
+ type=bool,
51
+ default=False,
52
+ help="Set to True to force background consistency across generated frames.",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--negative_prompt",
57
+ type=str,
58
+ default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
59
+ help="Negative prompt to avoid unwanted features",
60
+ )
61
+ parser.add_argument("--tiled", action="store_true", help="Enable tiled mode")
62
+ parser.add_argument(
63
+ "--fps", type=int, default=25, help="Frames per second for output video"
64
+ )
65
+ parser.add_argument(
66
+ "--quality", type=int, default=9, help="Output video quality (1-9)"
67
+ )
68
+ parser.add_argument(
69
+ "--base_path",
70
+ type=str,
71
+ default="checkpoints/base_model/",
72
+ help="Path to base model checkpoint",
73
+ )
74
+ parser.add_argument(
75
+ "--stand_in_path",
76
+ type=str,
77
+ default="checkpoints/Stand-In/Stand-In_wan2.1_T2V_14B_ver1.0.ckpt",
78
+ help="Path to LoRA weights checkpoint",
79
+ )
80
+ parser.add_argument(
81
+ "--antelopv2_path",
82
+ type=str,
83
+ default="checkpoints/antelopev2",
84
+ help="Path to AntelopeV2 model checkpoint",
85
+ )
86
+
87
+ args = parser.parse_args()
88
+
89
+ face_processor = FaceProcessor(antelopv2_path=args.antelopv2_path)
90
+ videomask_generator = VideoMaskGenerator(antelopv2_path=args.antelopv2_path)
91
+
92
+ ip_image, ip_image_rgba = face_processor.process(args.ip_image, extra_input=True)
93
+ input_video, face_mask, width, height, num_frames = videomask_generator.process(args.input_video, ip_image_rgba, random_horizontal_flip_chance=0.05, dilation_kernel_size=10)
94
+
95
+ pipe = load_wan_pipe(
96
+ base_path=args.base_path, face_swap=True, torch_dtype=torch.bfloat16
97
+ )
98
+
99
+ set_stand_in(
100
+ pipe,
101
+ model_path=args.stand_in_path,
102
+ )
103
+
104
+ video = pipe(
105
+ prompt=args.prompt,
106
+ negative_prompt=args.negative_prompt,
107
+ seed=args.seed,
108
+ width=width,
109
+ height=height,
110
+ num_frames=num_frames,
111
+ denoising_strength=args.denoising_strength,
112
+ ip_image=ip_image,
113
+ face_mask=face_mask,
114
+ input_video=input_video,
115
+ num_inference_steps=args.num_inference_steps,
116
+ tiled=args.tiled,
117
+ force_background_consistency=args.force_background_consistency
118
+ )
119
+ save_video(video, args.output, fps=args.fps, quality=args.quality)
infer_with_lora.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data.video import save_video
3
+ from wan_loader import load_wan_pipe
4
+ from models.set_condition_branch import set_stand_in
5
+ from preprocessor import FaceProcessor
6
+ import argparse
7
+
8
+ parser = argparse.ArgumentParser()
9
+
10
+ parser.add_argument(
11
+ "--ip_image",
12
+ type=str,
13
+ default="test/input/lecun.jpg",
14
+ help="Input face image path or URL",
15
+ )
16
+ parser.add_argument(
17
+ "--lora_path", type=str, required=True, help="Text prompt for video generation"
18
+ )
19
+ parser.add_argument(
20
+ "--prompt",
21
+ type=str,
22
+ default="Close-up of a young man with dark hair tied back, wearing a white kimono adorned with a red floral pattern. He sits against a backdrop of sliding doors with blue accents. His expression shifts from neutral to a slight smile, then to a surprised look. The camera remains static, focusing on his face and upper body as he appears to be reacting to something off-screen. The lighting is soft and natural, suggesting daytime.",
23
+ help="Text prompt for video generation",
24
+ )
25
+ parser.add_argument(
26
+ "--output", type=str, default="test/output/lecun.mp4", help="Output video file path"
27
+ )
28
+ parser.add_argument(
29
+ "--seed", type=int, default=0, help="Random seed for reproducibility"
30
+ )
31
+ parser.add_argument(
32
+ "--num_inference_steps", type=int, default=20, help="Number of inference steps"
33
+ )
34
+ parser.add_argument("--lora_scale", type=float, default=1.0, help="Lora Scale")
35
+
36
+ parser.add_argument(
37
+ "--negative_prompt",
38
+ type=str,
39
+ default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
40
+ help="Negative prompt to avoid unwanted features",
41
+ )
42
+ parser.add_argument("--tiled", action="store_true", help="Enable tiled mode")
43
+ parser.add_argument(
44
+ "--fps", type=int, default=25, help="Frames per second for output video"
45
+ )
46
+ parser.add_argument(
47
+ "--quality", type=int, default=9, help="Output video quality (1-9)"
48
+ )
49
+ parser.add_argument(
50
+ "--base_path",
51
+ type=str,
52
+ default="checkpoints/base_model/",
53
+ help="Path to base model checkpoint",
54
+ )
55
+ parser.add_argument(
56
+ "--stand_in_path",
57
+ type=str,
58
+ default="checkpoints/Stand-In/Stand-In_wan2.1_T2V_14B_ver1.0.ckpt",
59
+ help="Path to LoRA weights checkpoint",
60
+ )
61
+ parser.add_argument(
62
+ "--antelopv2_path",
63
+ type=str,
64
+ default="checkpoints/antelopev2",
65
+ help="Path to AntelopeV2 model checkpoint",
66
+ )
67
+
68
+ args = parser.parse_args()
69
+
70
+ face_processor = FaceProcessor(antelopv2_path=args.antelopv2_path)
71
+ ip_image = face_processor.process(args.ip_image)
72
+
73
+ pipe = load_wan_pipe(base_path=args.base_path, torch_dtype=torch.bfloat16)
74
+
75
+ pipe.load_lora(
76
+ pipe.dit,
77
+ args.lora_path,
78
+ alpha=1,
79
+ )
80
+
81
+ set_stand_in(
82
+ pipe,
83
+ model_path=args.stand_in_path,
84
+ )
85
+
86
+ video = pipe(
87
+ prompt=args.prompt,
88
+ negative_prompt=args.negative_prompt,
89
+ seed=args.seed,
90
+ ip_image=ip_image,
91
+ num_inference_steps=args.num_inference_steps,
92
+ tiled=args.tiled,
93
+ )
94
+ save_video(video, args.output, fps=args.fps, quality=args.quality)
infer_with_vace.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data.video import save_video
3
+ from wan_loader import load_wan_pipe
4
+ from models.set_condition_branch import set_stand_in
5
+ from preprocessor import FaceProcessor
6
+ import argparse
7
+
8
+ parser = argparse.ArgumentParser()
9
+
10
+ parser.add_argument(
11
+ "--ip_image",
12
+ type=str,
13
+ default="test/input/first_frame.png",
14
+ help="Input face image path or URL",
15
+ )
16
+ parser.add_argument(
17
+ "--reference_video",
18
+ type=str,
19
+ default="test/input/pose.mp4",
20
+ help="reference_video path",
21
+ )
22
+ parser.add_argument(
23
+ "--reference_image",
24
+ default="test/input/first_frame.png",
25
+ type=str,
26
+ help="reference_video path",
27
+ )
28
+ parser.add_argument(
29
+ "--vace_scale",
30
+ type=float,
31
+ default=0.8,
32
+ help="Scaling factor for VACE.",
33
+ )
34
+ parser.add_argument(
35
+ "--prompt",
36
+ type=str,
37
+ default="一个女人举起双手",
38
+ help="Text prompt for video generation",
39
+ )
40
+ parser.add_argument(
41
+ "--output", type=str, default="test/output/woman.mp4", help="Output video file path"
42
+ )
43
+ parser.add_argument(
44
+ "--seed", type=int, default=0, help="Random seed for reproducibility"
45
+ )
46
+ parser.add_argument(
47
+ "--num_inference_steps", type=int, default=20, help="Number of inference steps"
48
+ )
49
+ parser.add_argument(
50
+ "--vace_path",
51
+ type=str,
52
+ default="checkpoints/VACE/",
53
+ help="Path to base model checkpoint",
54
+ )
55
+
56
+ parser.add_argument(
57
+ "--negative_prompt",
58
+ type=str,
59
+ default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
60
+ help="Negative prompt to avoid unwanted features",
61
+ )
62
+ parser.add_argument("--tiled", action="store_true", help="Enable tiled mode")
63
+ parser.add_argument(
64
+ "--fps", type=int, default=25, help="Frames per second for output video"
65
+ )
66
+ parser.add_argument(
67
+ "--quality", type=int, default=9, help="Output video quality (1-9)"
68
+ )
69
+ parser.add_argument(
70
+ "--stand_in_path",
71
+ type=str,
72
+ default="checkpoints/Stand-In/Stand-In_wan2.1_T2V_14B_ver1.0.ckpt",
73
+ help="Path to LoRA weights checkpoint",
74
+ )
75
+ parser.add_argument(
76
+ "--antelopv2_path",
77
+ type=str,
78
+ default="checkpoints/antelopev2",
79
+ help="Path to AntelopeV2 model checkpoint",
80
+ )
81
+
82
+ args = parser.parse_args()
83
+
84
+
85
+ face_processor = FaceProcessor(antelopv2_path=args.antelopv2_path)
86
+ ip_image = face_processor.process(args.ip_image)
87
+
88
+ pipe = load_wan_pipe(base_path=args.vace_path, use_vace=True, torch_dtype=torch.bfloat16)
89
+
90
+ set_stand_in(
91
+ pipe,
92
+ model_path=args.stand_in_path,
93
+ )
94
+
95
+ video = pipe(
96
+ prompt=args.prompt,
97
+ vace_video=args.reference_video,
98
+ vace_reference_image=args.reference_image,
99
+ negative_prompt=args.negative_prompt,
100
+ vace_scale=args.vace_scale,
101
+ seed=args.seed,
102
+ ip_image=ip_image,
103
+ num_inference_steps=args.num_inference_steps,
104
+ tiled=args.tiled,
105
+ )
106
+ save_video(video, args.output, fps=args.fps, quality=args.quality)
lora/__init__.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class GeneralLoRALoader:
5
+ def __init__(self, device="cpu", torch_dtype=torch.float32):
6
+ self.device = device
7
+ self.torch_dtype = torch_dtype
8
+
9
+ def get_name_dict(self, lora_state_dict):
10
+ lora_name_dict = {}
11
+
12
+ has_lora_A = any(k.endswith(".lora_A.weight") for k in lora_state_dict)
13
+ has_lora_down = any(k.endswith(".lora_down.weight") for k in lora_state_dict)
14
+
15
+ if has_lora_A:
16
+ lora_a_keys = [k for k in lora_state_dict if k.endswith(".lora_A.weight")]
17
+ for lora_a_key in lora_a_keys:
18
+ base_name = lora_a_key.replace(".lora_A.weight", "")
19
+ lora_b_key = base_name + ".lora_B.weight"
20
+
21
+ if lora_b_key in lora_state_dict:
22
+ target_name = base_name.replace("diffusion_model.", "", 1)
23
+ lora_name_dict[target_name] = (lora_b_key, lora_a_key)
24
+
25
+ elif has_lora_down:
26
+ lora_down_keys = [
27
+ k for k in lora_state_dict if k.endswith(".lora_down.weight")
28
+ ]
29
+ for lora_down_key in lora_down_keys:
30
+ base_name = lora_down_key.replace(".lora_down.weight", "")
31
+ lora_up_key = base_name + ".lora_up.weight"
32
+
33
+ if lora_up_key in lora_state_dict:
34
+ target_name = base_name.replace("lora_unet_", "").replace("_", ".")
35
+ target_name = target_name.replace(".attn.", "_attn.")
36
+ lora_name_dict[target_name] = (lora_up_key, lora_down_key)
37
+
38
+ else:
39
+ print(
40
+ "Warning: No recognizable LoRA key names found in state_dict (neither 'lora_A' nor 'lora_down')."
41
+ )
42
+
43
+ return lora_name_dict
44
+
45
+ def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
46
+ lora_name_dict = self.get_name_dict(state_dict_lora)
47
+ updated_num = 0
48
+
49
+ lora_target_names = set(lora_name_dict.keys())
50
+ model_layer_names = {
51
+ name for name, module in model.named_modules() if hasattr(module, "weight")
52
+ }
53
+ matched_names = lora_target_names.intersection(model_layer_names)
54
+ unmatched_lora_names = lora_target_names - model_layer_names
55
+
56
+ print(f"Successfully matched {len(matched_names)} layers.")
57
+ if unmatched_lora_names:
58
+ print(
59
+ f"Warning: {len(unmatched_lora_names)} LoRA layers not matched and will be ignored."
60
+ )
61
+
62
+ for name, module in model.named_modules():
63
+ if name in matched_names:
64
+ lora_b_key, lora_a_key = lora_name_dict[name]
65
+ weight_up = state_dict_lora[lora_b_key].to(
66
+ device=self.device, dtype=self.torch_dtype
67
+ )
68
+ weight_down = state_dict_lora[lora_a_key].to(
69
+ device=self.device, dtype=self.torch_dtype
70
+ )
71
+
72
+ if len(weight_up.shape) == 4:
73
+ weight_up = weight_up.squeeze(3).squeeze(2)
74
+ weight_down = weight_down.squeeze(3).squeeze(2)
75
+ weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(
76
+ 2
77
+ ).unsqueeze(3)
78
+ else:
79
+ weight_lora = alpha * torch.mm(weight_up, weight_down)
80
+
81
+ if module.weight.shape != weight_lora.shape:
82
+ print(f"Error: Shape mismatch for layer '{name}'! Skipping update.")
83
+ continue
84
+
85
+ module.weight.data = (
86
+ module.weight.data.to(weight_lora.device, dtype=weight_lora.dtype)
87
+ + weight_lora
88
+ )
89
+ updated_num += 1
90
+
91
+ print(f"LoRA loading complete, updated {updated_num} tensors in total.\n")
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_manager import *
models/attention.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+
5
+ def low_version_attention(query, key, value, attn_bias=None):
6
+ scale = 1 / query.shape[-1] ** 0.5
7
+ query = query * scale
8
+ attn = torch.matmul(query, key.transpose(-2, -1))
9
+ if attn_bias is not None:
10
+ attn = attn + attn_bias
11
+ attn = attn.softmax(-1)
12
+ return attn @ value
13
+
14
+
15
+ class Attention(torch.nn.Module):
16
+ def __init__(
17
+ self,
18
+ q_dim,
19
+ num_heads,
20
+ head_dim,
21
+ kv_dim=None,
22
+ bias_q=False,
23
+ bias_kv=False,
24
+ bias_out=False,
25
+ ):
26
+ super().__init__()
27
+ dim_inner = head_dim * num_heads
28
+ kv_dim = kv_dim if kv_dim is not None else q_dim
29
+ self.num_heads = num_heads
30
+ self.head_dim = head_dim
31
+
32
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
33
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
34
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
35
+ self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
36
+
37
+ def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
38
+ batch_size = q.shape[0]
39
+ ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
40
+ ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
41
+ ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(
42
+ q, ip_k, ip_v
43
+ )
44
+ hidden_states = hidden_states + scale * ip_hidden_states
45
+ return hidden_states
46
+
47
+ def torch_forward(
48
+ self,
49
+ hidden_states,
50
+ encoder_hidden_states=None,
51
+ attn_mask=None,
52
+ ipadapter_kwargs=None,
53
+ qkv_preprocessor=None,
54
+ ):
55
+ if encoder_hidden_states is None:
56
+ encoder_hidden_states = hidden_states
57
+
58
+ batch_size = encoder_hidden_states.shape[0]
59
+
60
+ q = self.to_q(hidden_states)
61
+ k = self.to_k(encoder_hidden_states)
62
+ v = self.to_v(encoder_hidden_states)
63
+
64
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
65
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
66
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
67
+
68
+ if qkv_preprocessor is not None:
69
+ q, k, v = qkv_preprocessor(q, k, v)
70
+
71
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
72
+ q, k, v, attn_mask=attn_mask
73
+ )
74
+ if ipadapter_kwargs is not None:
75
+ hidden_states = self.interact_with_ipadapter(
76
+ hidden_states, q, **ipadapter_kwargs
77
+ )
78
+ hidden_states = hidden_states.transpose(1, 2).reshape(
79
+ batch_size, -1, self.num_heads * self.head_dim
80
+ )
81
+ hidden_states = hidden_states.to(q.dtype)
82
+
83
+ hidden_states = self.to_out(hidden_states)
84
+
85
+ return hidden_states
86
+
87
+ def xformers_forward(
88
+ self, hidden_states, encoder_hidden_states=None, attn_mask=None
89
+ ):
90
+ if encoder_hidden_states is None:
91
+ encoder_hidden_states = hidden_states
92
+
93
+ q = self.to_q(hidden_states)
94
+ k = self.to_k(encoder_hidden_states)
95
+ v = self.to_v(encoder_hidden_states)
96
+
97
+ q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
98
+ k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
99
+ v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
100
+
101
+ if attn_mask is not None:
102
+ hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
103
+ else:
104
+ import xformers.ops as xops
105
+
106
+ hidden_states = xops.memory_efficient_attention(q, k, v)
107
+ hidden_states = rearrange(
108
+ hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads
109
+ )
110
+
111
+ hidden_states = hidden_states.to(q.dtype)
112
+ hidden_states = self.to_out(hidden_states)
113
+
114
+ return hidden_states
115
+
116
+ def forward(
117
+ self,
118
+ hidden_states,
119
+ encoder_hidden_states=None,
120
+ attn_mask=None,
121
+ ipadapter_kwargs=None,
122
+ qkv_preprocessor=None,
123
+ ):
124
+ return self.torch_forward(
125
+ hidden_states,
126
+ encoder_hidden_states=encoder_hidden_states,
127
+ attn_mask=attn_mask,
128
+ ipadapter_kwargs=ipadapter_kwargs,
129
+ qkv_preprocessor=qkv_preprocessor,
130
+ )
models/downloader.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download, snapshot_download
2
+ import os, shutil
3
+ from typing_extensions import Literal, TypeAlias
4
+ from typing import List
5
+ from configs.model_config import (
6
+ preset_models_on_huggingface,
7
+ preset_models_on_modelscope,
8
+ Preset_model_id,
9
+ )
10
+
11
+
12
+ def download_from_modelscope(model_id, origin_file_path, local_dir):
13
+ os.makedirs(local_dir, exist_ok=True)
14
+ file_name = os.path.basename(origin_file_path)
15
+ if file_name in os.listdir(local_dir):
16
+ print(f" {file_name} has been already in {local_dir}.")
17
+ else:
18
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
19
+ snapshot_download(
20
+ model_id, allow_file_pattern=origin_file_path, local_dir=local_dir
21
+ )
22
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
23
+ target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
24
+ if downloaded_file_path != target_file_path:
25
+ shutil.move(downloaded_file_path, target_file_path)
26
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
27
+
28
+
29
+ def download_from_huggingface(model_id, origin_file_path, local_dir):
30
+ os.makedirs(local_dir, exist_ok=True)
31
+ file_name = os.path.basename(origin_file_path)
32
+ if file_name in os.listdir(local_dir):
33
+ print(f" {file_name} has been already in {local_dir}.")
34
+ else:
35
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
36
+ hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
37
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
38
+ target_file_path = os.path.join(local_dir, file_name)
39
+ if downloaded_file_path != target_file_path:
40
+ shutil.move(downloaded_file_path, target_file_path)
41
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
42
+
43
+
44
+ Preset_model_website: TypeAlias = Literal[
45
+ "HuggingFace",
46
+ "ModelScope",
47
+ ]
48
+ website_to_preset_models = {
49
+ "HuggingFace": preset_models_on_huggingface,
50
+ "ModelScope": preset_models_on_modelscope,
51
+ }
52
+ website_to_download_fn = {
53
+ "HuggingFace": download_from_huggingface,
54
+ "ModelScope": download_from_modelscope,
55
+ }
56
+
57
+
58
+ def download_customized_models(
59
+ model_id,
60
+ origin_file_path,
61
+ local_dir,
62
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
63
+ ):
64
+ downloaded_files = []
65
+ for website in downloading_priority:
66
+ # Check if the file is downloaded.
67
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
68
+ if file_to_download in downloaded_files:
69
+ continue
70
+ # Download
71
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
72
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
73
+ downloaded_files.append(file_to_download)
74
+ return downloaded_files
75
+
76
+
77
+ def download_models(
78
+ model_id_list: List[Preset_model_id] = [],
79
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
80
+ ):
81
+ print(f"Downloading models: {model_id_list}")
82
+ downloaded_files = []
83
+ load_files = []
84
+
85
+ for model_id in model_id_list:
86
+ for website in downloading_priority:
87
+ if model_id in website_to_preset_models[website]:
88
+ # Parse model metadata
89
+ model_metadata = website_to_preset_models[website][model_id]
90
+ if isinstance(model_metadata, list):
91
+ file_data = model_metadata
92
+ else:
93
+ file_data = model_metadata.get("file_list", [])
94
+
95
+ # Try downloading the model from this website.
96
+ model_files = []
97
+ for model_id, origin_file_path, local_dir in file_data:
98
+ # Check if the file is downloaded.
99
+ file_to_download = os.path.join(
100
+ local_dir, os.path.basename(origin_file_path)
101
+ )
102
+ if file_to_download in downloaded_files:
103
+ continue
104
+ # Download
105
+ website_to_download_fn[website](
106
+ model_id, origin_file_path, local_dir
107
+ )
108
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
109
+ downloaded_files.append(file_to_download)
110
+ model_files.append(file_to_download)
111
+
112
+ # If the model is successfully downloaded, break.
113
+ if len(model_files) > 0:
114
+ if (
115
+ isinstance(model_metadata, dict)
116
+ and "load_path" in model_metadata
117
+ ):
118
+ model_files = model_metadata["load_path"]
119
+ load_files.extend(model_files)
120
+ break
121
+
122
+ return load_files
models/model_manager.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, json, importlib
2
+ from typing import List
3
+
4
+ from .downloader import (
5
+ download_models,
6
+ download_customized_models,
7
+ Preset_model_id,
8
+ Preset_model_website,
9
+ )
10
+
11
+ from configs.model_config import (
12
+ model_loader_configs,
13
+ huggingface_model_loader_configs,
14
+ patch_model_loader_configs,
15
+ )
16
+ from .utils import (
17
+ load_state_dict,
18
+ init_weights_on_device,
19
+ hash_state_dict_keys,
20
+ split_state_dict_with_prefix,
21
+ )
22
+
23
+
24
+ def load_model_from_single_file(
25
+ state_dict, model_names, model_classes, model_resource, torch_dtype, device
26
+ ):
27
+ loaded_model_names, loaded_models = [], []
28
+ for model_name, model_class in zip(model_names, model_classes):
29
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
30
+ state_dict_converter = model_class.state_dict_converter()
31
+ if model_resource == "civitai":
32
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
33
+ elif model_resource == "diffusers":
34
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
35
+ if isinstance(state_dict_results, tuple):
36
+ model_state_dict, extra_kwargs = state_dict_results
37
+ print(
38
+ f" This model is initialized with extra kwargs: {extra_kwargs}"
39
+ )
40
+ else:
41
+ model_state_dict, extra_kwargs = state_dict_results, {}
42
+ torch_dtype = (
43
+ torch.float32
44
+ if extra_kwargs.get("upcast_to_float32", False)
45
+ else torch_dtype
46
+ )
47
+ with init_weights_on_device():
48
+ model = model_class(**extra_kwargs)
49
+ if hasattr(model, "eval"):
50
+ model = model.eval()
51
+ model.load_state_dict(model_state_dict, assign=True)
52
+ model = model.to(dtype=torch_dtype, device=device)
53
+ loaded_model_names.append(model_name)
54
+ loaded_models.append(model)
55
+ return loaded_model_names, loaded_models
56
+
57
+
58
+ def load_model_from_huggingface_folder(
59
+ file_path, model_names, model_classes, torch_dtype, device
60
+ ):
61
+ loaded_model_names, loaded_models = [], []
62
+ for model_name, model_class in zip(model_names, model_classes):
63
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
64
+ model = model_class.from_pretrained(
65
+ file_path, torch_dtype=torch_dtype
66
+ ).eval()
67
+ else:
68
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
69
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
70
+ model = model.half()
71
+ try:
72
+ model = model.to(device=device)
73
+ except:
74
+ pass
75
+ loaded_model_names.append(model_name)
76
+ loaded_models.append(model)
77
+ return loaded_model_names, loaded_models
78
+
79
+
80
+ def load_single_patch_model_from_single_file(
81
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device
82
+ ):
83
+ print(
84
+ f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}"
85
+ )
86
+ base_state_dict = base_model.state_dict()
87
+ base_model.to("cpu")
88
+ del base_model
89
+ model = model_class(**extra_kwargs)
90
+ model.load_state_dict(base_state_dict, strict=False)
91
+ model.load_state_dict(state_dict, strict=False)
92
+ model.to(dtype=torch_dtype, device=device)
93
+ return model
94
+
95
+
96
+ def load_patch_model_from_single_file(
97
+ state_dict,
98
+ model_names,
99
+ model_classes,
100
+ extra_kwargs,
101
+ model_manager,
102
+ torch_dtype,
103
+ device,
104
+ ):
105
+ loaded_model_names, loaded_models = [], []
106
+ for model_name, model_class in zip(model_names, model_classes):
107
+ while True:
108
+ for model_id in range(len(model_manager.model)):
109
+ base_model_name = model_manager.model_name[model_id]
110
+ if base_model_name == model_name:
111
+ base_model_path = model_manager.model_path[model_id]
112
+ base_model = model_manager.model[model_id]
113
+ print(
114
+ f" Adding patch model to {base_model_name} ({base_model_path})"
115
+ )
116
+ patched_model = load_single_patch_model_from_single_file(
117
+ state_dict,
118
+ model_name,
119
+ model_class,
120
+ base_model,
121
+ extra_kwargs,
122
+ torch_dtype,
123
+ device,
124
+ )
125
+ loaded_model_names.append(base_model_name)
126
+ loaded_models.append(patched_model)
127
+ model_manager.model.pop(model_id)
128
+ model_manager.model_path.pop(model_id)
129
+ model_manager.model_name.pop(model_id)
130
+ break
131
+ else:
132
+ break
133
+ return loaded_model_names, loaded_models
134
+
135
+
136
+ class ModelDetectorTemplate:
137
+ def __init__(self):
138
+ pass
139
+
140
+ def match(self, file_path="", state_dict={}):
141
+ return False
142
+
143
+ def load(
144
+ self,
145
+ file_path="",
146
+ state_dict={},
147
+ device="cuda",
148
+ torch_dtype=torch.float16,
149
+ **kwargs,
150
+ ):
151
+ return [], []
152
+
153
+
154
+ class ModelDetectorFromSingleFile:
155
+ def __init__(self, model_loader_configs=[]):
156
+ self.keys_hash_with_shape_dict = {}
157
+ self.keys_hash_dict = {}
158
+ for metadata in model_loader_configs:
159
+ self.add_model_metadata(*metadata)
160
+
161
+ def add_model_metadata(
162
+ self,
163
+ keys_hash,
164
+ keys_hash_with_shape,
165
+ model_names,
166
+ model_classes,
167
+ model_resource,
168
+ ):
169
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (
170
+ model_names,
171
+ model_classes,
172
+ model_resource,
173
+ )
174
+ if keys_hash is not None:
175
+ self.keys_hash_dict[keys_hash] = (
176
+ model_names,
177
+ model_classes,
178
+ model_resource,
179
+ )
180
+
181
+ def match(self, file_path="", state_dict={}):
182
+ if isinstance(file_path, str) and os.path.isdir(file_path):
183
+ return False
184
+ if len(state_dict) == 0:
185
+ state_dict = load_state_dict(file_path)
186
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
187
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
188
+ return True
189
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
190
+ if keys_hash in self.keys_hash_dict:
191
+ return True
192
+ return False
193
+
194
+ def load(
195
+ self,
196
+ file_path="",
197
+ state_dict={},
198
+ device="cuda",
199
+ torch_dtype=torch.float16,
200
+ **kwargs,
201
+ ):
202
+ if len(state_dict) == 0:
203
+ state_dict = load_state_dict(file_path)
204
+
205
+ # Load models with strict matching
206
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
207
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
208
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[
209
+ keys_hash_with_shape
210
+ ]
211
+ loaded_model_names, loaded_models = load_model_from_single_file(
212
+ state_dict,
213
+ model_names,
214
+ model_classes,
215
+ model_resource,
216
+ torch_dtype,
217
+ device,
218
+ )
219
+ return loaded_model_names, loaded_models
220
+
221
+ # Load models without strict matching
222
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
223
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
224
+ if keys_hash in self.keys_hash_dict:
225
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
226
+ loaded_model_names, loaded_models = load_model_from_single_file(
227
+ state_dict,
228
+ model_names,
229
+ model_classes,
230
+ model_resource,
231
+ torch_dtype,
232
+ device,
233
+ )
234
+ return loaded_model_names, loaded_models
235
+
236
+ return loaded_model_names, loaded_models
237
+
238
+
239
+ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
240
+ def __init__(self, model_loader_configs=[]):
241
+ super().__init__(model_loader_configs)
242
+
243
+ def match(self, file_path="", state_dict={}):
244
+ if isinstance(file_path, str) and os.path.isdir(file_path):
245
+ return False
246
+ if len(state_dict) == 0:
247
+ state_dict = load_state_dict(file_path)
248
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
249
+ for sub_state_dict in splited_state_dict:
250
+ if super().match(file_path, sub_state_dict):
251
+ return True
252
+ return False
253
+
254
+ def load(
255
+ self,
256
+ file_path="",
257
+ state_dict={},
258
+ device="cuda",
259
+ torch_dtype=torch.float16,
260
+ **kwargs,
261
+ ):
262
+ # Split the state_dict and load from each component
263
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
264
+ valid_state_dict = {}
265
+ for sub_state_dict in splited_state_dict:
266
+ if super().match(file_path, sub_state_dict):
267
+ valid_state_dict.update(sub_state_dict)
268
+ if super().match(file_path, valid_state_dict):
269
+ loaded_model_names, loaded_models = super().load(
270
+ file_path, valid_state_dict, device, torch_dtype
271
+ )
272
+ else:
273
+ loaded_model_names, loaded_models = [], []
274
+ for sub_state_dict in splited_state_dict:
275
+ if super().match(file_path, sub_state_dict):
276
+ loaded_model_names_, loaded_models_ = super().load(
277
+ file_path, valid_state_dict, device, torch_dtype
278
+ )
279
+ loaded_model_names += loaded_model_names_
280
+ loaded_models += loaded_models_
281
+ return loaded_model_names, loaded_models
282
+
283
+
284
+ class ModelDetectorFromHuggingfaceFolder:
285
+ def __init__(self, model_loader_configs=[]):
286
+ self.architecture_dict = {}
287
+ for metadata in model_loader_configs:
288
+ self.add_model_metadata(*metadata)
289
+
290
+ def add_model_metadata(
291
+ self, architecture, huggingface_lib, model_name, redirected_architecture
292
+ ):
293
+ self.architecture_dict[architecture] = (
294
+ huggingface_lib,
295
+ model_name,
296
+ redirected_architecture,
297
+ )
298
+
299
+ def match(self, file_path="", state_dict={}):
300
+ if not isinstance(file_path, str) or os.path.isfile(file_path):
301
+ return False
302
+ file_list = os.listdir(file_path)
303
+ if "config.json" not in file_list:
304
+ return False
305
+ with open(os.path.join(file_path, "config.json"), "r") as f:
306
+ config = json.load(f)
307
+ if "architectures" not in config and "_class_name" not in config:
308
+ return False
309
+ return True
310
+
311
+ def load(
312
+ self,
313
+ file_path="",
314
+ state_dict={},
315
+ device="cuda",
316
+ torch_dtype=torch.float16,
317
+ **kwargs,
318
+ ):
319
+ with open(os.path.join(file_path, "config.json"), "r") as f:
320
+ config = json.load(f)
321
+ loaded_model_names, loaded_models = [], []
322
+ architectures = (
323
+ config["architectures"]
324
+ if "architectures" in config
325
+ else [config["_class_name"]]
326
+ )
327
+ for architecture in architectures:
328
+ huggingface_lib, model_name, redirected_architecture = (
329
+ self.architecture_dict[architecture]
330
+ )
331
+ if redirected_architecture is not None:
332
+ architecture = redirected_architecture
333
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(
334
+ architecture
335
+ )
336
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(
337
+ file_path, [model_name], [model_class], torch_dtype, device
338
+ )
339
+ loaded_model_names += loaded_model_names_
340
+ loaded_models += loaded_models_
341
+ return loaded_model_names, loaded_models
342
+
343
+
344
+ class ModelDetectorFromPatchedSingleFile:
345
+ def __init__(self, model_loader_configs=[]):
346
+ self.keys_hash_with_shape_dict = {}
347
+ for metadata in model_loader_configs:
348
+ self.add_model_metadata(*metadata)
349
+
350
+ def add_model_metadata(
351
+ self, keys_hash_with_shape, model_name, model_class, extra_kwargs
352
+ ):
353
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (
354
+ model_name,
355
+ model_class,
356
+ extra_kwargs,
357
+ )
358
+
359
+ def match(self, file_path="", state_dict={}):
360
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
361
+ return False
362
+ if len(state_dict) == 0:
363
+ state_dict = load_state_dict(file_path)
364
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
365
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
366
+ return True
367
+ return False
368
+
369
+ def load(
370
+ self,
371
+ file_path="",
372
+ state_dict={},
373
+ device="cuda",
374
+ torch_dtype=torch.float16,
375
+ model_manager=None,
376
+ **kwargs,
377
+ ):
378
+ if len(state_dict) == 0:
379
+ state_dict = load_state_dict(file_path)
380
+
381
+ # Load models with strict matching
382
+ loaded_model_names, loaded_models = [], []
383
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
384
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
385
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[
386
+ keys_hash_with_shape
387
+ ]
388
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
389
+ state_dict,
390
+ model_names,
391
+ model_classes,
392
+ extra_kwargs,
393
+ model_manager,
394
+ torch_dtype,
395
+ device,
396
+ )
397
+ loaded_model_names += loaded_model_names_
398
+ loaded_models += loaded_models_
399
+ return loaded_model_names, loaded_models
400
+
401
+
402
+ class ModelManager:
403
+ def __init__(
404
+ self,
405
+ torch_dtype=torch.float16,
406
+ device="cuda",
407
+ model_id_list: List[Preset_model_id] = [],
408
+ downloading_priority: List[Preset_model_website] = [
409
+ "ModelScope",
410
+ "HuggingFace",
411
+ ],
412
+ file_path_list: List[str] = [],
413
+ ):
414
+ self.torch_dtype = torch_dtype
415
+ self.device = device
416
+ self.model = []
417
+ self.model_path = []
418
+ self.model_name = []
419
+ downloaded_files = (
420
+ download_models(model_id_list, downloading_priority)
421
+ if len(model_id_list) > 0
422
+ else []
423
+ )
424
+ self.model_detector = [
425
+ ModelDetectorFromSingleFile(model_loader_configs),
426
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
427
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
428
+ ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
429
+ ]
430
+ self.load_models(downloaded_files + file_path_list)
431
+
432
+ def load_model_from_single_file(
433
+ self,
434
+ file_path="",
435
+ state_dict={},
436
+ model_names=[],
437
+ model_classes=[],
438
+ model_resource=None,
439
+ ):
440
+ print(f"Loading models from file: {file_path}")
441
+ if len(state_dict) == 0:
442
+ state_dict = load_state_dict(file_path)
443
+ model_names, models = load_model_from_single_file(
444
+ state_dict,
445
+ model_names,
446
+ model_classes,
447
+ model_resource,
448
+ self.torch_dtype,
449
+ self.device,
450
+ )
451
+ for model_name, model in zip(model_names, models):
452
+ self.model.append(model)
453
+ self.model_path.append(file_path)
454
+ self.model_name.append(model_name)
455
+ print(f" The following models are loaded: {model_names}.")
456
+
457
+ def load_model_from_huggingface_folder(
458
+ self, file_path="", model_names=[], model_classes=[]
459
+ ):
460
+ print(f"Loading models from folder: {file_path}")
461
+ model_names, models = load_model_from_huggingface_folder(
462
+ file_path, model_names, model_classes, self.torch_dtype, self.device
463
+ )
464
+ for model_name, model in zip(model_names, models):
465
+ self.model.append(model)
466
+ self.model_path.append(file_path)
467
+ self.model_name.append(model_name)
468
+ print(f" The following models are loaded: {model_names}.")
469
+
470
+ def load_patch_model_from_single_file(
471
+ self,
472
+ file_path="",
473
+ state_dict={},
474
+ model_names=[],
475
+ model_classes=[],
476
+ extra_kwargs={},
477
+ ):
478
+ print(f"Loading patch models from file: {file_path}")
479
+ model_names, models = load_patch_model_from_single_file(
480
+ state_dict,
481
+ model_names,
482
+ model_classes,
483
+ extra_kwargs,
484
+ self,
485
+ self.torch_dtype,
486
+ self.device,
487
+ )
488
+ for model_name, model in zip(model_names, models):
489
+ self.model.append(model)
490
+ self.model_path.append(file_path)
491
+ self.model_name.append(model_name)
492
+ print(f" The following patched models are loaded: {model_names}.")
493
+
494
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
495
+ if isinstance(file_path, list):
496
+ for file_path_ in file_path:
497
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
498
+ else:
499
+ print(f"Loading LoRA models from file: {file_path}")
500
+ is_loaded = False
501
+ if len(state_dict) == 0:
502
+ state_dict = load_state_dict(file_path)
503
+ for model_name, model, model_path in zip(
504
+ self.model_name, self.model, self.model_path
505
+ ):
506
+ for lora in get_lora_loaders():
507
+ match_results = lora.match(model, state_dict)
508
+ if match_results is not None:
509
+ print(f" Adding LoRA to {model_name} ({model_path}).")
510
+ lora_prefix, model_resource = match_results
511
+ lora.load(
512
+ model,
513
+ state_dict,
514
+ lora_prefix,
515
+ alpha=lora_alpha,
516
+ model_resource=model_resource,
517
+ )
518
+ is_loaded = True
519
+ break
520
+ if not is_loaded:
521
+ print(f" Cannot load LoRA: {file_path}")
522
+
523
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
524
+ print(f"Loading models from: {file_path}")
525
+ if device is None:
526
+ device = self.device
527
+ if torch_dtype is None:
528
+ torch_dtype = self.torch_dtype
529
+ if isinstance(file_path, list):
530
+ state_dict = {}
531
+ for path in file_path:
532
+ state_dict.update(load_state_dict(path))
533
+ elif os.path.isfile(file_path):
534
+ state_dict = load_state_dict(file_path)
535
+ else:
536
+ state_dict = None
537
+ for model_detector in self.model_detector:
538
+ if model_detector.match(file_path, state_dict):
539
+ model_names, models = model_detector.load(
540
+ file_path,
541
+ state_dict,
542
+ device=device,
543
+ torch_dtype=torch_dtype,
544
+ allowed_model_names=model_names,
545
+ model_manager=self,
546
+ )
547
+ for model_name, model in zip(model_names, models):
548
+ self.model.append(model)
549
+ self.model_path.append(file_path)
550
+ self.model_name.append(model_name)
551
+ print(f" The following models are loaded: {model_names}.")
552
+ break
553
+ else:
554
+ print(f" We cannot detect the model type. No models are loaded.")
555
+
556
+ def load_models(
557
+ self, file_path_list, model_names=None, device=None, torch_dtype=None
558
+ ):
559
+ for file_path in file_path_list:
560
+ self.load_model(
561
+ file_path, model_names, device=device, torch_dtype=torch_dtype
562
+ )
563
+
564
+ def fetch_model(
565
+ self, model_name, file_path=None, require_model_path=False, index=None
566
+ ):
567
+ fetched_models = []
568
+ fetched_model_paths = []
569
+ for model, model_path, model_name_ in zip(
570
+ self.model, self.model_path, self.model_name
571
+ ):
572
+ if file_path is not None and file_path != model_path:
573
+ continue
574
+ if model_name == model_name_:
575
+ fetched_models.append(model)
576
+ fetched_model_paths.append(model_path)
577
+ if len(fetched_models) == 0:
578
+ print(f"No {model_name} models available.")
579
+ return None
580
+ if len(fetched_models) == 1:
581
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
582
+ model = fetched_models[0]
583
+ path = fetched_model_paths[0]
584
+ else:
585
+ if index is None:
586
+ model = fetched_models[0]
587
+ path = fetched_model_paths[0]
588
+ print(
589
+ f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}."
590
+ )
591
+ elif isinstance(index, int):
592
+ model = fetched_models[:index]
593
+ path = fetched_model_paths[:index]
594
+ print(
595
+ f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}."
596
+ )
597
+ else:
598
+ model = fetched_models
599
+ path = fetched_model_paths
600
+ print(
601
+ f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}."
602
+ )
603
+ if require_model_path:
604
+ return model, path
605
+ else:
606
+ return model
607
+
608
+ def to(self, device):
609
+ for model in self.model:
610
+ model.to(device)
models/set_condition_branch.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def set_stand_in(pipe, train=False, model_path=None):
5
+ for block in pipe.dit.blocks:
6
+ block.self_attn.init_lora(train)
7
+ if model_path is not None:
8
+ print(f"Loading Stand-In weights from: {model_path}")
9
+ load_lora_weights_into_pipe(pipe, model_path)
10
+
11
+
12
+ def load_lora_weights_into_pipe(pipe, ckpt_path, strict=True):
13
+ ckpt = torch.load(ckpt_path, map_location="cpu")
14
+ state_dict = ckpt.get("state_dict", ckpt)
15
+
16
+ model = {}
17
+ for i, block in enumerate(pipe.dit.blocks):
18
+ prefix = f"blocks.{i}.self_attn."
19
+ attn = block.self_attn
20
+ for name in ["q_loras", "k_loras", "v_loras"]:
21
+ for sub in ["down", "up"]:
22
+ key = f"{prefix}{name}.{sub}.weight"
23
+ if hasattr(getattr(attn, name), sub):
24
+ model[key] = getattr(getattr(attn, name), sub).weight
25
+ else:
26
+ if strict:
27
+ raise KeyError(f"Missing module: {key}")
28
+
29
+ for k, param in state_dict.items():
30
+ if k in model:
31
+ if model[k].shape != param.shape:
32
+ if strict:
33
+ raise ValueError(
34
+ f"Shape mismatch: {k} | {model[k].shape} vs {param.shape}"
35
+ )
36
+ else:
37
+ continue
38
+ model[k].data.copy_(param)
39
+ else:
40
+ if strict:
41
+ raise KeyError(f"Unexpected key in ckpt: {k}")
models/tiler.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange, repeat
3
+
4
+
5
+ class TileWorker:
6
+ def __init__(self):
7
+ pass
8
+
9
+ def mask(self, height, width, border_width):
10
+ # Create a mask with shape (height, width).
11
+ # The centre area is filled with 1, and the border line is filled with values in range (0, 1].
12
+ x = torch.arange(height).repeat(width, 1).T
13
+ y = torch.arange(width).repeat(height, 1)
14
+ mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
15
+ mask = (mask / border_width).clip(0, 1)
16
+ return mask
17
+
18
+ def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
19
+ # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
20
+ batch_size, channel, _, _ = model_input.shape
21
+ model_input = model_input.to(device=tile_device, dtype=tile_dtype)
22
+ unfold_operator = torch.nn.Unfold(
23
+ kernel_size=(tile_size, tile_size), stride=(tile_stride, tile_stride)
24
+ )
25
+ model_input = unfold_operator(model_input)
26
+ model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
27
+
28
+ return model_input
29
+
30
+ def tiled_inference(
31
+ self,
32
+ forward_fn,
33
+ model_input,
34
+ tile_batch_size,
35
+ inference_device,
36
+ inference_dtype,
37
+ tile_device,
38
+ tile_dtype,
39
+ ):
40
+ # Call y=forward_fn(x) for each tile
41
+ tile_num = model_input.shape[-1]
42
+ model_output_stack = []
43
+
44
+ for tile_id in range(0, tile_num, tile_batch_size):
45
+ # process input
46
+ tile_id_ = min(tile_id + tile_batch_size, tile_num)
47
+ x = model_input[:, :, :, :, tile_id:tile_id_]
48
+ x = x.to(device=inference_device, dtype=inference_dtype)
49
+ x = rearrange(x, "b c h w n -> (n b) c h w")
50
+
51
+ # process output
52
+ y = forward_fn(x)
53
+ y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_ - tile_id)
54
+ y = y.to(device=tile_device, dtype=tile_dtype)
55
+ model_output_stack.append(y)
56
+
57
+ model_output = torch.concat(model_output_stack, dim=-1)
58
+ return model_output
59
+
60
+ def io_scale(self, model_output, tile_size):
61
+ # Determine the size modification happened in forward_fn
62
+ # We only consider the same scale on height and width.
63
+ io_scale = model_output.shape[2] / tile_size
64
+ return io_scale
65
+
66
+ def untile(
67
+ self,
68
+ model_output,
69
+ height,
70
+ width,
71
+ tile_size,
72
+ tile_stride,
73
+ border_width,
74
+ tile_device,
75
+ tile_dtype,
76
+ ):
77
+ # The reversed function of tile
78
+ mask = self.mask(tile_size, tile_size, border_width)
79
+ mask = mask.to(device=tile_device, dtype=tile_dtype)
80
+ mask = rearrange(mask, "h w -> 1 1 h w 1")
81
+ model_output = model_output * mask
82
+
83
+ fold_operator = torch.nn.Fold(
84
+ output_size=(height, width),
85
+ kernel_size=(tile_size, tile_size),
86
+ stride=(tile_stride, tile_stride),
87
+ )
88
+ mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
89
+ model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
90
+ model_output = fold_operator(model_output) / fold_operator(mask)
91
+
92
+ return model_output
93
+
94
+ def tiled_forward(
95
+ self,
96
+ forward_fn,
97
+ model_input,
98
+ tile_size,
99
+ tile_stride,
100
+ tile_batch_size=1,
101
+ tile_device="cpu",
102
+ tile_dtype=torch.float32,
103
+ border_width=None,
104
+ ):
105
+ # Prepare
106
+ inference_device, inference_dtype = model_input.device, model_input.dtype
107
+ height, width = model_input.shape[2], model_input.shape[3]
108
+ border_width = int(tile_stride * 0.5) if border_width is None else border_width
109
+
110
+ # tile
111
+ model_input = self.tile(
112
+ model_input, tile_size, tile_stride, tile_device, tile_dtype
113
+ )
114
+
115
+ # inference
116
+ model_output = self.tiled_inference(
117
+ forward_fn,
118
+ model_input,
119
+ tile_batch_size,
120
+ inference_device,
121
+ inference_dtype,
122
+ tile_device,
123
+ tile_dtype,
124
+ )
125
+
126
+ # resize
127
+ io_scale = self.io_scale(model_output, tile_size)
128
+ height, width = int(height * io_scale), int(width * io_scale)
129
+ tile_size, tile_stride = int(tile_size * io_scale), int(tile_stride * io_scale)
130
+ border_width = int(border_width * io_scale)
131
+
132
+ # untile
133
+ model_output = self.untile(
134
+ model_output,
135
+ height,
136
+ width,
137
+ tile_size,
138
+ tile_stride,
139
+ border_width,
140
+ tile_device,
141
+ tile_dtype,
142
+ )
143
+
144
+ # Done!
145
+ model_output = model_output.to(device=inference_device, dtype=inference_dtype)
146
+ return model_output
147
+
148
+
149
+ class FastTileWorker:
150
+ def __init__(self):
151
+ pass
152
+
153
+ def build_mask(self, data, is_bound):
154
+ _, _, H, W = data.shape
155
+ h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
156
+ w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
157
+ border_width = (H + W) // 4
158
+ pad = torch.ones_like(h) * border_width
159
+ mask = (
160
+ torch.stack(
161
+ [
162
+ pad if is_bound[0] else h + 1,
163
+ pad if is_bound[1] else H - h,
164
+ pad if is_bound[2] else w + 1,
165
+ pad if is_bound[3] else W - w,
166
+ ]
167
+ )
168
+ .min(dim=0)
169
+ .values
170
+ )
171
+ mask = mask.clip(1, border_width)
172
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
173
+ mask = rearrange(mask, "H W -> 1 H W")
174
+ return mask
175
+
176
+ def tiled_forward(
177
+ self,
178
+ forward_fn,
179
+ model_input,
180
+ tile_size,
181
+ tile_stride,
182
+ tile_device="cpu",
183
+ tile_dtype=torch.float32,
184
+ border_width=None,
185
+ ):
186
+ # Prepare
187
+ B, C, H, W = model_input.shape
188
+ border_width = int(tile_stride * 0.5) if border_width is None else border_width
189
+ weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)
190
+ values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)
191
+
192
+ # Split tasks
193
+ tasks = []
194
+ for h in range(0, H, tile_stride):
195
+ for w in range(0, W, tile_stride):
196
+ if (h - tile_stride >= 0 and h - tile_stride + tile_size >= H) or (
197
+ w - tile_stride >= 0 and w - tile_stride + tile_size >= W
198
+ ):
199
+ continue
200
+ h_, w_ = h + tile_size, w + tile_size
201
+ if h_ > H:
202
+ h, h_ = H - tile_size, H
203
+ if w_ > W:
204
+ w, w_ = W - tile_size, W
205
+ tasks.append((h, h_, w, w_))
206
+
207
+ # Run
208
+ for hl, hr, wl, wr in tasks:
209
+ # Forward
210
+ hidden_states_batch = forward_fn(hl, hr, wl, wr).to(
211
+ dtype=tile_dtype, device=tile_device
212
+ )
213
+
214
+ mask = self.build_mask(
215
+ hidden_states_batch, is_bound=(hl == 0, hr >= H, wl == 0, wr >= W)
216
+ )
217
+ values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
218
+ weight[:, :, hl:hr, wl:wr] += mask
219
+ values /= weight
220
+ return values
221
+
222
+
223
+ class TileWorker2Dto3D:
224
+ """
225
+ Process 3D tensors, but only enable TileWorker on 2D.
226
+ """
227
+
228
+ def __init__(self):
229
+ pass
230
+
231
+ def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
232
+ t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
233
+ h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
234
+ w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
235
+ border_width = (H + W) // 4 if border_width is None else border_width
236
+ pad = torch.ones_like(h) * border_width
237
+ mask = (
238
+ torch.stack(
239
+ [
240
+ pad if is_bound[0] else t + 1,
241
+ pad if is_bound[1] else T - t,
242
+ pad if is_bound[2] else h + 1,
243
+ pad if is_bound[3] else H - h,
244
+ pad if is_bound[4] else w + 1,
245
+ pad if is_bound[5] else W - w,
246
+ ]
247
+ )
248
+ .min(dim=0)
249
+ .values
250
+ )
251
+ mask = mask.clip(1, border_width)
252
+ mask = (mask / border_width).to(dtype=dtype, device=device)
253
+ mask = rearrange(mask, "T H W -> 1 1 T H W")
254
+ return mask
255
+
256
+ def tiled_forward(
257
+ self,
258
+ forward_fn,
259
+ model_input,
260
+ tile_size,
261
+ tile_stride,
262
+ tile_device="cpu",
263
+ tile_dtype=torch.float32,
264
+ computation_device="cuda",
265
+ computation_dtype=torch.float32,
266
+ border_width=None,
267
+ scales=[1, 1, 1, 1],
268
+ progress_bar=lambda x: x,
269
+ ):
270
+ B, C, T, H, W = model_input.shape
271
+ scale_C, scale_T, scale_H, scale_W = scales
272
+ tile_size_H, tile_size_W = tile_size
273
+ tile_stride_H, tile_stride_W = tile_stride
274
+
275
+ value = torch.zeros(
276
+ (B, int(C * scale_C), int(T * scale_T), int(H * scale_H), int(W * scale_W)),
277
+ dtype=tile_dtype,
278
+ device=tile_device,
279
+ )
280
+ weight = torch.zeros(
281
+ (1, 1, int(T * scale_T), int(H * scale_H), int(W * scale_W)),
282
+ dtype=tile_dtype,
283
+ device=tile_device,
284
+ )
285
+
286
+ # Split tasks
287
+ tasks = []
288
+ for h in range(0, H, tile_stride_H):
289
+ for w in range(0, W, tile_stride_W):
290
+ if (
291
+ h - tile_stride_H >= 0 and h - tile_stride_H + tile_size_H >= H
292
+ ) or (w - tile_stride_W >= 0 and w - tile_stride_W + tile_size_W >= W):
293
+ continue
294
+ h_, w_ = h + tile_size_H, w + tile_size_W
295
+ if h_ > H:
296
+ h, h_ = max(H - tile_size_H, 0), H
297
+ if w_ > W:
298
+ w, w_ = max(W - tile_size_W, 0), W
299
+ tasks.append((h, h_, w, w_))
300
+
301
+ # Run
302
+ for hl, hr, wl, wr in progress_bar(tasks):
303
+ mask = self.build_mask(
304
+ int(T * scale_T),
305
+ int((hr - hl) * scale_H),
306
+ int((wr - wl) * scale_W),
307
+ tile_dtype,
308
+ tile_device,
309
+ is_bound=(True, True, hl == 0, hr >= H, wl == 0, wr >= W),
310
+ border_width=border_width,
311
+ )
312
+ grid_input = model_input[:, :, :, hl:hr, wl:wr].to(
313
+ dtype=computation_dtype, device=computation_device
314
+ )
315
+ grid_output = forward_fn(grid_input).to(
316
+ dtype=tile_dtype, device=tile_device
317
+ )
318
+ value[
319
+ :,
320
+ :,
321
+ :,
322
+ int(hl * scale_H) : int(hr * scale_H),
323
+ int(wl * scale_W) : int(wr * scale_W),
324
+ ] += grid_output * mask
325
+ weight[
326
+ :,
327
+ :,
328
+ :,
329
+ int(hl * scale_H) : int(hr * scale_H),
330
+ int(wl * scale_W) : int(wr * scale_W),
331
+ ] += mask
332
+ value = value / weight
333
+ return value
models/utils.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ from safetensors import safe_open
3
+ from contextlib import contextmanager
4
+ import hashlib
5
+
6
+
7
+ @contextmanager
8
+ def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False):
9
+ old_register_parameter = torch.nn.Module.register_parameter
10
+ if include_buffers:
11
+ old_register_buffer = torch.nn.Module.register_buffer
12
+
13
+ def register_empty_parameter(module, name, param):
14
+ old_register_parameter(module, name, param)
15
+ if param is not None:
16
+ param_cls = type(module._parameters[name])
17
+ kwargs = module._parameters[name].__dict__
18
+ kwargs["requires_grad"] = param.requires_grad
19
+ module._parameters[name] = param_cls(
20
+ module._parameters[name].to(device), **kwargs
21
+ )
22
+
23
+ def register_empty_buffer(module, name, buffer, persistent=True):
24
+ old_register_buffer(module, name, buffer, persistent=persistent)
25
+ if buffer is not None:
26
+ module._buffers[name] = module._buffers[name].to(device)
27
+
28
+ def patch_tensor_constructor(fn):
29
+ def wrapper(*args, **kwargs):
30
+ kwargs["device"] = device
31
+ return fn(*args, **kwargs)
32
+
33
+ return wrapper
34
+
35
+ if include_buffers:
36
+ tensor_constructors_to_patch = {
37
+ torch_function_name: getattr(torch, torch_function_name)
38
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
39
+ }
40
+ else:
41
+ tensor_constructors_to_patch = {}
42
+
43
+ try:
44
+ torch.nn.Module.register_parameter = register_empty_parameter
45
+ if include_buffers:
46
+ torch.nn.Module.register_buffer = register_empty_buffer
47
+ for torch_function_name in tensor_constructors_to_patch.keys():
48
+ setattr(
49
+ torch,
50
+ torch_function_name,
51
+ patch_tensor_constructor(getattr(torch, torch_function_name)),
52
+ )
53
+ yield
54
+ finally:
55
+ torch.nn.Module.register_parameter = old_register_parameter
56
+ if include_buffers:
57
+ torch.nn.Module.register_buffer = old_register_buffer
58
+ for (
59
+ torch_function_name,
60
+ old_torch_function,
61
+ ) in tensor_constructors_to_patch.items():
62
+ setattr(torch, torch_function_name, old_torch_function)
63
+
64
+
65
+ def load_state_dict_from_folder(file_path, torch_dtype=None):
66
+ state_dict = {}
67
+ for file_name in os.listdir(file_path):
68
+ if "." in file_name and file_name.split(".")[-1] in [
69
+ "safetensors",
70
+ "bin",
71
+ "ckpt",
72
+ "pth",
73
+ "pt",
74
+ ]:
75
+ state_dict.update(
76
+ load_state_dict(
77
+ os.path.join(file_path, file_name), torch_dtype=torch_dtype
78
+ )
79
+ )
80
+ return state_dict
81
+
82
+
83
+ def load_state_dict(file_path, torch_dtype=None, device="cpu"):
84
+ if file_path.endswith(".safetensors"):
85
+ return load_state_dict_from_safetensors(
86
+ file_path, torch_dtype=torch_dtype, device=device
87
+ )
88
+ else:
89
+ return load_state_dict_from_bin(
90
+ file_path, torch_dtype=torch_dtype, device=device
91
+ )
92
+
93
+
94
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
95
+ state_dict = {}
96
+ with safe_open(file_path, framework="pt", device=str(device)) as f:
97
+ for k in f.keys():
98
+ state_dict[k] = f.get_tensor(k)
99
+ if torch_dtype is not None:
100
+ state_dict[k] = state_dict[k].to(torch_dtype)
101
+ return state_dict
102
+
103
+
104
+ def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
105
+ state_dict = torch.load(file_path, map_location=device, weights_only=True)
106
+ if torch_dtype is not None:
107
+ for i in state_dict:
108
+ if isinstance(state_dict[i], torch.Tensor):
109
+ state_dict[i] = state_dict[i].to(torch_dtype)
110
+ return state_dict
111
+
112
+
113
+ def search_for_embeddings(state_dict):
114
+ embeddings = []
115
+ for k in state_dict:
116
+ if isinstance(state_dict[k], torch.Tensor):
117
+ embeddings.append(state_dict[k])
118
+ elif isinstance(state_dict[k], dict):
119
+ embeddings += search_for_embeddings(state_dict[k])
120
+ return embeddings
121
+
122
+
123
+ def search_parameter(param, state_dict):
124
+ for name, param_ in state_dict.items():
125
+ if param.numel() == param_.numel():
126
+ if param.shape == param_.shape:
127
+ if torch.dist(param, param_) < 1e-3:
128
+ return name
129
+ else:
130
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
131
+ return name
132
+ return None
133
+
134
+
135
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
136
+ matched_keys = set()
137
+ with torch.no_grad():
138
+ for name in source_state_dict:
139
+ rename = search_parameter(source_state_dict[name], target_state_dict)
140
+ if rename is not None:
141
+ print(f'"{name}": "{rename}",')
142
+ matched_keys.add(rename)
143
+ elif (
144
+ split_qkv
145
+ and len(source_state_dict[name].shape) >= 1
146
+ and source_state_dict[name].shape[0] % 3 == 0
147
+ ):
148
+ length = source_state_dict[name].shape[0] // 3
149
+ rename = []
150
+ for i in range(3):
151
+ rename.append(
152
+ search_parameter(
153
+ source_state_dict[name][i * length : i * length + length],
154
+ target_state_dict,
155
+ )
156
+ )
157
+ if None not in rename:
158
+ print(f'"{name}": {rename},')
159
+ for rename_ in rename:
160
+ matched_keys.add(rename_)
161
+ for name in target_state_dict:
162
+ if name not in matched_keys:
163
+ print("Cannot find", name, target_state_dict[name].shape)
164
+
165
+
166
+ def search_for_files(folder, extensions):
167
+ files = []
168
+ if os.path.isdir(folder):
169
+ for file in sorted(os.listdir(folder)):
170
+ files += search_for_files(os.path.join(folder, file), extensions)
171
+ elif os.path.isfile(folder):
172
+ for extension in extensions:
173
+ if folder.endswith(extension):
174
+ files.append(folder)
175
+ break
176
+ return files
177
+
178
+
179
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
180
+ keys = []
181
+ for key, value in state_dict.items():
182
+ if isinstance(key, str):
183
+ if isinstance(value, torch.Tensor):
184
+ if with_shape:
185
+ shape = "_".join(map(str, list(value.shape)))
186
+ keys.append(key + ":" + shape)
187
+ keys.append(key)
188
+ elif isinstance(value, dict):
189
+ keys.append(
190
+ key
191
+ + "|"
192
+ + convert_state_dict_keys_to_single_str(
193
+ value, with_shape=with_shape
194
+ )
195
+ )
196
+ keys.sort()
197
+ keys_str = ",".join(keys)
198
+ return keys_str
199
+
200
+
201
+ def split_state_dict_with_prefix(state_dict):
202
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
203
+ prefix_dict = {}
204
+ for key in keys:
205
+ prefix = key if "." not in key else key.split(".")[0]
206
+ if prefix not in prefix_dict:
207
+ prefix_dict[prefix] = []
208
+ prefix_dict[prefix].append(key)
209
+ state_dicts = []
210
+ for prefix, keys in prefix_dict.items():
211
+ sub_state_dict = {key: state_dict[key] for key in keys}
212
+ state_dicts.append(sub_state_dict)
213
+ return state_dicts
214
+
215
+
216
+ def hash_state_dict_keys(state_dict, with_shape=True):
217
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
218
+ keys_str = keys_str.encode(encoding="UTF-8")
219
+ return hashlib.md5(keys_str).hexdigest()
models/wan_video_camera_controller.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from einops import rearrange
5
+ import os
6
+ from typing_extensions import Literal
7
+
8
+
9
+ class SimpleAdapter(nn.Module):
10
+ def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
11
+ super(SimpleAdapter, self).__init__()
12
+
13
+ # Pixel Unshuffle: reduce spatial dimensions by a factor of 8
14
+ self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
15
+
16
+ # Convolution: reduce spatial dimensions by a factor
17
+ # of 2 (without overlap)
18
+ self.conv = nn.Conv2d(
19
+ in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0
20
+ )
21
+
22
+ # Residual blocks for feature extraction
23
+ self.residual_blocks = nn.Sequential(
24
+ *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
25
+ )
26
+
27
+ def forward(self, x):
28
+ # Reshape to merge the frame dimension into batch
29
+ bs, c, f, h, w = x.size()
30
+ x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
31
+
32
+ # Pixel Unshuffle operation
33
+ x_unshuffled = self.pixel_unshuffle(x)
34
+
35
+ # Convolution operation
36
+ x_conv = self.conv(x_unshuffled)
37
+
38
+ # Feature extraction with residual blocks
39
+ out = self.residual_blocks(x_conv)
40
+
41
+ # Reshape to restore original bf dimension
42
+ out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
43
+
44
+ # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
45
+ out = out.permute(0, 2, 1, 3, 4)
46
+
47
+ return out
48
+
49
+ def process_camera_coordinates(
50
+ self,
51
+ direction: Literal[
52
+ "Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"
53
+ ],
54
+ length: int,
55
+ height: int,
56
+ width: int,
57
+ speed: float = 1 / 54,
58
+ origin=(
59
+ 0,
60
+ 0.532139961,
61
+ 0.946026558,
62
+ 0.5,
63
+ 0.5,
64
+ 0,
65
+ 0,
66
+ 1,
67
+ 0,
68
+ 0,
69
+ 0,
70
+ 0,
71
+ 1,
72
+ 0,
73
+ 0,
74
+ 0,
75
+ 0,
76
+ 1,
77
+ 0,
78
+ ),
79
+ ):
80
+ if origin is None:
81
+ origin = (
82
+ 0,
83
+ 0.532139961,
84
+ 0.946026558,
85
+ 0.5,
86
+ 0.5,
87
+ 0,
88
+ 0,
89
+ 1,
90
+ 0,
91
+ 0,
92
+ 0,
93
+ 0,
94
+ 1,
95
+ 0,
96
+ 0,
97
+ 0,
98
+ 0,
99
+ 1,
100
+ 0,
101
+ )
102
+ coordinates = generate_camera_coordinates(direction, length, speed, origin)
103
+ plucker_embedding = process_pose_file(coordinates, width, height)
104
+ return plucker_embedding
105
+
106
+
107
+ class ResidualBlock(nn.Module):
108
+ def __init__(self, dim):
109
+ super(ResidualBlock, self).__init__()
110
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
111
+ self.relu = nn.ReLU(inplace=True)
112
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
113
+
114
+ def forward(self, x):
115
+ residual = x
116
+ out = self.relu(self.conv1(x))
117
+ out = self.conv2(out)
118
+ out += residual
119
+ return out
120
+
121
+
122
+ class Camera(object):
123
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
124
+
125
+ def __init__(self, entry):
126
+ fx, fy, cx, cy = entry[1:5]
127
+ self.fx = fx
128
+ self.fy = fy
129
+ self.cx = cx
130
+ self.cy = cy
131
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
132
+ w2c_mat_4x4 = np.eye(4)
133
+ w2c_mat_4x4[:3, :] = w2c_mat
134
+ self.w2c_mat = w2c_mat_4x4
135
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
136
+
137
+
138
+ def get_relative_pose(cam_params):
139
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
140
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
141
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
142
+ cam_to_origin = 0
143
+ target_cam_c2w = np.array(
144
+ [[1, 0, 0, 0], [0, 1, 0, -cam_to_origin], [0, 0, 1, 0], [0, 0, 0, 1]]
145
+ )
146
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
147
+ ret_poses = [
148
+ target_cam_c2w,
149
+ ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
150
+ ret_poses = np.array(ret_poses, dtype=np.float32)
151
+ return ret_poses
152
+
153
+
154
+ def custom_meshgrid(*args):
155
+ # torch>=2.0.0 only
156
+ return torch.meshgrid(*args, indexing="ij")
157
+
158
+
159
+ def ray_condition(K, c2w, H, W, device):
160
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
161
+ # c2w: B, V, 4, 4
162
+ # K: B, V, 4
163
+
164
+ B = K.shape[0]
165
+
166
+ j, i = custom_meshgrid(
167
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
168
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
169
+ )
170
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
171
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
172
+
173
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
174
+
175
+ zs = torch.ones_like(i) # [B, HxW]
176
+ xs = (i - cx) / fx * zs
177
+ ys = (j - cy) / fy * zs
178
+ zs = zs.expand_as(ys)
179
+
180
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
181
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
182
+
183
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
184
+ rays_o = c2w[..., :3, 3] # B, V, 3
185
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
186
+ # c2w @ dirctions
187
+ rays_dxo = torch.linalg.cross(rays_o, rays_d)
188
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
189
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
190
+ # plucker = plucker.permute(0, 1, 4, 2, 3)
191
+ return plucker
192
+
193
+
194
+ def process_pose_file(
195
+ cam_params,
196
+ width=672,
197
+ height=384,
198
+ original_pose_width=1280,
199
+ original_pose_height=720,
200
+ device="cpu",
201
+ return_poses=False,
202
+ ):
203
+ if return_poses:
204
+ return cam_params
205
+ else:
206
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
207
+
208
+ sample_wh_ratio = width / height
209
+ pose_wh_ratio = (
210
+ original_pose_width / original_pose_height
211
+ ) # Assuming placeholder ratios, change as needed
212
+
213
+ if pose_wh_ratio > sample_wh_ratio:
214
+ resized_ori_w = height * pose_wh_ratio
215
+ for cam_param in cam_params:
216
+ cam_param.fx = resized_ori_w * cam_param.fx / width
217
+ else:
218
+ resized_ori_h = width / pose_wh_ratio
219
+ for cam_param in cam_params:
220
+ cam_param.fy = resized_ori_h * cam_param.fy / height
221
+
222
+ intrinsic = np.asarray(
223
+ [
224
+ [
225
+ cam_param.fx * width,
226
+ cam_param.fy * height,
227
+ cam_param.cx * width,
228
+ cam_param.cy * height,
229
+ ]
230
+ for cam_param in cam_params
231
+ ],
232
+ dtype=np.float32,
233
+ )
234
+
235
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
236
+ c2ws = get_relative_pose(
237
+ cam_params
238
+ ) # Assuming this function is defined elsewhere
239
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
240
+ plucker_embedding = (
241
+ ray_condition(K, c2ws, height, width, device=device)[0]
242
+ .permute(0, 3, 1, 2)
243
+ .contiguous()
244
+ ) # V, 6, H, W
245
+ plucker_embedding = plucker_embedding[None]
246
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
247
+ return plucker_embedding
248
+
249
+
250
+ def generate_camera_coordinates(
251
+ direction: Literal[
252
+ "Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"
253
+ ],
254
+ length: int,
255
+ speed: float = 1 / 54,
256
+ origin=(
257
+ 0,
258
+ 0.532139961,
259
+ 0.946026558,
260
+ 0.5,
261
+ 0.5,
262
+ 0,
263
+ 0,
264
+ 1,
265
+ 0,
266
+ 0,
267
+ 0,
268
+ 0,
269
+ 1,
270
+ 0,
271
+ 0,
272
+ 0,
273
+ 0,
274
+ 1,
275
+ 0,
276
+ ),
277
+ ):
278
+ coordinates = [list(origin)]
279
+ while len(coordinates) < length:
280
+ coor = coordinates[-1].copy()
281
+ if "Left" in direction:
282
+ coor[9] += speed
283
+ if "Right" in direction:
284
+ coor[9] -= speed
285
+ if "Up" in direction:
286
+ coor[13] += speed
287
+ if "Down" in direction:
288
+ coor[13] -= speed
289
+ coordinates.append(coor)
290
+ return coordinates
models/wan_video_dit.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import Tuple, Optional
6
+ from einops import rearrange
7
+ from .utils import hash_state_dict_keys
8
+ from .wan_video_camera_controller import SimpleAdapter
9
+
10
+ try:
11
+ import flash_attn_interface
12
+
13
+ FLASH_ATTN_3_AVAILABLE = True
14
+ except ModuleNotFoundError:
15
+ FLASH_ATTN_3_AVAILABLE = False
16
+
17
+ try:
18
+ import flash_attn
19
+
20
+ FLASH_ATTN_2_AVAILABLE = True
21
+ except ModuleNotFoundError:
22
+ FLASH_ATTN_2_AVAILABLE = False
23
+
24
+ try:
25
+ from sageattention import sageattn
26
+
27
+ SAGE_ATTN_AVAILABLE = True
28
+ except ModuleNotFoundError:
29
+ SAGE_ATTN_AVAILABLE = False
30
+
31
+
32
+ def flash_attention(
33
+ q: torch.Tensor,
34
+ k: torch.Tensor,
35
+ v: torch.Tensor,
36
+ num_heads: int,
37
+ compatibility_mode=False,
38
+ ):
39
+ if compatibility_mode:
40
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
41
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
42
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
43
+ x = F.scaled_dot_product_attention(q, k, v)
44
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
45
+ elif FLASH_ATTN_3_AVAILABLE:
46
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
47
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
48
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
49
+ x = flash_attn_interface.flash_attn_func(q, k, v)
50
+ if isinstance(x, tuple):
51
+ x = x[0]
52
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
53
+ elif FLASH_ATTN_2_AVAILABLE:
54
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
55
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
56
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
57
+ x = flash_attn.flash_attn_func(q, k, v)
58
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
59
+ elif SAGE_ATTN_AVAILABLE:
60
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
61
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
62
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
63
+ x = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
64
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
65
+ else:
66
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
67
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
68
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
69
+ x = F.scaled_dot_product_attention(q, k, v)
70
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
71
+ return x
72
+
73
+
74
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
75
+ return x * (1 + scale) + shift
76
+
77
+
78
+ def sinusoidal_embedding_1d(dim, position):
79
+ sinusoid = torch.outer(
80
+ position.type(torch.float64),
81
+ torch.pow(
82
+ 10000,
83
+ -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(
84
+ dim // 2
85
+ ),
86
+ ),
87
+ )
88
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
89
+ return x.to(position.dtype)
90
+
91
+
92
+ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
93
+ # 3d rope precompute
94
+ f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end + 1, theta)
95
+ h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
96
+ w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
97
+ return f_freqs_cis, h_freqs_cis, w_freqs_cis
98
+
99
+
100
+ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
101
+ # 1d rope precompute
102
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim))
103
+ ###################################################### add f = -1
104
+ positions = torch.arange(-1, end, device=freqs.device)
105
+ freqs = torch.outer(positions, freqs)
106
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
107
+ ######################################################
108
+ return freqs_cis
109
+
110
+
111
+ def rope_apply(x, freqs, num_heads):
112
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
113
+ x_out = torch.view_as_complex(
114
+ x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
115
+ )
116
+ x_out = torch.view_as_real(x_out * freqs).flatten(2)
117
+ return x_out.to(x.dtype)
118
+
119
+
120
+ class RMSNorm(nn.Module):
121
+ def __init__(self, dim, eps=1e-5):
122
+ super().__init__()
123
+ self.eps = eps
124
+ self.weight = nn.Parameter(torch.ones(dim))
125
+
126
+ def norm(self, x):
127
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
128
+
129
+ def forward(self, x):
130
+ dtype = x.dtype
131
+ return self.norm(x.float()).to(dtype) * self.weight
132
+
133
+
134
+ class AttentionModule(nn.Module):
135
+ def __init__(self, num_heads):
136
+ super().__init__()
137
+ self.num_heads = num_heads
138
+
139
+ def forward(self, q, k, v):
140
+ x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
141
+ return x
142
+
143
+
144
+ class LoRALinearLayer(nn.Module):
145
+ def __init__(
146
+ self,
147
+ in_features: int,
148
+ out_features: int,
149
+ rank: int = 128,
150
+ device="cuda",
151
+ dtype: Optional[torch.dtype] = torch.float32,
152
+ ):
153
+ super().__init__()
154
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
155
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
156
+ self.rank = rank
157
+ self.out_features = out_features
158
+ self.in_features = in_features
159
+
160
+ nn.init.normal_(self.down.weight, std=1 / rank)
161
+ nn.init.zeros_(self.up.weight)
162
+
163
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
164
+ orig_dtype = hidden_states.dtype
165
+ dtype = self.down.weight.dtype
166
+
167
+ down_hidden_states = self.down(hidden_states.to(dtype))
168
+ up_hidden_states = self.up(down_hidden_states)
169
+ return up_hidden_states.to(orig_dtype)
170
+
171
+
172
+ class SelfAttention(nn.Module):
173
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
174
+ super().__init__()
175
+ self.dim = dim
176
+ self.num_heads = num_heads
177
+ self.head_dim = dim // num_heads
178
+
179
+ self.q = nn.Linear(dim, dim)
180
+ self.k = nn.Linear(dim, dim)
181
+ self.v = nn.Linear(dim, dim)
182
+ self.o = nn.Linear(dim, dim)
183
+ self.norm_q = RMSNorm(dim, eps=eps)
184
+ self.norm_k = RMSNorm(dim, eps=eps)
185
+
186
+ self.attn = AttentionModule(self.num_heads)
187
+
188
+ self.kv_cache = None
189
+ self.cond_size = None
190
+
191
+ def init_lora(self, train=False):
192
+ dim = self.dim
193
+ self.q_loras = LoRALinearLayer(dim, dim, rank=128)
194
+ self.k_loras = LoRALinearLayer(dim, dim, rank=128)
195
+ self.v_loras = LoRALinearLayer(dim, dim, rank=128)
196
+
197
+ requires_grad = train
198
+ for lora in [self.q_loras, self.k_loras, self.v_loras]:
199
+ for param in lora.parameters():
200
+ param.requires_grad = requires_grad
201
+
202
+ def forward(self, x, freqs):
203
+ if self.cond_size is not None:
204
+ if self.kv_cache is None:
205
+ x_main, x_ip = x[:, : -self.cond_size], x[:, -self.cond_size :]
206
+ split_point = freqs.shape[0] - self.cond_size
207
+ freqs_main = freqs[:split_point]
208
+ freqs_ip = freqs[split_point:]
209
+
210
+ q_main = self.norm_q(self.q(x_main))
211
+ k_main = self.norm_k(self.k(x_main))
212
+ v_main = self.v(x_main)
213
+
214
+ q_main = rope_apply(q_main, freqs_main, self.num_heads)
215
+ k_main = rope_apply(k_main, freqs_main, self.num_heads)
216
+
217
+ q_ip = self.norm_q(self.q(x_ip) + self.q_loras(x_ip))
218
+ k_ip = self.norm_k(self.k(x_ip) + self.k_loras(x_ip))
219
+ v_ip = self.v(x_ip) + self.v_loras(x_ip)
220
+
221
+ q_ip = rope_apply(q_ip, freqs_ip, self.num_heads)
222
+ k_ip = rope_apply(k_ip, freqs_ip, self.num_heads)
223
+ self.kv_cache = {"k_ip": k_ip.detach(), "v_ip": v_ip.detach()}
224
+ full_k = torch.concat([k_main, k_ip], dim=1)
225
+ full_v = torch.concat([v_main, v_ip], dim=1)
226
+ cond_out = self.attn(q_ip, k_ip, v_ip)
227
+ main_out = self.attn(q_main, full_k, full_v)
228
+ out = torch.concat([main_out, cond_out], dim=1)
229
+ return self.o(out)
230
+
231
+ else:
232
+ k_ip = self.kv_cache["k_ip"]
233
+ v_ip = self.kv_cache["v_ip"]
234
+ q_main = self.norm_q(self.q(x))
235
+ k_main = self.norm_k(self.k(x))
236
+ v_main = self.v(x)
237
+ q_main = rope_apply(q_main, freqs, self.num_heads)
238
+ k_main = rope_apply(k_main, freqs, self.num_heads)
239
+
240
+ full_k = torch.concat([k_main, k_ip], dim=1)
241
+ full_v = torch.concat([v_main, v_ip], dim=1)
242
+ x = self.attn(q_main, full_k, full_v)
243
+ return self.o(x)
244
+ else:
245
+ q = self.norm_q(self.q(x))
246
+ k = self.norm_k(self.k(x))
247
+ v = self.v(x)
248
+ q = rope_apply(q, freqs, self.num_heads)
249
+ k = rope_apply(k, freqs, self.num_heads)
250
+ x = self.attn(q, k, v)
251
+ return self.o(x)
252
+
253
+
254
+ class CrossAttention(nn.Module):
255
+ def __init__(
256
+ self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False
257
+ ):
258
+ super().__init__()
259
+ self.dim = dim
260
+ self.num_heads = num_heads
261
+ self.head_dim = dim // num_heads
262
+
263
+ self.q = nn.Linear(dim, dim)
264
+ self.k = nn.Linear(dim, dim)
265
+ self.v = nn.Linear(dim, dim)
266
+ self.o = nn.Linear(dim, dim)
267
+ self.norm_q = RMSNorm(dim, eps=eps)
268
+ self.norm_k = RMSNorm(dim, eps=eps)
269
+ self.has_image_input = has_image_input
270
+ if has_image_input:
271
+ self.k_img = nn.Linear(dim, dim)
272
+ self.v_img = nn.Linear(dim, dim)
273
+ self.norm_k_img = RMSNorm(dim, eps=eps)
274
+
275
+ self.attn = AttentionModule(self.num_heads)
276
+
277
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
278
+ if self.has_image_input:
279
+ img = y[:, :257]
280
+ ctx = y[:, 257:]
281
+ else:
282
+ ctx = y
283
+ q = self.norm_q(self.q(x))
284
+ k = self.norm_k(self.k(ctx))
285
+ v = self.v(ctx)
286
+ x = self.attn(q, k, v)
287
+ if self.has_image_input:
288
+ k_img = self.norm_k_img(self.k_img(img))
289
+ v_img = self.v_img(img)
290
+ y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
291
+ x = x + y
292
+ return self.o(x)
293
+
294
+
295
+ class GateModule(nn.Module):
296
+ def __init__(
297
+ self,
298
+ ):
299
+ super().__init__()
300
+
301
+ def forward(self, x, gate, residual):
302
+ return x + gate * residual
303
+
304
+
305
+ class DiTBlock(nn.Module):
306
+ def __init__(
307
+ self,
308
+ has_image_input: bool,
309
+ dim: int,
310
+ num_heads: int,
311
+ ffn_dim: int,
312
+ eps: float = 1e-6,
313
+ ):
314
+ super().__init__()
315
+ self.dim = dim
316
+ self.num_heads = num_heads
317
+ self.ffn_dim = ffn_dim
318
+
319
+ self.self_attn = SelfAttention(dim, num_heads, eps)
320
+ self.cross_attn = CrossAttention(
321
+ dim, num_heads, eps, has_image_input=has_image_input
322
+ )
323
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
324
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
325
+ self.norm3 = nn.LayerNorm(dim, eps=eps)
326
+ self.ffn = nn.Sequential(
327
+ nn.Linear(dim, ffn_dim),
328
+ nn.GELU(approximate="tanh"),
329
+ nn.Linear(ffn_dim, dim),
330
+ )
331
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
332
+ self.gate = GateModule()
333
+
334
+ def forward(self, x, context, t_mod, freqs, x_ip=None, t_mod_ip=None):
335
+ # msa: multi-head self-attention mlp: multi-layer perceptron
336
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
337
+ self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
338
+ ).chunk(6, dim=1)
339
+
340
+ input_x = modulate(self.norm1(x), shift_msa, scale_msa)
341
+
342
+ if x_ip is not None:
343
+ (
344
+ shift_msa_ip,
345
+ scale_msa_ip,
346
+ gate_msa_ip,
347
+ shift_mlp_ip,
348
+ scale_mlp_ip,
349
+ gate_mlp_ip,
350
+ ) = (
351
+ self.modulation.to(dtype=t_mod_ip.dtype, device=t_mod_ip.device)
352
+ + t_mod_ip
353
+ ).chunk(6, dim=1)
354
+ input_x_ip = modulate(
355
+ self.norm1(x_ip), shift_msa_ip, scale_msa_ip
356
+ ) # [1, 1024, 5120]
357
+ self.self_attn.cond_size = input_x_ip.shape[1]
358
+ input_x = torch.concat([input_x, input_x_ip], dim=1)
359
+ self.self_attn.kv_cache = None
360
+
361
+ attn_out = self.self_attn(input_x, freqs)
362
+ if x_ip is not None:
363
+ attn_out, attn_out_ip = (
364
+ attn_out[:, : -self.self_attn.cond_size],
365
+ attn_out[:, -self.self_attn.cond_size :],
366
+ )
367
+
368
+ x = self.gate(x, gate_msa, attn_out)
369
+ x = x + self.cross_attn(self.norm3(x), context)
370
+ input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
371
+ x = self.gate(x, gate_mlp, self.ffn(input_x))
372
+
373
+ if x_ip is not None:
374
+ x_ip = self.gate(x_ip, gate_msa_ip, attn_out_ip)
375
+ input_x_ip = modulate(self.norm2(x_ip), shift_mlp_ip, scale_mlp_ip)
376
+ x_ip = self.gate(x_ip, gate_mlp_ip, self.ffn(input_x_ip))
377
+ return x, x_ip
378
+
379
+
380
+ class MLP(torch.nn.Module):
381
+ def __init__(self, in_dim, out_dim, has_pos_emb=False):
382
+ super().__init__()
383
+ self.proj = torch.nn.Sequential(
384
+ nn.LayerNorm(in_dim),
385
+ nn.Linear(in_dim, in_dim),
386
+ nn.GELU(),
387
+ nn.Linear(in_dim, out_dim),
388
+ nn.LayerNorm(out_dim),
389
+ )
390
+ self.has_pos_emb = has_pos_emb
391
+ if has_pos_emb:
392
+ self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
393
+
394
+ def forward(self, x):
395
+ if self.has_pos_emb:
396
+ x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
397
+ return self.proj(x)
398
+
399
+
400
+ class Head(nn.Module):
401
+ def __init__(
402
+ self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float
403
+ ):
404
+ super().__init__()
405
+ self.dim = dim
406
+ self.patch_size = patch_size
407
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
408
+ self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
409
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
410
+
411
+ def forward(self, x, t_mod):
412
+ if len(t_mod.shape) == 3:
413
+ shift, scale = (
414
+ self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device)
415
+ + t_mod.unsqueeze(2)
416
+ ).chunk(2, dim=2)
417
+ x = self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))
418
+ else:
419
+ shift, scale = (
420
+ self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
421
+ ).chunk(2, dim=1)
422
+ x = self.head(self.norm(x) * (1 + scale) + shift)
423
+ return x
424
+
425
+
426
+ class WanModel(torch.nn.Module):
427
+ def __init__(
428
+ self,
429
+ dim: int,
430
+ in_dim: int,
431
+ ffn_dim: int,
432
+ out_dim: int,
433
+ text_dim: int,
434
+ freq_dim: int,
435
+ eps: float,
436
+ patch_size: Tuple[int, int, int],
437
+ num_heads: int,
438
+ num_layers: int,
439
+ has_image_input: bool,
440
+ has_image_pos_emb: bool = False,
441
+ has_ref_conv: bool = False,
442
+ add_control_adapter: bool = False,
443
+ in_dim_control_adapter: int = 24,
444
+ seperated_timestep: bool = False,
445
+ require_vae_embedding: bool = True,
446
+ require_clip_embedding: bool = True,
447
+ fuse_vae_embedding_in_latents: bool = False,
448
+ ):
449
+ super().__init__()
450
+ self.dim = dim
451
+ self.freq_dim = freq_dim
452
+ self.has_image_input = has_image_input
453
+ self.patch_size = patch_size
454
+ self.seperated_timestep = seperated_timestep
455
+ self.require_vae_embedding = require_vae_embedding
456
+ self.require_clip_embedding = require_clip_embedding
457
+ self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
458
+
459
+ self.patch_embedding = nn.Conv3d(
460
+ in_dim, dim, kernel_size=patch_size, stride=patch_size
461
+ )
462
+ self.text_embedding = nn.Sequential(
463
+ nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
464
+ )
465
+ self.time_embedding = nn.Sequential(
466
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
467
+ )
468
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
469
+ self.blocks = nn.ModuleList(
470
+ [
471
+ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
472
+ for _ in range(num_layers)
473
+ ]
474
+ )
475
+ self.head = Head(dim, out_dim, patch_size, eps)
476
+ head_dim = dim // num_heads
477
+ self.freqs = precompute_freqs_cis_3d(head_dim)
478
+
479
+ if has_image_input:
480
+ self.img_emb = MLP(
481
+ 1280, dim, has_pos_emb=has_image_pos_emb
482
+ ) # clip_feature_dim = 1280
483
+ if has_ref_conv:
484
+ self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
485
+ self.has_image_pos_emb = has_image_pos_emb
486
+ self.has_ref_conv = has_ref_conv
487
+ if add_control_adapter:
488
+ self.control_adapter = SimpleAdapter(
489
+ in_dim_control_adapter,
490
+ dim,
491
+ kernel_size=patch_size[1:],
492
+ stride=patch_size[1:],
493
+ )
494
+ else:
495
+ self.control_adapter = None
496
+
497
+ def patchify(
498
+ self, x: torch.Tensor, control_camera_latents_input: torch.Tensor = None
499
+ ):
500
+ x = self.patch_embedding(x)
501
+ if (
502
+ self.control_adapter is not None
503
+ and control_camera_latents_input is not None
504
+ ):
505
+ y_camera = self.control_adapter(control_camera_latents_input)
506
+ x = [u + v for u, v in zip(x, y_camera)]
507
+ x = x[0].unsqueeze(0)
508
+ grid_size = x.shape[2:]
509
+ x = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
510
+ return x, grid_size # x, grid_size: (f, h, w)
511
+
512
+ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
513
+ return rearrange(
514
+ x,
515
+ "b (f h w) (x y z c) -> b c (f x) (h y) (w z)",
516
+ f=grid_size[0],
517
+ h=grid_size[1],
518
+ w=grid_size[2],
519
+ x=self.patch_size[0],
520
+ y=self.patch_size[1],
521
+ z=self.patch_size[2],
522
+ )
523
+
524
+ def forward(
525
+ self,
526
+ x: torch.Tensor,
527
+ timestep: torch.Tensor,
528
+ context: torch.Tensor,
529
+ clip_feature: Optional[torch.Tensor] = None,
530
+ y: Optional[torch.Tensor] = None,
531
+ use_gradient_checkpointing: bool = False,
532
+ use_gradient_checkpointing_offload: bool = False,
533
+ ip_image=None,
534
+ **kwargs,
535
+ ):
536
+ x_ip = None
537
+ t_mod_ip = None
538
+ t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
539
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
540
+ context = self.text_embedding(context)
541
+
542
+ if ip_image is not None:
543
+ timestep_ip = torch.zeros_like(timestep) # [B] with 0s
544
+ t_ip = self.time_embedding(
545
+ sinusoidal_embedding_1d(self.freq_dim, timestep_ip)
546
+ )
547
+ t_mod_ip = self.time_projection(t_ip).unflatten(1, (6, self.dim))
548
+ x, (f, h, w) = self.patchify(x)
549
+
550
+ offset = 1
551
+ freqs = (
552
+ torch.cat(
553
+ [
554
+ self.freqs[0][offset : f + offset]
555
+ .view(f, 1, 1, -1)
556
+ .expand(f, h, w, -1),
557
+ self.freqs[1][offset : h + offset]
558
+ .view(1, h, 1, -1)
559
+ .expand(f, h, w, -1),
560
+ self.freqs[2][offset : w + offset]
561
+ .view(1, 1, w, -1)
562
+ .expand(f, h, w, -1),
563
+ ],
564
+ dim=-1,
565
+ )
566
+ .reshape(f * h * w, 1, -1)
567
+ .to(x.device)
568
+ )
569
+
570
+ ############################################################################################
571
+ if ip_image is not None:
572
+ if ip_image.dim() == 6 and ip_image.shape[3] == 1:
573
+ ip_image = ip_image.squeeze(1)
574
+ x_ip, (f_ip, h_ip, w_ip) = self.patchify(
575
+ ip_image
576
+ ) # x_ip [1, 1024, 5120] [B, N, D] f_ip = 1 h_ip = 32 w_ip = 32
577
+ freqs_ip = (
578
+ torch.cat(
579
+ [
580
+ self.freqs[0][0]
581
+ .view(f_ip, 1, 1, -1)
582
+ .expand(f_ip, h_ip, w_ip, -1),
583
+ self.freqs[1][h + offset : h + offset + h_ip]
584
+ .view(1, h_ip, 1, -1)
585
+ .expand(f_ip, h_ip, w_ip, -1),
586
+ self.freqs[2][w + offset : w + offset + w_ip]
587
+ .view(1, 1, w_ip, -1)
588
+ .expand(f_ip, h_ip, w_ip, -1),
589
+ ],
590
+ dim=-1,
591
+ )
592
+ .reshape(f_ip * h_ip * w_ip, 1, -1)
593
+ .to(x_ip.device)
594
+ )
595
+ freqs = torch.cat([freqs, freqs_ip], dim=0)
596
+
597
+ ############################################################################################
598
+ def create_custom_forward(module):
599
+ def custom_forward(*inputs):
600
+ return module(*inputs)
601
+
602
+ return custom_forward
603
+
604
+ for block in self.blocks:
605
+ if self.training and use_gradient_checkpointing:
606
+ if use_gradient_checkpointing_offload:
607
+ with torch.autograd.graph.save_on_cpu():
608
+ x, x_ip = torch.utils.checkpoint.checkpoint(
609
+ create_custom_forward(block),
610
+ x,
611
+ context,
612
+ t_mod,
613
+ freqs,
614
+ x_ip,
615
+ t_mod_ip,
616
+ use_reentrant=False,
617
+ )
618
+ else:
619
+ x, x_ip = torch.utils.checkpoint.checkpoint(
620
+ create_custom_forward(block),
621
+ x,
622
+ context,
623
+ t_mod,
624
+ freqs,
625
+ x_ip,
626
+ t_mod_ip,
627
+ use_reentrant=False,
628
+ )
629
+ else:
630
+ x, x_ip = block(x, context, t_mod, freqs, x_ip, t_mod_ip)
631
+
632
+ x = self.head(x, t)
633
+ x = self.unpatchify(x, (f, h, w))
634
+ return x
635
+
636
+ @staticmethod
637
+ def state_dict_converter():
638
+ return WanModelStateDictConverter()
639
+
640
+
641
+ class WanModelStateDictConverter:
642
+ def __init__(self):
643
+ pass
644
+
645
+ def from_diffusers(self, state_dict):
646
+ rename_dict = {
647
+ "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
648
+ "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
649
+ "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
650
+ "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
651
+ "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
652
+ "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
653
+ "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
654
+ "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
655
+ "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
656
+ "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
657
+ "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
658
+ "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
659
+ "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
660
+ "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
661
+ "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
662
+ "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
663
+ "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
664
+ "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
665
+ "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
666
+ "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
667
+ "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
668
+ "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
669
+ "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
670
+ "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
671
+ "blocks.0.norm2.bias": "blocks.0.norm3.bias",
672
+ "blocks.0.norm2.weight": "blocks.0.norm3.weight",
673
+ "blocks.0.scale_shift_table": "blocks.0.modulation",
674
+ "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
675
+ "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
676
+ "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
677
+ "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
678
+ "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
679
+ "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
680
+ "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
681
+ "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
682
+ "condition_embedder.time_proj.bias": "time_projection.1.bias",
683
+ "condition_embedder.time_proj.weight": "time_projection.1.weight",
684
+ "patch_embedding.bias": "patch_embedding.bias",
685
+ "patch_embedding.weight": "patch_embedding.weight",
686
+ "scale_shift_table": "head.modulation",
687
+ "proj_out.bias": "head.head.bias",
688
+ "proj_out.weight": "head.head.weight",
689
+ }
690
+ state_dict_ = {}
691
+ for name, param in state_dict.items():
692
+ if name in rename_dict:
693
+ state_dict_[rename_dict[name]] = param
694
+ else:
695
+ name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
696
+ if name_ in rename_dict:
697
+ name_ = rename_dict[name_]
698
+ name_ = ".".join(
699
+ name_.split(".")[:1]
700
+ + [name.split(".")[1]]
701
+ + name_.split(".")[2:]
702
+ )
703
+ state_dict_[name_] = param
704
+ if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
705
+ config = {
706
+ "model_type": "t2v",
707
+ "patch_size": (1, 2, 2),
708
+ "text_len": 512,
709
+ "in_dim": 16,
710
+ "dim": 5120,
711
+ "ffn_dim": 13824,
712
+ "freq_dim": 256,
713
+ "text_dim": 4096,
714
+ "out_dim": 16,
715
+ "num_heads": 40,
716
+ "num_layers": 40,
717
+ "window_size": (-1, -1),
718
+ "qk_norm": True,
719
+ "cross_attn_norm": True,
720
+ "eps": 1e-6,
721
+ }
722
+ else:
723
+ config = {}
724
+ return state_dict_, config
725
+
726
+ def from_civitai(self, state_dict):
727
+ state_dict = {
728
+ name: param
729
+ for name, param in state_dict.items()
730
+ if not name.startswith("vace")
731
+ }
732
+ if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
733
+ config = {
734
+ "has_image_input": False,
735
+ "patch_size": [1, 2, 2],
736
+ "in_dim": 16,
737
+ "dim": 1536,
738
+ "ffn_dim": 8960,
739
+ "freq_dim": 256,
740
+ "text_dim": 4096,
741
+ "out_dim": 16,
742
+ "num_heads": 12,
743
+ "num_layers": 30,
744
+ "eps": 1e-6,
745
+ }
746
+ elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
747
+ config = {
748
+ "has_image_input": False,
749
+ "patch_size": [1, 2, 2],
750
+ "in_dim": 16,
751
+ "dim": 5120,
752
+ "ffn_dim": 13824,
753
+ "freq_dim": 256,
754
+ "text_dim": 4096,
755
+ "out_dim": 16,
756
+ "num_heads": 40,
757
+ "num_layers": 40,
758
+ "eps": 1e-6,
759
+ }
760
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
761
+ config = {
762
+ "has_image_input": True,
763
+ "patch_size": [1, 2, 2],
764
+ "in_dim": 36,
765
+ "dim": 5120,
766
+ "ffn_dim": 13824,
767
+ "freq_dim": 256,
768
+ "text_dim": 4096,
769
+ "out_dim": 16,
770
+ "num_heads": 40,
771
+ "num_layers": 40,
772
+ "eps": 1e-6,
773
+ }
774
+ elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
775
+ config = {
776
+ "has_image_input": True,
777
+ "patch_size": [1, 2, 2],
778
+ "in_dim": 36,
779
+ "dim": 1536,
780
+ "ffn_dim": 8960,
781
+ "freq_dim": 256,
782
+ "text_dim": 4096,
783
+ "out_dim": 16,
784
+ "num_heads": 12,
785
+ "num_layers": 30,
786
+ "eps": 1e-6,
787
+ }
788
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
789
+ config = {
790
+ "has_image_input": True,
791
+ "patch_size": [1, 2, 2],
792
+ "in_dim": 36,
793
+ "dim": 5120,
794
+ "ffn_dim": 13824,
795
+ "freq_dim": 256,
796
+ "text_dim": 4096,
797
+ "out_dim": 16,
798
+ "num_heads": 40,
799
+ "num_layers": 40,
800
+ "eps": 1e-6,
801
+ }
802
+ elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
803
+ # 1.3B PAI control
804
+ config = {
805
+ "has_image_input": True,
806
+ "patch_size": [1, 2, 2],
807
+ "in_dim": 48,
808
+ "dim": 1536,
809
+ "ffn_dim": 8960,
810
+ "freq_dim": 256,
811
+ "text_dim": 4096,
812
+ "out_dim": 16,
813
+ "num_heads": 12,
814
+ "num_layers": 30,
815
+ "eps": 1e-6,
816
+ }
817
+ elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
818
+ # 14B PAI control
819
+ config = {
820
+ "has_image_input": True,
821
+ "patch_size": [1, 2, 2],
822
+ "in_dim": 48,
823
+ "dim": 5120,
824
+ "ffn_dim": 13824,
825
+ "freq_dim": 256,
826
+ "text_dim": 4096,
827
+ "out_dim": 16,
828
+ "num_heads": 40,
829
+ "num_layers": 40,
830
+ "eps": 1e-6,
831
+ }
832
+ elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
833
+ config = {
834
+ "has_image_input": True,
835
+ "patch_size": [1, 2, 2],
836
+ "in_dim": 36,
837
+ "dim": 5120,
838
+ "ffn_dim": 13824,
839
+ "freq_dim": 256,
840
+ "text_dim": 4096,
841
+ "out_dim": 16,
842
+ "num_heads": 40,
843
+ "num_layers": 40,
844
+ "eps": 1e-6,
845
+ "has_image_pos_emb": True,
846
+ }
847
+ elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
848
+ # 1.3B PAI control v1.1
849
+ config = {
850
+ "has_image_input": True,
851
+ "patch_size": [1, 2, 2],
852
+ "in_dim": 48,
853
+ "dim": 1536,
854
+ "ffn_dim": 8960,
855
+ "freq_dim": 256,
856
+ "text_dim": 4096,
857
+ "out_dim": 16,
858
+ "num_heads": 12,
859
+ "num_layers": 30,
860
+ "eps": 1e-6,
861
+ "has_ref_conv": True,
862
+ }
863
+ elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
864
+ # 14B PAI control v1.1
865
+ config = {
866
+ "has_image_input": True,
867
+ "patch_size": [1, 2, 2],
868
+ "in_dim": 48,
869
+ "dim": 5120,
870
+ "ffn_dim": 13824,
871
+ "freq_dim": 256,
872
+ "text_dim": 4096,
873
+ "out_dim": 16,
874
+ "num_heads": 40,
875
+ "num_layers": 40,
876
+ "eps": 1e-6,
877
+ "has_ref_conv": True,
878
+ }
879
+ elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
880
+ # 1.3B PAI control-camera v1.1
881
+ config = {
882
+ "has_image_input": True,
883
+ "patch_size": [1, 2, 2],
884
+ "in_dim": 32,
885
+ "dim": 1536,
886
+ "ffn_dim": 8960,
887
+ "freq_dim": 256,
888
+ "text_dim": 4096,
889
+ "out_dim": 16,
890
+ "num_heads": 12,
891
+ "num_layers": 30,
892
+ "eps": 1e-6,
893
+ "has_ref_conv": False,
894
+ "add_control_adapter": True,
895
+ "in_dim_control_adapter": 24,
896
+ }
897
+ elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
898
+ # 14B PAI control-camera v1.1
899
+ config = {
900
+ "has_image_input": True,
901
+ "patch_size": [1, 2, 2],
902
+ "in_dim": 32,
903
+ "dim": 5120,
904
+ "ffn_dim": 13824,
905
+ "freq_dim": 256,
906
+ "text_dim": 4096,
907
+ "out_dim": 16,
908
+ "num_heads": 40,
909
+ "num_layers": 40,
910
+ "eps": 1e-6,
911
+ "has_ref_conv": False,
912
+ "add_control_adapter": True,
913
+ "in_dim_control_adapter": 24,
914
+ }
915
+ elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
916
+ # Wan-AI/Wan2.2-TI2V-5B
917
+ config = {
918
+ "has_image_input": False,
919
+ "patch_size": [1, 2, 2],
920
+ "in_dim": 48,
921
+ "dim": 3072,
922
+ "ffn_dim": 14336,
923
+ "freq_dim": 256,
924
+ "text_dim": 4096,
925
+ "out_dim": 48,
926
+ "num_heads": 24,
927
+ "num_layers": 30,
928
+ "eps": 1e-6,
929
+ "seperated_timestep": True,
930
+ "require_clip_embedding": False,
931
+ "require_vae_embedding": False,
932
+ "fuse_vae_embedding_in_latents": True,
933
+ }
934
+ elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
935
+ # Wan-AI/Wan2.2-I2V-A14B
936
+ config = {
937
+ "has_image_input": False,
938
+ "patch_size": [1, 2, 2],
939
+ "in_dim": 36,
940
+ "dim": 5120,
941
+ "ffn_dim": 13824,
942
+ "freq_dim": 256,
943
+ "text_dim": 4096,
944
+ "out_dim": 16,
945
+ "num_heads": 40,
946
+ "num_layers": 40,
947
+ "eps": 1e-6,
948
+ "require_clip_embedding": False,
949
+ }
950
+ else:
951
+ config = {}
952
+ return state_dict, config
models/wan_video_image_encoder.py ADDED
@@ -0,0 +1,957 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Concise re-implementation of
3
+ ``https://github.com/openai/CLIP'' and
4
+ ``https://github.com/mlfoundations/open_clip''.
5
+ """
6
+
7
+ import math
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms as T
12
+ from .wan_video_dit import flash_attention
13
+
14
+
15
+ class SelfAttention(nn.Module):
16
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
17
+ assert dim % num_heads == 0
18
+ super().__init__()
19
+ self.dim = dim
20
+ self.num_heads = num_heads
21
+ self.head_dim = dim // num_heads
22
+ self.eps = eps
23
+
24
+ # layers
25
+ self.q = nn.Linear(dim, dim)
26
+ self.k = nn.Linear(dim, dim)
27
+ self.v = nn.Linear(dim, dim)
28
+ self.o = nn.Linear(dim, dim)
29
+ self.dropout = nn.Dropout(dropout)
30
+
31
+ def forward(self, x, mask):
32
+ """
33
+ x: [B, L, C].
34
+ """
35
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
36
+
37
+ # compute query, key, value
38
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
39
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
40
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
41
+
42
+ # compute attention
43
+ p = self.dropout.p if self.training else 0.0
44
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
45
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
46
+
47
+ # output
48
+ x = self.o(x)
49
+ x = self.dropout(x)
50
+ return x
51
+
52
+
53
+ class AttentionBlock(nn.Module):
54
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
55
+ super().__init__()
56
+ self.dim = dim
57
+ self.num_heads = num_heads
58
+ self.post_norm = post_norm
59
+ self.eps = eps
60
+
61
+ # layers
62
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
63
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
64
+ self.ffn = nn.Sequential(
65
+ nn.Linear(dim, dim * 4),
66
+ nn.GELU(),
67
+ nn.Linear(dim * 4, dim),
68
+ nn.Dropout(dropout),
69
+ )
70
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
71
+
72
+ def forward(self, x, mask):
73
+ if self.post_norm:
74
+ x = self.norm1(x + self.attn(x, mask))
75
+ x = self.norm2(x + self.ffn(x))
76
+ else:
77
+ x = x + self.attn(self.norm1(x), mask)
78
+ x = x + self.ffn(self.norm2(x))
79
+ return x
80
+
81
+
82
+ class XLMRoberta(nn.Module):
83
+ """
84
+ XLMRobertaModel with no pooler and no LM head.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ vocab_size=250002,
90
+ max_seq_len=514,
91
+ type_size=1,
92
+ pad_id=1,
93
+ dim=1024,
94
+ num_heads=16,
95
+ num_layers=24,
96
+ post_norm=True,
97
+ dropout=0.1,
98
+ eps=1e-5,
99
+ ):
100
+ super().__init__()
101
+ self.vocab_size = vocab_size
102
+ self.max_seq_len = max_seq_len
103
+ self.type_size = type_size
104
+ self.pad_id = pad_id
105
+ self.dim = dim
106
+ self.num_heads = num_heads
107
+ self.num_layers = num_layers
108
+ self.post_norm = post_norm
109
+ self.eps = eps
110
+
111
+ # embeddings
112
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
113
+ self.type_embedding = nn.Embedding(type_size, dim)
114
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
115
+ self.dropout = nn.Dropout(dropout)
116
+
117
+ # blocks
118
+ self.blocks = nn.ModuleList(
119
+ [
120
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
121
+ for _ in range(num_layers)
122
+ ]
123
+ )
124
+
125
+ # norm layer
126
+ self.norm = nn.LayerNorm(dim, eps=eps)
127
+
128
+ def forward(self, ids):
129
+ """
130
+ ids: [B, L] of torch.LongTensor.
131
+ """
132
+ b, s = ids.shape
133
+ mask = ids.ne(self.pad_id).long()
134
+
135
+ # embeddings
136
+ x = (
137
+ self.token_embedding(ids)
138
+ + self.type_embedding(torch.zeros_like(ids))
139
+ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
140
+ )
141
+ if self.post_norm:
142
+ x = self.norm(x)
143
+ x = self.dropout(x)
144
+
145
+ # blocks
146
+ mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
147
+ for block in self.blocks:
148
+ x = block(x, mask)
149
+
150
+ # output
151
+ if not self.post_norm:
152
+ x = self.norm(x)
153
+ return x
154
+
155
+
156
+ def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
157
+ """
158
+ XLMRobertaLarge adapted from Huggingface.
159
+ """
160
+ # params
161
+ cfg = dict(
162
+ vocab_size=250002,
163
+ max_seq_len=514,
164
+ type_size=1,
165
+ pad_id=1,
166
+ dim=1024,
167
+ num_heads=16,
168
+ num_layers=24,
169
+ post_norm=True,
170
+ dropout=0.1,
171
+ eps=1e-5,
172
+ )
173
+ cfg.update(**kwargs)
174
+
175
+ # init model
176
+ if pretrained:
177
+ from sora import DOWNLOAD_TO_CACHE
178
+
179
+ # init a meta model
180
+ with torch.device("meta"):
181
+ model = XLMRoberta(**cfg)
182
+
183
+ # load checkpoint
184
+ model.load_state_dict(
185
+ torch.load(
186
+ DOWNLOAD_TO_CACHE("models/xlm_roberta/xlm_roberta_large.pth"),
187
+ map_location=device,
188
+ ),
189
+ assign=True,
190
+ )
191
+ else:
192
+ # init a model on device
193
+ with torch.device(device):
194
+ model = XLMRoberta(**cfg)
195
+
196
+ # init tokenizer
197
+ if return_tokenizer:
198
+ from sora.data import HuggingfaceTokenizer
199
+
200
+ tokenizer = HuggingfaceTokenizer(
201
+ name="xlm-roberta-large", seq_len=model.text_len, clean="whitespace"
202
+ )
203
+ return model, tokenizer
204
+ else:
205
+ return model
206
+
207
+
208
+ def pos_interpolate(pos, seq_len):
209
+ if pos.size(1) == seq_len:
210
+ return pos
211
+ else:
212
+ src_grid = int(math.sqrt(pos.size(1)))
213
+ tar_grid = int(math.sqrt(seq_len))
214
+ n = pos.size(1) - src_grid * src_grid
215
+ return torch.cat(
216
+ [
217
+ pos[:, :n],
218
+ F.interpolate(
219
+ pos[:, n:]
220
+ .float()
221
+ .reshape(1, src_grid, src_grid, -1)
222
+ .permute(0, 3, 1, 2),
223
+ size=(tar_grid, tar_grid),
224
+ mode="bicubic",
225
+ align_corners=False,
226
+ )
227
+ .flatten(2)
228
+ .transpose(1, 2),
229
+ ],
230
+ dim=1,
231
+ )
232
+
233
+
234
+ class QuickGELU(nn.Module):
235
+ def forward(self, x):
236
+ return x * torch.sigmoid(1.702 * x)
237
+
238
+
239
+ class LayerNorm(nn.LayerNorm):
240
+ def forward(self, x):
241
+ return super().forward(x).type_as(x)
242
+
243
+
244
+ class SelfAttention(nn.Module):
245
+ def __init__(
246
+ self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0
247
+ ):
248
+ assert dim % num_heads == 0
249
+ super().__init__()
250
+ self.dim = dim
251
+ self.num_heads = num_heads
252
+ self.head_dim = dim // num_heads
253
+ self.causal = causal
254
+ self.attn_dropout = attn_dropout
255
+ self.proj_dropout = proj_dropout
256
+
257
+ # layers
258
+ self.to_qkv = nn.Linear(dim, dim * 3)
259
+ self.proj = nn.Linear(dim, dim)
260
+
261
+ def forward(self, x):
262
+ """
263
+ x: [B, L, C].
264
+ """
265
+ # compute query, key, value
266
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
267
+
268
+ # compute attention
269
+ x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
270
+
271
+ # output
272
+ x = self.proj(x)
273
+ x = F.dropout(x, self.proj_dropout, self.training)
274
+ return x
275
+
276
+
277
+ class SwiGLU(nn.Module):
278
+ def __init__(self, dim, mid_dim):
279
+ super().__init__()
280
+ self.dim = dim
281
+ self.mid_dim = mid_dim
282
+
283
+ # layers
284
+ self.fc1 = nn.Linear(dim, mid_dim)
285
+ self.fc2 = nn.Linear(dim, mid_dim)
286
+ self.fc3 = nn.Linear(mid_dim, dim)
287
+
288
+ def forward(self, x):
289
+ x = F.silu(self.fc1(x)) * self.fc2(x)
290
+ x = self.fc3(x)
291
+ return x
292
+
293
+
294
+ class AttentionBlock(nn.Module):
295
+ def __init__(
296
+ self,
297
+ dim,
298
+ mlp_ratio,
299
+ num_heads,
300
+ post_norm=False,
301
+ causal=False,
302
+ activation="quick_gelu",
303
+ attn_dropout=0.0,
304
+ proj_dropout=0.0,
305
+ norm_eps=1e-5,
306
+ ):
307
+ assert activation in ["quick_gelu", "gelu", "swi_glu"]
308
+ super().__init__()
309
+ self.dim = dim
310
+ self.mlp_ratio = mlp_ratio
311
+ self.num_heads = num_heads
312
+ self.post_norm = post_norm
313
+ self.causal = causal
314
+ self.norm_eps = norm_eps
315
+
316
+ # layers
317
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
318
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
319
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
320
+ if activation == "swi_glu":
321
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
322
+ else:
323
+ self.mlp = nn.Sequential(
324
+ nn.Linear(dim, int(dim * mlp_ratio)),
325
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
326
+ nn.Linear(int(dim * mlp_ratio), dim),
327
+ nn.Dropout(proj_dropout),
328
+ )
329
+
330
+ def forward(self, x):
331
+ if self.post_norm:
332
+ x = x + self.norm1(self.attn(x))
333
+ x = x + self.norm2(self.mlp(x))
334
+ else:
335
+ x = x + self.attn(self.norm1(x))
336
+ x = x + self.mlp(self.norm2(x))
337
+ return x
338
+
339
+
340
+ class AttentionPool(nn.Module):
341
+ def __init__(
342
+ self,
343
+ dim,
344
+ mlp_ratio,
345
+ num_heads,
346
+ activation="gelu",
347
+ proj_dropout=0.0,
348
+ norm_eps=1e-5,
349
+ ):
350
+ assert dim % num_heads == 0
351
+ super().__init__()
352
+ self.dim = dim
353
+ self.mlp_ratio = mlp_ratio
354
+ self.num_heads = num_heads
355
+ self.head_dim = dim // num_heads
356
+ self.proj_dropout = proj_dropout
357
+ self.norm_eps = norm_eps
358
+
359
+ # layers
360
+ gain = 1.0 / math.sqrt(dim)
361
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
362
+ self.to_q = nn.Linear(dim, dim)
363
+ self.to_kv = nn.Linear(dim, dim * 2)
364
+ self.proj = nn.Linear(dim, dim)
365
+ self.norm = LayerNorm(dim, eps=norm_eps)
366
+ self.mlp = nn.Sequential(
367
+ nn.Linear(dim, int(dim * mlp_ratio)),
368
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
369
+ nn.Linear(int(dim * mlp_ratio), dim),
370
+ nn.Dropout(proj_dropout),
371
+ )
372
+
373
+ def forward(self, x):
374
+ """
375
+ x: [B, L, C].
376
+ """
377
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
378
+
379
+ # compute query, key, value
380
+ q = self.to_q(self.cls_embedding).view(1, 1, n * d).expand(b, -1, -1)
381
+ k, v = self.to_kv(x).chunk(2, dim=-1)
382
+
383
+ # compute attention
384
+ x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
385
+ x = x.reshape(b, 1, c)
386
+
387
+ # output
388
+ x = self.proj(x)
389
+ x = F.dropout(x, self.proj_dropout, self.training)
390
+
391
+ # mlp
392
+ x = x + self.mlp(self.norm(x))
393
+ return x[:, 0]
394
+
395
+
396
+ class VisionTransformer(nn.Module):
397
+ def __init__(
398
+ self,
399
+ image_size=224,
400
+ patch_size=16,
401
+ dim=768,
402
+ mlp_ratio=4,
403
+ out_dim=512,
404
+ num_heads=12,
405
+ num_layers=12,
406
+ pool_type="token",
407
+ pre_norm=True,
408
+ post_norm=False,
409
+ activation="quick_gelu",
410
+ attn_dropout=0.0,
411
+ proj_dropout=0.0,
412
+ embedding_dropout=0.0,
413
+ norm_eps=1e-5,
414
+ ):
415
+ if image_size % patch_size != 0:
416
+ print("[WARNING] image_size is not divisible by patch_size", flush=True)
417
+ assert pool_type in ("token", "token_fc", "attn_pool")
418
+ out_dim = out_dim or dim
419
+ super().__init__()
420
+ self.image_size = image_size
421
+ self.patch_size = patch_size
422
+ self.num_patches = (image_size // patch_size) ** 2
423
+ self.dim = dim
424
+ self.mlp_ratio = mlp_ratio
425
+ self.out_dim = out_dim
426
+ self.num_heads = num_heads
427
+ self.num_layers = num_layers
428
+ self.pool_type = pool_type
429
+ self.post_norm = post_norm
430
+ self.norm_eps = norm_eps
431
+
432
+ # embeddings
433
+ gain = 1.0 / math.sqrt(dim)
434
+ self.patch_embedding = nn.Conv2d(
435
+ 3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm
436
+ )
437
+ if pool_type in ("token", "token_fc"):
438
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
439
+ self.pos_embedding = nn.Parameter(
440
+ gain
441
+ * torch.randn(
442
+ 1,
443
+ self.num_patches + (1 if pool_type in ("token", "token_fc") else 0),
444
+ dim,
445
+ )
446
+ )
447
+ self.dropout = nn.Dropout(embedding_dropout)
448
+
449
+ # transformer
450
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
451
+ self.transformer = nn.Sequential(
452
+ *[
453
+ AttentionBlock(
454
+ dim,
455
+ mlp_ratio,
456
+ num_heads,
457
+ post_norm,
458
+ False,
459
+ activation,
460
+ attn_dropout,
461
+ proj_dropout,
462
+ norm_eps,
463
+ )
464
+ for _ in range(num_layers)
465
+ ]
466
+ )
467
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
468
+
469
+ # head
470
+ if pool_type == "token":
471
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
472
+ elif pool_type == "token_fc":
473
+ self.head = nn.Linear(dim, out_dim)
474
+ elif pool_type == "attn_pool":
475
+ self.head = AttentionPool(
476
+ dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps
477
+ )
478
+
479
+ def forward(self, x, interpolation=False, use_31_block=False):
480
+ b = x.size(0)
481
+
482
+ # embeddings
483
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
484
+ if self.pool_type in ("token", "token_fc"):
485
+ x = torch.cat(
486
+ [
487
+ self.cls_embedding.expand(b, -1, -1).to(
488
+ dtype=x.dtype, device=x.device
489
+ ),
490
+ x,
491
+ ],
492
+ dim=1,
493
+ )
494
+ if interpolation:
495
+ e = pos_interpolate(self.pos_embedding, x.size(1))
496
+ else:
497
+ e = self.pos_embedding
498
+ e = e.to(dtype=x.dtype, device=x.device)
499
+ x = self.dropout(x + e)
500
+ if self.pre_norm is not None:
501
+ x = self.pre_norm(x)
502
+
503
+ # transformer
504
+ if use_31_block:
505
+ x = self.transformer[:-1](x)
506
+ return x
507
+ else:
508
+ x = self.transformer(x)
509
+ return x
510
+
511
+
512
+ class CLIP(nn.Module):
513
+ def __init__(
514
+ self,
515
+ embed_dim=512,
516
+ image_size=224,
517
+ patch_size=16,
518
+ vision_dim=768,
519
+ vision_mlp_ratio=4,
520
+ vision_heads=12,
521
+ vision_layers=12,
522
+ vision_pool="token",
523
+ vision_pre_norm=True,
524
+ vision_post_norm=False,
525
+ vocab_size=49408,
526
+ text_len=77,
527
+ text_dim=512,
528
+ text_mlp_ratio=4,
529
+ text_heads=8,
530
+ text_layers=12,
531
+ text_causal=True,
532
+ text_pool="argmax",
533
+ text_head_bias=False,
534
+ logit_bias=None,
535
+ activation="quick_gelu",
536
+ attn_dropout=0.0,
537
+ proj_dropout=0.0,
538
+ embedding_dropout=0.0,
539
+ norm_eps=1e-5,
540
+ ):
541
+ super().__init__()
542
+ self.embed_dim = embed_dim
543
+ self.image_size = image_size
544
+ self.patch_size = patch_size
545
+ self.vision_dim = vision_dim
546
+ self.vision_mlp_ratio = vision_mlp_ratio
547
+ self.vision_heads = vision_heads
548
+ self.vision_layers = vision_layers
549
+ self.vision_pool = vision_pool
550
+ self.vision_pre_norm = vision_pre_norm
551
+ self.vision_post_norm = vision_post_norm
552
+ self.vocab_size = vocab_size
553
+ self.text_len = text_len
554
+ self.text_dim = text_dim
555
+ self.text_mlp_ratio = text_mlp_ratio
556
+ self.text_heads = text_heads
557
+ self.text_layers = text_layers
558
+ self.text_causal = text_causal
559
+ self.text_pool = text_pool
560
+ self.text_head_bias = text_head_bias
561
+ self.norm_eps = norm_eps
562
+
563
+ # models
564
+ self.visual = VisionTransformer(
565
+ image_size=image_size,
566
+ patch_size=patch_size,
567
+ dim=vision_dim,
568
+ mlp_ratio=vision_mlp_ratio,
569
+ out_dim=embed_dim,
570
+ num_heads=vision_heads,
571
+ num_layers=vision_layers,
572
+ pool_type=vision_pool,
573
+ pre_norm=vision_pre_norm,
574
+ post_norm=vision_post_norm,
575
+ activation=activation,
576
+ attn_dropout=attn_dropout,
577
+ proj_dropout=proj_dropout,
578
+ embedding_dropout=embedding_dropout,
579
+ norm_eps=norm_eps,
580
+ )
581
+ self.textual = TextTransformer(
582
+ vocab_size=vocab_size,
583
+ text_len=text_len,
584
+ dim=text_dim,
585
+ mlp_ratio=text_mlp_ratio,
586
+ out_dim=embed_dim,
587
+ num_heads=text_heads,
588
+ num_layers=text_layers,
589
+ causal=text_causal,
590
+ pool_type=text_pool,
591
+ head_bias=text_head_bias,
592
+ activation=activation,
593
+ attn_dropout=attn_dropout,
594
+ proj_dropout=proj_dropout,
595
+ embedding_dropout=embedding_dropout,
596
+ norm_eps=norm_eps,
597
+ )
598
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
599
+ if logit_bias is not None:
600
+ self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
601
+
602
+ # initialize weights
603
+ self.init_weights()
604
+
605
+ def forward(self, imgs, txt_ids):
606
+ """
607
+ imgs: [B, 3, H, W] of torch.float32.
608
+ - mean: [0.48145466, 0.4578275, 0.40821073]
609
+ - std: [0.26862954, 0.26130258, 0.27577711]
610
+ txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
611
+ """
612
+ xi = self.visual(imgs)
613
+ xt = self.textual(txt_ids)
614
+ return xi, xt
615
+
616
+ def init_weights(self):
617
+ # embeddings
618
+ nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
619
+ nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
620
+
621
+ # attentions
622
+ for modality in ["visual", "textual"]:
623
+ dim = self.vision_dim if modality == "visual" else self.text_dim
624
+ transformer = getattr(self, modality).transformer
625
+ proj_gain = (1.0 / math.sqrt(dim)) * (1.0 / math.sqrt(2 * len(transformer)))
626
+ attn_gain = 1.0 / math.sqrt(dim)
627
+ mlp_gain = 1.0 / math.sqrt(2.0 * dim)
628
+ for block in transformer:
629
+ nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
630
+ nn.init.normal_(block.attn.proj.weight, std=proj_gain)
631
+ nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
632
+ nn.init.normal_(block.mlp[2].weight, std=proj_gain)
633
+
634
+ def param_groups(self):
635
+ groups = [
636
+ {
637
+ "params": [
638
+ p
639
+ for n, p in self.named_parameters()
640
+ if "norm" in n or n.endswith("bias")
641
+ ],
642
+ "weight_decay": 0.0,
643
+ },
644
+ {
645
+ "params": [
646
+ p
647
+ for n, p in self.named_parameters()
648
+ if not ("norm" in n or n.endswith("bias"))
649
+ ]
650
+ },
651
+ ]
652
+ return groups
653
+
654
+
655
+ class XLMRobertaWithHead(XLMRoberta):
656
+ def __init__(self, **kwargs):
657
+ self.out_dim = kwargs.pop("out_dim")
658
+ super().__init__(**kwargs)
659
+
660
+ # head
661
+ mid_dim = (self.dim + self.out_dim) // 2
662
+ self.head = nn.Sequential(
663
+ nn.Linear(self.dim, mid_dim, bias=False),
664
+ nn.GELU(),
665
+ nn.Linear(mid_dim, self.out_dim, bias=False),
666
+ )
667
+
668
+ def forward(self, ids):
669
+ # xlm-roberta
670
+ x = super().forward(ids)
671
+
672
+ # average pooling
673
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
674
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
675
+
676
+ # head
677
+ x = self.head(x)
678
+ return x
679
+
680
+
681
+ class XLMRobertaCLIP(nn.Module):
682
+ def __init__(
683
+ self,
684
+ embed_dim=1024,
685
+ image_size=224,
686
+ patch_size=14,
687
+ vision_dim=1280,
688
+ vision_mlp_ratio=4,
689
+ vision_heads=16,
690
+ vision_layers=32,
691
+ vision_pool="token",
692
+ vision_pre_norm=True,
693
+ vision_post_norm=False,
694
+ activation="gelu",
695
+ vocab_size=250002,
696
+ max_text_len=514,
697
+ type_size=1,
698
+ pad_id=1,
699
+ text_dim=1024,
700
+ text_heads=16,
701
+ text_layers=24,
702
+ text_post_norm=True,
703
+ text_dropout=0.1,
704
+ attn_dropout=0.0,
705
+ proj_dropout=0.0,
706
+ embedding_dropout=0.0,
707
+ norm_eps=1e-5,
708
+ ):
709
+ super().__init__()
710
+ self.embed_dim = embed_dim
711
+ self.image_size = image_size
712
+ self.patch_size = patch_size
713
+ self.vision_dim = vision_dim
714
+ self.vision_mlp_ratio = vision_mlp_ratio
715
+ self.vision_heads = vision_heads
716
+ self.vision_layers = vision_layers
717
+ self.vision_pre_norm = vision_pre_norm
718
+ self.vision_post_norm = vision_post_norm
719
+ self.activation = activation
720
+ self.vocab_size = vocab_size
721
+ self.max_text_len = max_text_len
722
+ self.type_size = type_size
723
+ self.pad_id = pad_id
724
+ self.text_dim = text_dim
725
+ self.text_heads = text_heads
726
+ self.text_layers = text_layers
727
+ self.text_post_norm = text_post_norm
728
+ self.norm_eps = norm_eps
729
+
730
+ # models
731
+ self.visual = VisionTransformer(
732
+ image_size=image_size,
733
+ patch_size=patch_size,
734
+ dim=vision_dim,
735
+ mlp_ratio=vision_mlp_ratio,
736
+ out_dim=embed_dim,
737
+ num_heads=vision_heads,
738
+ num_layers=vision_layers,
739
+ pool_type=vision_pool,
740
+ pre_norm=vision_pre_norm,
741
+ post_norm=vision_post_norm,
742
+ activation=activation,
743
+ attn_dropout=attn_dropout,
744
+ proj_dropout=proj_dropout,
745
+ embedding_dropout=embedding_dropout,
746
+ norm_eps=norm_eps,
747
+ )
748
+ self.textual = None
749
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
750
+
751
+ def forward(self, imgs, txt_ids):
752
+ """
753
+ imgs: [B, 3, H, W] of torch.float32.
754
+ - mean: [0.48145466, 0.4578275, 0.40821073]
755
+ - std: [0.26862954, 0.26130258, 0.27577711]
756
+ txt_ids: [B, L] of torch.long.
757
+ Encoded by data.CLIPTokenizer.
758
+ """
759
+ xi = self.visual(imgs)
760
+ xt = self.textual(txt_ids)
761
+ return xi, xt
762
+
763
+ def param_groups(self):
764
+ groups = [
765
+ {
766
+ "params": [
767
+ p
768
+ for n, p in self.named_parameters()
769
+ if "norm" in n or n.endswith("bias")
770
+ ],
771
+ "weight_decay": 0.0,
772
+ },
773
+ {
774
+ "params": [
775
+ p
776
+ for n, p in self.named_parameters()
777
+ if not ("norm" in n or n.endswith("bias"))
778
+ ]
779
+ },
780
+ ]
781
+ return groups
782
+
783
+
784
+ def _clip(
785
+ pretrained=False,
786
+ pretrained_name=None,
787
+ model_cls=CLIP,
788
+ return_transforms=False,
789
+ return_tokenizer=False,
790
+ tokenizer_padding="eos",
791
+ dtype=torch.float32,
792
+ device="cpu",
793
+ **kwargs,
794
+ ):
795
+ # init model
796
+ if pretrained and pretrained_name:
797
+ from sora import BUCKET, DOWNLOAD_TO_CACHE
798
+
799
+ # init a meta model
800
+ with torch.device("meta"):
801
+ model = model_cls(**kwargs)
802
+
803
+ # checkpoint path
804
+ checkpoint = f"models/clip/{pretrained_name}"
805
+ if dtype in (torch.float16, torch.bfloat16):
806
+ suffix = "-" + {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype]
807
+ if object_exists(BUCKET, f"{checkpoint}{suffix}.pth"):
808
+ checkpoint = f"{checkpoint}{suffix}"
809
+ checkpoint += ".pth"
810
+
811
+ # load
812
+ model.load_state_dict(
813
+ torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
814
+ assign=True,
815
+ strict=False,
816
+ )
817
+ else:
818
+ # init a model on device
819
+ with torch.device(device):
820
+ model = model_cls(**kwargs)
821
+
822
+ # set device
823
+ output = (model,)
824
+
825
+ # init transforms
826
+ if return_transforms:
827
+ # mean and std
828
+ if "siglip" in pretrained_name.lower():
829
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
830
+ else:
831
+ mean = [0.48145466, 0.4578275, 0.40821073]
832
+ std = [0.26862954, 0.26130258, 0.27577711]
833
+
834
+ # transforms
835
+ transforms = T.Compose(
836
+ [
837
+ T.Resize(
838
+ (model.image_size, model.image_size),
839
+ interpolation=T.InterpolationMode.BICUBIC,
840
+ ),
841
+ T.ToTensor(),
842
+ T.Normalize(mean=mean, std=std),
843
+ ]
844
+ )
845
+ output += (transforms,)
846
+
847
+ # init tokenizer
848
+ if return_tokenizer:
849
+ from sora import data
850
+
851
+ if "siglip" in pretrained_name.lower():
852
+ tokenizer = data.HuggingfaceTokenizer(
853
+ name=f"timm/{pretrained_name}",
854
+ seq_len=model.text_len,
855
+ clean="canonicalize",
856
+ )
857
+ elif "xlm" in pretrained_name.lower():
858
+ tokenizer = data.HuggingfaceTokenizer(
859
+ name="xlm-roberta-large",
860
+ seq_len=model.max_text_len - 2,
861
+ clean="whitespace",
862
+ )
863
+ elif "mba" in pretrained_name.lower():
864
+ tokenizer = data.HuggingfaceTokenizer(
865
+ name="facebook/xlm-roberta-xl",
866
+ seq_len=model.max_text_len - 2,
867
+ clean="whitespace",
868
+ )
869
+ else:
870
+ tokenizer = data.CLIPTokenizer(
871
+ seq_len=model.text_len, padding=tokenizer_padding
872
+ )
873
+ output += (tokenizer,)
874
+ return output[0] if len(output) == 1 else output
875
+
876
+
877
+ def clip_xlm_roberta_vit_h_14(
878
+ pretrained=False,
879
+ pretrained_name="open-clip-xlm-roberta-large-vit-huge-14",
880
+ **kwargs,
881
+ ):
882
+ cfg = dict(
883
+ embed_dim=1024,
884
+ image_size=224,
885
+ patch_size=14,
886
+ vision_dim=1280,
887
+ vision_mlp_ratio=4,
888
+ vision_heads=16,
889
+ vision_layers=32,
890
+ vision_pool="token",
891
+ activation="gelu",
892
+ vocab_size=250002,
893
+ max_text_len=514,
894
+ type_size=1,
895
+ pad_id=1,
896
+ text_dim=1024,
897
+ text_heads=16,
898
+ text_layers=24,
899
+ text_post_norm=True,
900
+ text_dropout=0.1,
901
+ attn_dropout=0.0,
902
+ proj_dropout=0.0,
903
+ embedding_dropout=0.0,
904
+ )
905
+ cfg.update(**kwargs)
906
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
907
+
908
+
909
+ class WanImageEncoder(torch.nn.Module):
910
+ def __init__(self):
911
+ super().__init__()
912
+ # init model
913
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
914
+ pretrained=False,
915
+ return_transforms=True,
916
+ return_tokenizer=False,
917
+ dtype=torch.float32,
918
+ device="cpu",
919
+ )
920
+
921
+ def encode_image(self, videos):
922
+ # preprocess
923
+ size = (self.model.image_size,) * 2
924
+ videos = torch.cat(
925
+ [
926
+ F.interpolate(u, size=size, mode="bicubic", align_corners=False)
927
+ for u in videos
928
+ ]
929
+ )
930
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
931
+
932
+ # forward
933
+ dtype = next(iter(self.model.visual.parameters())).dtype
934
+ videos = videos.to(dtype)
935
+ out = self.model.visual(videos, use_31_block=True)
936
+ return out
937
+
938
+ @staticmethod
939
+ def state_dict_converter():
940
+ return WanImageEncoderStateDictConverter()
941
+
942
+
943
+ class WanImageEncoderStateDictConverter:
944
+ def __init__(self):
945
+ pass
946
+
947
+ def from_diffusers(self, state_dict):
948
+ return state_dict
949
+
950
+ def from_civitai(self, state_dict):
951
+ state_dict_ = {}
952
+ for name, param in state_dict.items():
953
+ if name.startswith("textual."):
954
+ continue
955
+ name = "model." + name
956
+ state_dict_[name] = param
957
+ return state_dict_
models/wan_video_motion_controller.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .wan_video_dit import sinusoidal_embedding_1d
4
+
5
+
6
+ class WanMotionControllerModel(torch.nn.Module):
7
+ def __init__(self, freq_dim=256, dim=1536):
8
+ super().__init__()
9
+ self.freq_dim = freq_dim
10
+ self.linear = nn.Sequential(
11
+ nn.Linear(freq_dim, dim),
12
+ nn.SiLU(),
13
+ nn.Linear(dim, dim),
14
+ nn.SiLU(),
15
+ nn.Linear(dim, dim * 6),
16
+ )
17
+
18
+ def forward(self, motion_bucket_id):
19
+ emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
20
+ emb = self.linear(emb)
21
+ return emb
22
+
23
+ def init(self):
24
+ state_dict = self.linear[-1].state_dict()
25
+ state_dict = {i: state_dict[i] * 0 for i in state_dict}
26
+ self.linear[-1].load_state_dict(state_dict)
27
+
28
+ @staticmethod
29
+ def state_dict_converter():
30
+ return WanMotionControllerModelDictConverter()
31
+
32
+
33
+ class WanMotionControllerModelDictConverter:
34
+ def __init__(self):
35
+ pass
36
+
37
+ def from_diffusers(self, state_dict):
38
+ return state_dict
39
+
40
+ def from_civitai(self, state_dict):
41
+ return state_dict
models/wan_video_text_encoder.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def fp16_clamp(x):
9
+ if x.dtype == torch.float16 and torch.isinf(x).any():
10
+ clamp = torch.finfo(x.dtype).max - 1000
11
+ x = torch.clamp(x, min=-clamp, max=clamp)
12
+ return x
13
+
14
+
15
+ class GELU(nn.Module):
16
+ def forward(self, x):
17
+ return (
18
+ 0.5
19
+ * x
20
+ * (
21
+ 1.0
22
+ + torch.tanh(
23
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
24
+ )
25
+ )
26
+ )
27
+
28
+
29
+ class T5LayerNorm(nn.Module):
30
+ def __init__(self, dim, eps=1e-6):
31
+ super(T5LayerNorm, self).__init__()
32
+ self.dim = dim
33
+ self.eps = eps
34
+ self.weight = nn.Parameter(torch.ones(dim))
35
+
36
+ def forward(self, x):
37
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
38
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
39
+ x = x.type_as(self.weight)
40
+ return self.weight * x
41
+
42
+
43
+ class T5Attention(nn.Module):
44
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
45
+ assert dim_attn % num_heads == 0
46
+ super(T5Attention, self).__init__()
47
+ self.dim = dim
48
+ self.dim_attn = dim_attn
49
+ self.num_heads = num_heads
50
+ self.head_dim = dim_attn // num_heads
51
+
52
+ # layers
53
+ self.q = nn.Linear(dim, dim_attn, bias=False)
54
+ self.k = nn.Linear(dim, dim_attn, bias=False)
55
+ self.v = nn.Linear(dim, dim_attn, bias=False)
56
+ self.o = nn.Linear(dim_attn, dim, bias=False)
57
+ self.dropout = nn.Dropout(dropout)
58
+
59
+ def forward(self, x, context=None, mask=None, pos_bias=None):
60
+ """
61
+ x: [B, L1, C].
62
+ context: [B, L2, C] or None.
63
+ mask: [B, L2] or [B, L1, L2] or None.
64
+ """
65
+ # check inputs
66
+ context = x if context is None else context
67
+ b, n, c = x.size(0), self.num_heads, self.head_dim
68
+
69
+ # compute query, key, value
70
+ q = self.q(x).view(b, -1, n, c)
71
+ k = self.k(context).view(b, -1, n, c)
72
+ v = self.v(context).view(b, -1, n, c)
73
+
74
+ # attention bias
75
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
76
+ if pos_bias is not None:
77
+ attn_bias += pos_bias
78
+ if mask is not None:
79
+ assert mask.ndim in [2, 3]
80
+ mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
81
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
82
+
83
+ # compute attention (T5 does not use scaling)
84
+ attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
85
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
86
+ x = torch.einsum("bnij,bjnc->binc", attn, v)
87
+
88
+ # output
89
+ x = x.reshape(b, -1, n * c)
90
+ x = self.o(x)
91
+ x = self.dropout(x)
92
+ return x
93
+
94
+
95
+ class T5FeedForward(nn.Module):
96
+ def __init__(self, dim, dim_ffn, dropout=0.1):
97
+ super(T5FeedForward, self).__init__()
98
+ self.dim = dim
99
+ self.dim_ffn = dim_ffn
100
+
101
+ # layers
102
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
103
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
104
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
105
+ self.dropout = nn.Dropout(dropout)
106
+
107
+ def forward(self, x):
108
+ x = self.fc1(x) * self.gate(x)
109
+ x = self.dropout(x)
110
+ x = self.fc2(x)
111
+ x = self.dropout(x)
112
+ return x
113
+
114
+
115
+ class T5SelfAttention(nn.Module):
116
+ def __init__(
117
+ self,
118
+ dim,
119
+ dim_attn,
120
+ dim_ffn,
121
+ num_heads,
122
+ num_buckets,
123
+ shared_pos=True,
124
+ dropout=0.1,
125
+ ):
126
+ super(T5SelfAttention, self).__init__()
127
+ self.dim = dim
128
+ self.dim_attn = dim_attn
129
+ self.dim_ffn = dim_ffn
130
+ self.num_heads = num_heads
131
+ self.num_buckets = num_buckets
132
+ self.shared_pos = shared_pos
133
+
134
+ # layers
135
+ self.norm1 = T5LayerNorm(dim)
136
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
137
+ self.norm2 = T5LayerNorm(dim)
138
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
139
+ self.pos_embedding = (
140
+ None
141
+ if shared_pos
142
+ else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
143
+ )
144
+
145
+ def forward(self, x, mask=None, pos_bias=None):
146
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
147
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
148
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
149
+ return x
150
+
151
+
152
+ class T5RelativeEmbedding(nn.Module):
153
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
154
+ super(T5RelativeEmbedding, self).__init__()
155
+ self.num_buckets = num_buckets
156
+ self.num_heads = num_heads
157
+ self.bidirectional = bidirectional
158
+ self.max_dist = max_dist
159
+
160
+ # layers
161
+ self.embedding = nn.Embedding(num_buckets, num_heads)
162
+
163
+ def forward(self, lq, lk):
164
+ device = self.embedding.weight.device
165
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
166
+ # torch.arange(lq).unsqueeze(1).to(device)
167
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(
168
+ lq, device=device
169
+ ).unsqueeze(1)
170
+ rel_pos = self._relative_position_bucket(rel_pos)
171
+ rel_pos_embeds = self.embedding(rel_pos)
172
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
173
+ return rel_pos_embeds.contiguous()
174
+
175
+ def _relative_position_bucket(self, rel_pos):
176
+ # preprocess
177
+ if self.bidirectional:
178
+ num_buckets = self.num_buckets // 2
179
+ rel_buckets = (rel_pos > 0).long() * num_buckets
180
+ rel_pos = torch.abs(rel_pos)
181
+ else:
182
+ num_buckets = self.num_buckets
183
+ rel_buckets = 0
184
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
185
+
186
+ # embeddings for small and large positions
187
+ max_exact = num_buckets // 2
188
+ rel_pos_large = (
189
+ max_exact
190
+ + (
191
+ torch.log(rel_pos.float() / max_exact)
192
+ / math.log(self.max_dist / max_exact)
193
+ * (num_buckets - max_exact)
194
+ ).long()
195
+ )
196
+ rel_pos_large = torch.min(
197
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
198
+ )
199
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
200
+ return rel_buckets
201
+
202
+
203
+ def init_weights(m):
204
+ if isinstance(m, T5LayerNorm):
205
+ nn.init.ones_(m.weight)
206
+ elif isinstance(m, T5FeedForward):
207
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
208
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
209
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
210
+ elif isinstance(m, T5Attention):
211
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
212
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
213
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
214
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
215
+ elif isinstance(m, T5RelativeEmbedding):
216
+ nn.init.normal_(
217
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
218
+ )
219
+
220
+
221
+ class WanTextEncoder(torch.nn.Module):
222
+ def __init__(
223
+ self,
224
+ vocab=256384,
225
+ dim=4096,
226
+ dim_attn=4096,
227
+ dim_ffn=10240,
228
+ num_heads=64,
229
+ num_layers=24,
230
+ num_buckets=32,
231
+ shared_pos=False,
232
+ dropout=0.1,
233
+ ):
234
+ super(WanTextEncoder, self).__init__()
235
+ self.dim = dim
236
+ self.dim_attn = dim_attn
237
+ self.dim_ffn = dim_ffn
238
+ self.num_heads = num_heads
239
+ self.num_layers = num_layers
240
+ self.num_buckets = num_buckets
241
+ self.shared_pos = shared_pos
242
+
243
+ # layers
244
+ self.token_embedding = (
245
+ vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
246
+ )
247
+ self.pos_embedding = (
248
+ T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
249
+ if shared_pos
250
+ else None
251
+ )
252
+ self.dropout = nn.Dropout(dropout)
253
+ self.blocks = nn.ModuleList(
254
+ [
255
+ T5SelfAttention(
256
+ dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
257
+ )
258
+ for _ in range(num_layers)
259
+ ]
260
+ )
261
+ self.norm = T5LayerNorm(dim)
262
+
263
+ # initialize weights
264
+ self.apply(init_weights)
265
+
266
+ def forward(self, ids, mask=None):
267
+ x = self.token_embedding(ids)
268
+ x = self.dropout(x)
269
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
270
+ for block in self.blocks:
271
+ x = block(x, mask, pos_bias=e)
272
+ x = self.norm(x)
273
+ x = self.dropout(x)
274
+ return x
275
+
276
+ @staticmethod
277
+ def state_dict_converter():
278
+ return WanTextEncoderStateDictConverter()
279
+
280
+
281
+ class WanTextEncoderStateDictConverter:
282
+ def __init__(self):
283
+ pass
284
+
285
+ def from_diffusers(self, state_dict):
286
+ return state_dict
287
+
288
+ def from_civitai(self, state_dict):
289
+ return state_dict
models/wan_video_vace.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .wan_video_dit import DiTBlock
3
+ from .utils import hash_state_dict_keys
4
+
5
+
6
+ class VaceWanAttentionBlock(DiTBlock):
7
+ def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
8
+ super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
9
+ self.block_id = block_id
10
+ if block_id == 0:
11
+ self.before_proj = torch.nn.Linear(self.dim, self.dim)
12
+ self.after_proj = torch.nn.Linear(self.dim, self.dim)
13
+
14
+ def forward(self, c, x, context, t_mod, freqs):
15
+ if self.block_id == 0:
16
+ c = self.before_proj(c) + x
17
+ all_c = []
18
+ else:
19
+ all_c = list(torch.unbind(c))
20
+ c = all_c.pop(-1)
21
+ c, _ = super().forward(c, context, t_mod, freqs)
22
+ c_skip = self.after_proj(c)
23
+ all_c += [c_skip, c]
24
+ c = torch.stack(all_c)
25
+ return c
26
+
27
+
28
+ class VaceWanModel(torch.nn.Module):
29
+ def __init__(
30
+ self,
31
+ vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
32
+ vace_in_dim=96,
33
+ patch_size=(1, 2, 2),
34
+ has_image_input=False,
35
+ dim=1536,
36
+ num_heads=12,
37
+ ffn_dim=8960,
38
+ eps=1e-6,
39
+ ):
40
+ super().__init__()
41
+ self.vace_layers = vace_layers
42
+ self.vace_in_dim = vace_in_dim
43
+ self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
44
+
45
+ # vace blocks
46
+ self.vace_blocks = torch.nn.ModuleList(
47
+ [
48
+ VaceWanAttentionBlock(
49
+ has_image_input, dim, num_heads, ffn_dim, eps, block_id=i
50
+ )
51
+ for i in self.vace_layers
52
+ ]
53
+ )
54
+
55
+ # vace patch embeddings
56
+ self.vace_patch_embedding = torch.nn.Conv3d(
57
+ vace_in_dim, dim, kernel_size=patch_size, stride=patch_size
58
+ )
59
+
60
+ def forward(
61
+ self,
62
+ x,
63
+ vace_context,
64
+ context,
65
+ t_mod,
66
+ freqs,
67
+ use_gradient_checkpointing: bool = False,
68
+ use_gradient_checkpointing_offload: bool = False,
69
+ ):
70
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
71
+ c = [u.flatten(2).transpose(1, 2) for u in c]
72
+ c = torch.cat(
73
+ [
74
+ torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))], dim=1)
75
+ for u in c
76
+ ]
77
+ )
78
+
79
+ def create_custom_forward(module):
80
+ def custom_forward(*inputs):
81
+ return module(*inputs)
82
+
83
+ return custom_forward
84
+
85
+ for block in self.vace_blocks:
86
+ if use_gradient_checkpointing_offload:
87
+ with torch.autograd.graph.save_on_cpu():
88
+ c = torch.utils.checkpoint.checkpoint(
89
+ create_custom_forward(block),
90
+ c,
91
+ x,
92
+ context,
93
+ t_mod,
94
+ freqs,
95
+ use_reentrant=False,
96
+ )
97
+ elif use_gradient_checkpointing:
98
+ c = torch.utils.checkpoint.checkpoint(
99
+ create_custom_forward(block),
100
+ c,
101
+ x,
102
+ context,
103
+ t_mod,
104
+ freqs,
105
+ use_reentrant=False,
106
+ )
107
+ else:
108
+ c = block(c, x, context, t_mod, freqs)
109
+ hints = torch.unbind(c)[:-1]
110
+ return hints
111
+
112
+ @staticmethod
113
+ def state_dict_converter():
114
+ return VaceWanModelDictConverter()
115
+
116
+
117
+ class VaceWanModelDictConverter:
118
+ def __init__(self):
119
+ pass
120
+
121
+ def from_civitai(self, state_dict):
122
+ state_dict_ = {
123
+ name: param for name, param in state_dict.items() if name.startswith("vace")
124
+ }
125
+ if (
126
+ hash_state_dict_keys(state_dict_) == "3b2726384e4f64837bdf216eea3f310d"
127
+ ): # vace 14B
128
+ config = {
129
+ "vace_layers": (0, 5, 10, 15, 20, 25, 30, 35),
130
+ "vace_in_dim": 96,
131
+ "patch_size": (1, 2, 2),
132
+ "has_image_input": False,
133
+ "dim": 5120,
134
+ "num_heads": 40,
135
+ "ffn_dim": 13824,
136
+ "eps": 1e-06,
137
+ }
138
+ else:
139
+ config = {}
140
+ return state_dict_, config
models/wan_video_vae.py ADDED
@@ -0,0 +1,1634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange, repeat
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+
8
+ CACHE_T = 2
9
+
10
+
11
+ def check_is_instance(model, module_class):
12
+ if isinstance(model, module_class):
13
+ return True
14
+ if hasattr(model, "module") and isinstance(model.module, module_class):
15
+ return True
16
+ return False
17
+
18
+
19
+ def block_causal_mask(x, block_size):
20
+ # params
21
+ b, n, s, _, device = *x.size(), x.device
22
+ assert s % block_size == 0
23
+ num_blocks = s // block_size
24
+
25
+ # build mask
26
+ mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
27
+ for i in range(num_blocks):
28
+ mask[:, :, i * block_size : (i + 1) * block_size, : (i + 1) * block_size] = 1
29
+ return mask
30
+
31
+
32
+ class CausalConv3d(nn.Conv3d):
33
+ """
34
+ Causal 3d convolusion.
35
+ """
36
+
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+ self._padding = (
40
+ self.padding[2],
41
+ self.padding[2],
42
+ self.padding[1],
43
+ self.padding[1],
44
+ 2 * self.padding[0],
45
+ 0,
46
+ )
47
+ self.padding = (0, 0, 0)
48
+
49
+ def forward(self, x, cache_x=None):
50
+ padding = list(self._padding)
51
+ if cache_x is not None and self._padding[4] > 0:
52
+ cache_x = cache_x.to(x.device)
53
+ x = torch.cat([cache_x, x], dim=2)
54
+ padding[4] -= cache_x.shape[2]
55
+ x = F.pad(x, padding)
56
+
57
+ return super().forward(x)
58
+
59
+
60
+ class RMS_norm(nn.Module):
61
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
62
+ super().__init__()
63
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
64
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
65
+
66
+ self.channel_first = channel_first
67
+ self.scale = dim**0.5
68
+ self.gamma = nn.Parameter(torch.ones(shape))
69
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
70
+
71
+ def forward(self, x):
72
+ return (
73
+ F.normalize(x, dim=(1 if self.channel_first else -1))
74
+ * self.scale
75
+ * self.gamma
76
+ + self.bias
77
+ )
78
+
79
+
80
+ class Upsample(nn.Upsample):
81
+ def forward(self, x):
82
+ """
83
+ Fix bfloat16 support for nearest neighbor interpolation.
84
+ """
85
+ return super().forward(x.float()).type_as(x)
86
+
87
+
88
+ class Resample(nn.Module):
89
+ def __init__(self, dim, mode):
90
+ assert mode in (
91
+ "none",
92
+ "upsample2d",
93
+ "upsample3d",
94
+ "downsample2d",
95
+ "downsample3d",
96
+ )
97
+ super().__init__()
98
+ self.dim = dim
99
+ self.mode = mode
100
+
101
+ # layers
102
+ if mode == "upsample2d":
103
+ self.resample = nn.Sequential(
104
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
105
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
106
+ )
107
+ elif mode == "upsample3d":
108
+ self.resample = nn.Sequential(
109
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
110
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
111
+ )
112
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
113
+
114
+ elif mode == "downsample2d":
115
+ self.resample = nn.Sequential(
116
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
117
+ )
118
+ elif mode == "downsample3d":
119
+ self.resample = nn.Sequential(
120
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
121
+ )
122
+ self.time_conv = CausalConv3d(
123
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
124
+ )
125
+
126
+ else:
127
+ self.resample = nn.Identity()
128
+
129
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
130
+ b, c, t, h, w = x.size()
131
+ if self.mode == "upsample3d":
132
+ if feat_cache is not None:
133
+ idx = feat_idx[0]
134
+ if feat_cache[idx] is None:
135
+ feat_cache[idx] = "Rep"
136
+ feat_idx[0] += 1
137
+ else:
138
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
139
+ if (
140
+ cache_x.shape[2] < 2
141
+ and feat_cache[idx] is not None
142
+ and feat_cache[idx] != "Rep"
143
+ ):
144
+ # cache last frame of last two chunk
145
+ cache_x = torch.cat(
146
+ [
147
+ feat_cache[idx][:, :, -1, :, :]
148
+ .unsqueeze(2)
149
+ .to(cache_x.device),
150
+ cache_x,
151
+ ],
152
+ dim=2,
153
+ )
154
+ if (
155
+ cache_x.shape[2] < 2
156
+ and feat_cache[idx] is not None
157
+ and feat_cache[idx] == "Rep"
158
+ ):
159
+ cache_x = torch.cat(
160
+ [torch.zeros_like(cache_x).to(cache_x.device), cache_x],
161
+ dim=2,
162
+ )
163
+ if feat_cache[idx] == "Rep":
164
+ x = self.time_conv(x)
165
+ else:
166
+ x = self.time_conv(x, feat_cache[idx])
167
+ feat_cache[idx] = cache_x
168
+ feat_idx[0] += 1
169
+
170
+ x = x.reshape(b, 2, c, t, h, w)
171
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
172
+ x = x.reshape(b, c, t * 2, h, w)
173
+ t = x.shape[2]
174
+ x = rearrange(x, "b c t h w -> (b t) c h w")
175
+ x = self.resample(x)
176
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
177
+
178
+ if self.mode == "downsample3d":
179
+ if feat_cache is not None:
180
+ idx = feat_idx[0]
181
+ if feat_cache[idx] is None:
182
+ feat_cache[idx] = x.clone()
183
+ feat_idx[0] += 1
184
+ else:
185
+ cache_x = x[:, :, -1:, :, :].clone()
186
+ x = self.time_conv(
187
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
188
+ )
189
+ feat_cache[idx] = cache_x
190
+ feat_idx[0] += 1
191
+ return x
192
+
193
+ def init_weight(self, conv):
194
+ conv_weight = conv.weight
195
+ nn.init.zeros_(conv_weight)
196
+ c1, c2, t, h, w = conv_weight.size()
197
+ one_matrix = torch.eye(c1, c2)
198
+ init_matrix = one_matrix
199
+ nn.init.zeros_(conv_weight)
200
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix
201
+ conv.weight.data.copy_(conv_weight)
202
+ nn.init.zeros_(conv.bias.data)
203
+
204
+ def init_weight2(self, conv):
205
+ conv_weight = conv.weight.data
206
+ nn.init.zeros_(conv_weight)
207
+ c1, c2, t, h, w = conv_weight.size()
208
+ init_matrix = torch.eye(c1 // 2, c2)
209
+ conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
210
+ conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
211
+ conv.weight.data.copy_(conv_weight)
212
+ nn.init.zeros_(conv.bias.data)
213
+
214
+
215
+ def patchify(x, patch_size):
216
+ if patch_size == 1:
217
+ return x
218
+ if x.dim() == 4:
219
+ x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
220
+ elif x.dim() == 5:
221
+ x = rearrange(
222
+ x, "b c f (h q) (w r) -> b (c r q) f h w", q=patch_size, r=patch_size
223
+ )
224
+ else:
225
+ raise ValueError(f"Invalid input shape: {x.shape}")
226
+ return x
227
+
228
+
229
+ def unpatchify(x, patch_size):
230
+ if patch_size == 1:
231
+ return x
232
+ if x.dim() == 4:
233
+ x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
234
+ elif x.dim() == 5:
235
+ x = rearrange(
236
+ x, "b (c r q) f h w -> b c f (h q) (w r)", q=patch_size, r=patch_size
237
+ )
238
+ return x
239
+
240
+
241
+ class Resample38(Resample):
242
+ def __init__(self, dim, mode):
243
+ assert mode in (
244
+ "none",
245
+ "upsample2d",
246
+ "upsample3d",
247
+ "downsample2d",
248
+ "downsample3d",
249
+ )
250
+ super(Resample, self).__init__()
251
+ self.dim = dim
252
+ self.mode = mode
253
+
254
+ # layers
255
+ if mode == "upsample2d":
256
+ self.resample = nn.Sequential(
257
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
258
+ nn.Conv2d(dim, dim, 3, padding=1),
259
+ )
260
+ elif mode == "upsample3d":
261
+ self.resample = nn.Sequential(
262
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
263
+ nn.Conv2d(dim, dim, 3, padding=1),
264
+ )
265
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
266
+ elif mode == "downsample2d":
267
+ self.resample = nn.Sequential(
268
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
269
+ )
270
+ elif mode == "downsample3d":
271
+ self.resample = nn.Sequential(
272
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
273
+ )
274
+ self.time_conv = CausalConv3d(
275
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
276
+ )
277
+ else:
278
+ self.resample = nn.Identity()
279
+
280
+
281
+ class ResidualBlock(nn.Module):
282
+ def __init__(self, in_dim, out_dim, dropout=0.0):
283
+ super().__init__()
284
+ self.in_dim = in_dim
285
+ self.out_dim = out_dim
286
+
287
+ # layers
288
+ self.residual = nn.Sequential(
289
+ RMS_norm(in_dim, images=False),
290
+ nn.SiLU(),
291
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
292
+ RMS_norm(out_dim, images=False),
293
+ nn.SiLU(),
294
+ nn.Dropout(dropout),
295
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
296
+ )
297
+ self.shortcut = (
298
+ CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
299
+ )
300
+
301
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
302
+ h = self.shortcut(x)
303
+ for layer in self.residual:
304
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
305
+ idx = feat_idx[0]
306
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
307
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
308
+ # cache last frame of last two chunk
309
+ cache_x = torch.cat(
310
+ [
311
+ feat_cache[idx][:, :, -1, :, :]
312
+ .unsqueeze(2)
313
+ .to(cache_x.device),
314
+ cache_x,
315
+ ],
316
+ dim=2,
317
+ )
318
+ x = layer(x, feat_cache[idx])
319
+ feat_cache[idx] = cache_x
320
+ feat_idx[0] += 1
321
+ else:
322
+ x = layer(x)
323
+ return x + h
324
+
325
+
326
+ class AttentionBlock(nn.Module):
327
+ """
328
+ Causal self-attention with a single head.
329
+ """
330
+
331
+ def __init__(self, dim):
332
+ super().__init__()
333
+ self.dim = dim
334
+
335
+ # layers
336
+ self.norm = RMS_norm(dim)
337
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
338
+ self.proj = nn.Conv2d(dim, dim, 1)
339
+
340
+ # zero out the last layer params
341
+ nn.init.zeros_(self.proj.weight)
342
+
343
+ def forward(self, x):
344
+ identity = x
345
+ b, c, t, h, w = x.size()
346
+ x = rearrange(x, "b c t h w -> (b t) c h w")
347
+ x = self.norm(x)
348
+ # compute query, key, value
349
+ q, k, v = (
350
+ self.to_qkv(x)
351
+ .reshape(b * t, 1, c * 3, -1)
352
+ .permute(0, 1, 3, 2)
353
+ .contiguous()
354
+ .chunk(3, dim=-1)
355
+ )
356
+
357
+ # apply attention
358
+ x = F.scaled_dot_product_attention(
359
+ q,
360
+ k,
361
+ v,
362
+ # attn_mask=block_causal_mask(q, block_size=h * w)
363
+ )
364
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
365
+
366
+ # output
367
+ x = self.proj(x)
368
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
369
+ return x + identity
370
+
371
+
372
+ class AvgDown3D(nn.Module):
373
+ def __init__(
374
+ self,
375
+ in_channels,
376
+ out_channels,
377
+ factor_t,
378
+ factor_s=1,
379
+ ):
380
+ super().__init__()
381
+ self.in_channels = in_channels
382
+ self.out_channels = out_channels
383
+ self.factor_t = factor_t
384
+ self.factor_s = factor_s
385
+ self.factor = self.factor_t * self.factor_s * self.factor_s
386
+
387
+ assert in_channels * self.factor % out_channels == 0
388
+ self.group_size = in_channels * self.factor // out_channels
389
+
390
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
391
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
392
+ pad = (0, 0, 0, 0, pad_t, 0)
393
+ x = F.pad(x, pad)
394
+ B, C, T, H, W = x.shape
395
+ x = x.view(
396
+ B,
397
+ C,
398
+ T // self.factor_t,
399
+ self.factor_t,
400
+ H // self.factor_s,
401
+ self.factor_s,
402
+ W // self.factor_s,
403
+ self.factor_s,
404
+ )
405
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
406
+ x = x.view(
407
+ B,
408
+ C * self.factor,
409
+ T // self.factor_t,
410
+ H // self.factor_s,
411
+ W // self.factor_s,
412
+ )
413
+ x = x.view(
414
+ B,
415
+ self.out_channels,
416
+ self.group_size,
417
+ T // self.factor_t,
418
+ H // self.factor_s,
419
+ W // self.factor_s,
420
+ )
421
+ x = x.mean(dim=2)
422
+ return x
423
+
424
+
425
+ class DupUp3D(nn.Module):
426
+ def __init__(
427
+ self,
428
+ in_channels: int,
429
+ out_channels: int,
430
+ factor_t,
431
+ factor_s=1,
432
+ ):
433
+ super().__init__()
434
+ self.in_channels = in_channels
435
+ self.out_channels = out_channels
436
+
437
+ self.factor_t = factor_t
438
+ self.factor_s = factor_s
439
+ self.factor = self.factor_t * self.factor_s * self.factor_s
440
+
441
+ assert out_channels * self.factor % in_channels == 0
442
+ self.repeats = out_channels * self.factor // in_channels
443
+
444
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
445
+ x = x.repeat_interleave(self.repeats, dim=1)
446
+ x = x.view(
447
+ x.size(0),
448
+ self.out_channels,
449
+ self.factor_t,
450
+ self.factor_s,
451
+ self.factor_s,
452
+ x.size(2),
453
+ x.size(3),
454
+ x.size(4),
455
+ )
456
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
457
+ x = x.view(
458
+ x.size(0),
459
+ self.out_channels,
460
+ x.size(2) * self.factor_t,
461
+ x.size(4) * self.factor_s,
462
+ x.size(6) * self.factor_s,
463
+ )
464
+ if first_chunk:
465
+ x = x[:, :, self.factor_t - 1 :, :, :]
466
+ return x
467
+
468
+
469
+ class Down_ResidualBlock(nn.Module):
470
+ def __init__(
471
+ self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False
472
+ ):
473
+ super().__init__()
474
+
475
+ # Shortcut path with downsample
476
+ self.avg_shortcut = AvgDown3D(
477
+ in_dim,
478
+ out_dim,
479
+ factor_t=2 if temperal_downsample else 1,
480
+ factor_s=2 if down_flag else 1,
481
+ )
482
+
483
+ # Main path with residual blocks and downsample
484
+ downsamples = []
485
+ for _ in range(mult):
486
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
487
+ in_dim = out_dim
488
+
489
+ # Add the final downsample block
490
+ if down_flag:
491
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
492
+ downsamples.append(Resample38(out_dim, mode=mode))
493
+
494
+ self.downsamples = nn.Sequential(*downsamples)
495
+
496
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
497
+ x_copy = x.clone()
498
+ for module in self.downsamples:
499
+ x = module(x, feat_cache, feat_idx)
500
+
501
+ return x + self.avg_shortcut(x_copy)
502
+
503
+
504
+ class Up_ResidualBlock(nn.Module):
505
+ def __init__(
506
+ self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False
507
+ ):
508
+ super().__init__()
509
+ # Shortcut path with upsample
510
+ if up_flag:
511
+ self.avg_shortcut = DupUp3D(
512
+ in_dim,
513
+ out_dim,
514
+ factor_t=2 if temperal_upsample else 1,
515
+ factor_s=2 if up_flag else 1,
516
+ )
517
+ else:
518
+ self.avg_shortcut = None
519
+
520
+ # Main path with residual blocks and upsample
521
+ upsamples = []
522
+ for _ in range(mult):
523
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
524
+ in_dim = out_dim
525
+
526
+ # Add the final upsample block
527
+ if up_flag:
528
+ mode = "upsample3d" if temperal_upsample else "upsample2d"
529
+ upsamples.append(Resample38(out_dim, mode=mode))
530
+
531
+ self.upsamples = nn.Sequential(*upsamples)
532
+
533
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
534
+ x_main = x.clone()
535
+ for module in self.upsamples:
536
+ x_main = module(x_main, feat_cache, feat_idx)
537
+ if self.avg_shortcut is not None:
538
+ x_shortcut = self.avg_shortcut(x, first_chunk)
539
+ return x_main + x_shortcut
540
+ else:
541
+ return x_main
542
+
543
+
544
+ class Encoder3d(nn.Module):
545
+ def __init__(
546
+ self,
547
+ dim=128,
548
+ z_dim=4,
549
+ dim_mult=[1, 2, 4, 4],
550
+ num_res_blocks=2,
551
+ attn_scales=[],
552
+ temperal_downsample=[True, True, False],
553
+ dropout=0.0,
554
+ ):
555
+ super().__init__()
556
+ self.dim = dim
557
+ self.z_dim = z_dim
558
+ self.dim_mult = dim_mult
559
+ self.num_res_blocks = num_res_blocks
560
+ self.attn_scales = attn_scales
561
+ self.temperal_downsample = temperal_downsample
562
+
563
+ # dimensions
564
+ dims = [dim * u for u in [1] + dim_mult]
565
+ scale = 1.0
566
+
567
+ # init block
568
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
569
+
570
+ # downsample blocks
571
+ downsamples = []
572
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
573
+ # residual (+attention) blocks
574
+ for _ in range(num_res_blocks):
575
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
576
+ if scale in attn_scales:
577
+ downsamples.append(AttentionBlock(out_dim))
578
+ in_dim = out_dim
579
+
580
+ # downsample block
581
+ if i != len(dim_mult) - 1:
582
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
583
+ downsamples.append(Resample(out_dim, mode=mode))
584
+ scale /= 2.0
585
+ self.downsamples = nn.Sequential(*downsamples)
586
+
587
+ # middle blocks
588
+ self.middle = nn.Sequential(
589
+ ResidualBlock(out_dim, out_dim, dropout),
590
+ AttentionBlock(out_dim),
591
+ ResidualBlock(out_dim, out_dim, dropout),
592
+ )
593
+
594
+ # output blocks
595
+ self.head = nn.Sequential(
596
+ RMS_norm(out_dim, images=False),
597
+ nn.SiLU(),
598
+ CausalConv3d(out_dim, z_dim, 3, padding=1),
599
+ )
600
+
601
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
602
+ if feat_cache is not None:
603
+ idx = feat_idx[0]
604
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
605
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
606
+ # cache last frame of last two chunk
607
+ cache_x = torch.cat(
608
+ [
609
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
610
+ cache_x,
611
+ ],
612
+ dim=2,
613
+ )
614
+ x = self.conv1(x, feat_cache[idx])
615
+ feat_cache[idx] = cache_x
616
+ feat_idx[0] += 1
617
+ else:
618
+ x = self.conv1(x)
619
+
620
+ ## downsamples
621
+ for layer in self.downsamples:
622
+ if feat_cache is not None:
623
+ x = layer(x, feat_cache, feat_idx)
624
+ else:
625
+ x = layer(x)
626
+
627
+ ## middle
628
+ for layer in self.middle:
629
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
630
+ x = layer(x, feat_cache, feat_idx)
631
+ else:
632
+ x = layer(x)
633
+
634
+ ## head
635
+ for layer in self.head:
636
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
637
+ idx = feat_idx[0]
638
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
639
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
640
+ # cache last frame of last two chunk
641
+ cache_x = torch.cat(
642
+ [
643
+ feat_cache[idx][:, :, -1, :, :]
644
+ .unsqueeze(2)
645
+ .to(cache_x.device),
646
+ cache_x,
647
+ ],
648
+ dim=2,
649
+ )
650
+ x = layer(x, feat_cache[idx])
651
+ feat_cache[idx] = cache_x
652
+ feat_idx[0] += 1
653
+ else:
654
+ x = layer(x)
655
+ return x
656
+
657
+
658
+ class Encoder3d_38(nn.Module):
659
+ def __init__(
660
+ self,
661
+ dim=128,
662
+ z_dim=4,
663
+ dim_mult=[1, 2, 4, 4],
664
+ num_res_blocks=2,
665
+ attn_scales=[],
666
+ temperal_downsample=[False, True, True],
667
+ dropout=0.0,
668
+ ):
669
+ super().__init__()
670
+ self.dim = dim
671
+ self.z_dim = z_dim
672
+ self.dim_mult = dim_mult
673
+ self.num_res_blocks = num_res_blocks
674
+ self.attn_scales = attn_scales
675
+ self.temperal_downsample = temperal_downsample
676
+
677
+ # dimensions
678
+ dims = [dim * u for u in [1] + dim_mult]
679
+ scale = 1.0
680
+
681
+ # init block
682
+ self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
683
+
684
+ # downsample blocks
685
+ downsamples = []
686
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
687
+ t_down_flag = (
688
+ temperal_downsample[i] if i < len(temperal_downsample) else False
689
+ )
690
+ downsamples.append(
691
+ Down_ResidualBlock(
692
+ in_dim=in_dim,
693
+ out_dim=out_dim,
694
+ dropout=dropout,
695
+ mult=num_res_blocks,
696
+ temperal_downsample=t_down_flag,
697
+ down_flag=i != len(dim_mult) - 1,
698
+ )
699
+ )
700
+ scale /= 2.0
701
+ self.downsamples = nn.Sequential(*downsamples)
702
+
703
+ # middle blocks
704
+ self.middle = nn.Sequential(
705
+ ResidualBlock(out_dim, out_dim, dropout),
706
+ AttentionBlock(out_dim),
707
+ ResidualBlock(out_dim, out_dim, dropout),
708
+ )
709
+
710
+ # # output blocks
711
+ self.head = nn.Sequential(
712
+ RMS_norm(out_dim, images=False),
713
+ nn.SiLU(),
714
+ CausalConv3d(out_dim, z_dim, 3, padding=1),
715
+ )
716
+
717
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
718
+ if feat_cache is not None:
719
+ idx = feat_idx[0]
720
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
721
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
722
+ cache_x = torch.cat(
723
+ [
724
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
725
+ cache_x,
726
+ ],
727
+ dim=2,
728
+ )
729
+ x = self.conv1(x, feat_cache[idx])
730
+ feat_cache[idx] = cache_x
731
+ feat_idx[0] += 1
732
+ else:
733
+ x = self.conv1(x)
734
+
735
+ ## downsamples
736
+ for layer in self.downsamples:
737
+ if feat_cache is not None:
738
+ x = layer(x, feat_cache, feat_idx)
739
+ else:
740
+ x = layer(x)
741
+
742
+ ## middle
743
+ for layer in self.middle:
744
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
745
+ x = layer(x, feat_cache, feat_idx)
746
+ else:
747
+ x = layer(x)
748
+
749
+ ## head
750
+ for layer in self.head:
751
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
752
+ idx = feat_idx[0]
753
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
754
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
755
+ cache_x = torch.cat(
756
+ [
757
+ feat_cache[idx][:, :, -1, :, :]
758
+ .unsqueeze(2)
759
+ .to(cache_x.device),
760
+ cache_x,
761
+ ],
762
+ dim=2,
763
+ )
764
+ x = layer(x, feat_cache[idx])
765
+ feat_cache[idx] = cache_x
766
+ feat_idx[0] += 1
767
+ else:
768
+ x = layer(x)
769
+
770
+ return x
771
+
772
+
773
+ class Decoder3d(nn.Module):
774
+ def __init__(
775
+ self,
776
+ dim=128,
777
+ z_dim=4,
778
+ dim_mult=[1, 2, 4, 4],
779
+ num_res_blocks=2,
780
+ attn_scales=[],
781
+ temperal_upsample=[False, True, True],
782
+ dropout=0.0,
783
+ ):
784
+ super().__init__()
785
+ self.dim = dim
786
+ self.z_dim = z_dim
787
+ self.dim_mult = dim_mult
788
+ self.num_res_blocks = num_res_blocks
789
+ self.attn_scales = attn_scales
790
+ self.temperal_upsample = temperal_upsample
791
+
792
+ # dimensions
793
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
794
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
795
+
796
+ # init block
797
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
798
+
799
+ # middle blocks
800
+ self.middle = nn.Sequential(
801
+ ResidualBlock(dims[0], dims[0], dropout),
802
+ AttentionBlock(dims[0]),
803
+ ResidualBlock(dims[0], dims[0], dropout),
804
+ )
805
+
806
+ # upsample blocks
807
+ upsamples = []
808
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
809
+ # residual (+attention) blocks
810
+ if i == 1 or i == 2 or i == 3:
811
+ in_dim = in_dim // 2
812
+ for _ in range(num_res_blocks + 1):
813
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
814
+ if scale in attn_scales:
815
+ upsamples.append(AttentionBlock(out_dim))
816
+ in_dim = out_dim
817
+
818
+ # upsample block
819
+ if i != len(dim_mult) - 1:
820
+ mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
821
+ upsamples.append(Resample(out_dim, mode=mode))
822
+ scale *= 2.0
823
+ self.upsamples = nn.Sequential(*upsamples)
824
+
825
+ # output blocks
826
+ self.head = nn.Sequential(
827
+ RMS_norm(out_dim, images=False),
828
+ nn.SiLU(),
829
+ CausalConv3d(out_dim, 3, 3, padding=1),
830
+ )
831
+
832
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
833
+ ## conv1
834
+ if feat_cache is not None:
835
+ idx = feat_idx[0]
836
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
837
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
838
+ # cache last frame of last two chunk
839
+ cache_x = torch.cat(
840
+ [
841
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
842
+ cache_x,
843
+ ],
844
+ dim=2,
845
+ )
846
+ x = self.conv1(x, feat_cache[idx])
847
+ feat_cache[idx] = cache_x
848
+ feat_idx[0] += 1
849
+ else:
850
+ x = self.conv1(x)
851
+
852
+ ## middle
853
+ for layer in self.middle:
854
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
855
+ x = layer(x, feat_cache, feat_idx)
856
+ else:
857
+ x = layer(x)
858
+
859
+ ## upsamples
860
+ for layer in self.upsamples:
861
+ if feat_cache is not None:
862
+ x = layer(x, feat_cache, feat_idx)
863
+ else:
864
+ x = layer(x)
865
+
866
+ ## head
867
+ for layer in self.head:
868
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
869
+ idx = feat_idx[0]
870
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
871
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
872
+ # cache last frame of last two chunk
873
+ cache_x = torch.cat(
874
+ [
875
+ feat_cache[idx][:, :, -1, :, :]
876
+ .unsqueeze(2)
877
+ .to(cache_x.device),
878
+ cache_x,
879
+ ],
880
+ dim=2,
881
+ )
882
+ x = layer(x, feat_cache[idx])
883
+ feat_cache[idx] = cache_x
884
+ feat_idx[0] += 1
885
+ else:
886
+ x = layer(x)
887
+ return x
888
+
889
+
890
+ class Decoder3d_38(nn.Module):
891
+ def __init__(
892
+ self,
893
+ dim=128,
894
+ z_dim=4,
895
+ dim_mult=[1, 2, 4, 4],
896
+ num_res_blocks=2,
897
+ attn_scales=[],
898
+ temperal_upsample=[False, True, True],
899
+ dropout=0.0,
900
+ ):
901
+ super().__init__()
902
+ self.dim = dim
903
+ self.z_dim = z_dim
904
+ self.dim_mult = dim_mult
905
+ self.num_res_blocks = num_res_blocks
906
+ self.attn_scales = attn_scales
907
+ self.temperal_upsample = temperal_upsample
908
+
909
+ # dimensions
910
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
911
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
912
+ # init block
913
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
914
+
915
+ # middle blocks
916
+ self.middle = nn.Sequential(
917
+ ResidualBlock(dims[0], dims[0], dropout),
918
+ AttentionBlock(dims[0]),
919
+ ResidualBlock(dims[0], dims[0], dropout),
920
+ )
921
+
922
+ # upsample blocks
923
+ upsamples = []
924
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
925
+ t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
926
+ upsamples.append(
927
+ Up_ResidualBlock(
928
+ in_dim=in_dim,
929
+ out_dim=out_dim,
930
+ dropout=dropout,
931
+ mult=num_res_blocks + 1,
932
+ temperal_upsample=t_up_flag,
933
+ up_flag=i != len(dim_mult) - 1,
934
+ )
935
+ )
936
+ self.upsamples = nn.Sequential(*upsamples)
937
+
938
+ # output blocks
939
+ self.head = nn.Sequential(
940
+ RMS_norm(out_dim, images=False),
941
+ nn.SiLU(),
942
+ CausalConv3d(out_dim, 12, 3, padding=1),
943
+ )
944
+
945
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
946
+ if feat_cache is not None:
947
+ idx = feat_idx[0]
948
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
949
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
950
+ cache_x = torch.cat(
951
+ [
952
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
953
+ cache_x,
954
+ ],
955
+ dim=2,
956
+ )
957
+ x = self.conv1(x, feat_cache[idx])
958
+ feat_cache[idx] = cache_x
959
+ feat_idx[0] += 1
960
+ else:
961
+ x = self.conv1(x)
962
+
963
+ for layer in self.middle:
964
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
965
+ x = layer(x, feat_cache, feat_idx)
966
+ else:
967
+ x = layer(x)
968
+
969
+ ## upsamples
970
+ for layer in self.upsamples:
971
+ if feat_cache is not None:
972
+ x = layer(x, feat_cache, feat_idx, first_chunk)
973
+ else:
974
+ x = layer(x)
975
+
976
+ ## head
977
+ for layer in self.head:
978
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
979
+ idx = feat_idx[0]
980
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
981
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
982
+ cache_x = torch.cat(
983
+ [
984
+ feat_cache[idx][:, :, -1, :, :]
985
+ .unsqueeze(2)
986
+ .to(cache_x.device),
987
+ cache_x,
988
+ ],
989
+ dim=2,
990
+ )
991
+ x = layer(x, feat_cache[idx])
992
+ feat_cache[idx] = cache_x
993
+ feat_idx[0] += 1
994
+ else:
995
+ x = layer(x)
996
+ return x
997
+
998
+
999
+ def count_conv3d(model):
1000
+ count = 0
1001
+ for m in model.modules():
1002
+ if isinstance(m, CausalConv3d):
1003
+ count += 1
1004
+ return count
1005
+
1006
+
1007
+ class VideoVAE_(nn.Module):
1008
+ def __init__(
1009
+ self,
1010
+ dim=96,
1011
+ z_dim=16,
1012
+ dim_mult=[1, 2, 4, 4],
1013
+ num_res_blocks=2,
1014
+ attn_scales=[],
1015
+ temperal_downsample=[False, True, True],
1016
+ dropout=0.0,
1017
+ ):
1018
+ super().__init__()
1019
+ self.dim = dim
1020
+ self.z_dim = z_dim
1021
+ self.dim_mult = dim_mult
1022
+ self.num_res_blocks = num_res_blocks
1023
+ self.attn_scales = attn_scales
1024
+ self.temperal_downsample = temperal_downsample
1025
+ self.temperal_upsample = temperal_downsample[::-1]
1026
+
1027
+ # modules
1028
+ self.encoder = Encoder3d(
1029
+ dim,
1030
+ z_dim * 2,
1031
+ dim_mult,
1032
+ num_res_blocks,
1033
+ attn_scales,
1034
+ self.temperal_downsample,
1035
+ dropout,
1036
+ )
1037
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
1038
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
1039
+ self.decoder = Decoder3d(
1040
+ dim,
1041
+ z_dim,
1042
+ dim_mult,
1043
+ num_res_blocks,
1044
+ attn_scales,
1045
+ self.temperal_upsample,
1046
+ dropout,
1047
+ )
1048
+
1049
+ def forward(self, x):
1050
+ mu, log_var = self.encode(x)
1051
+ z = self.reparameterize(mu, log_var)
1052
+ x_recon = self.decode(z)
1053
+ return x_recon, mu, log_var
1054
+
1055
+ def encode(self, x, scale):
1056
+ self.clear_cache()
1057
+ ## cache
1058
+ t = x.shape[2]
1059
+ iter_ = 1 + (t - 1) // 4
1060
+
1061
+ for i in range(iter_):
1062
+ self._enc_conv_idx = [0]
1063
+ if i == 0:
1064
+ out = self.encoder(
1065
+ x[:, :, :1, :, :],
1066
+ feat_cache=self._enc_feat_map,
1067
+ feat_idx=self._enc_conv_idx,
1068
+ )
1069
+ else:
1070
+ out_ = self.encoder(
1071
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
1072
+ feat_cache=self._enc_feat_map,
1073
+ feat_idx=self._enc_conv_idx,
1074
+ )
1075
+ out = torch.cat([out, out_], 2)
1076
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
1077
+ if isinstance(scale[0], torch.Tensor):
1078
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
1079
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1080
+ 1, self.z_dim, 1, 1, 1
1081
+ )
1082
+ else:
1083
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
1084
+ mu = (mu - scale[0]) * scale[1]
1085
+ return mu
1086
+
1087
+ def decode(self, z, scale):
1088
+ self.clear_cache()
1089
+ # z: [b,c,t,h,w]
1090
+ if isinstance(scale[0], torch.Tensor):
1091
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
1092
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1093
+ 1, self.z_dim, 1, 1, 1
1094
+ )
1095
+ else:
1096
+ scale = scale.to(dtype=z.dtype, device=z.device)
1097
+ z = z / scale[1] + scale[0]
1098
+ iter_ = z.shape[2]
1099
+ x = self.conv2(z)
1100
+ for i in range(iter_):
1101
+ self._conv_idx = [0]
1102
+ if i == 0:
1103
+ out = self.decoder(
1104
+ x[:, :, i : i + 1, :, :],
1105
+ feat_cache=self._feat_map,
1106
+ feat_idx=self._conv_idx,
1107
+ )
1108
+ else:
1109
+ out_ = self.decoder(
1110
+ x[:, :, i : i + 1, :, :],
1111
+ feat_cache=self._feat_map,
1112
+ feat_idx=self._conv_idx,
1113
+ )
1114
+ out = torch.cat([out, out_], 2) # may add tensor offload
1115
+ return out
1116
+
1117
+ def reparameterize(self, mu, log_var):
1118
+ std = torch.exp(0.5 * log_var)
1119
+ eps = torch.randn_like(std)
1120
+ return eps * std + mu
1121
+
1122
+ def sample(self, imgs, deterministic=False):
1123
+ mu, log_var = self.encode(imgs)
1124
+ if deterministic:
1125
+ return mu
1126
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
1127
+ return mu + std * torch.randn_like(std)
1128
+
1129
+ def clear_cache(self):
1130
+ self._conv_num = count_conv3d(self.decoder)
1131
+ self._conv_idx = [0]
1132
+ self._feat_map = [None] * self._conv_num
1133
+ # cache encode
1134
+ self._enc_conv_num = count_conv3d(self.encoder)
1135
+ self._enc_conv_idx = [0]
1136
+ self._enc_feat_map = [None] * self._enc_conv_num
1137
+
1138
+
1139
+ class WanVideoVAE(nn.Module):
1140
+ def __init__(self, z_dim=16):
1141
+ super().__init__()
1142
+
1143
+ mean = [
1144
+ -0.7571,
1145
+ -0.7089,
1146
+ -0.9113,
1147
+ 0.1075,
1148
+ -0.1745,
1149
+ 0.9653,
1150
+ -0.1517,
1151
+ 1.5508,
1152
+ 0.4134,
1153
+ -0.0715,
1154
+ 0.5517,
1155
+ -0.3632,
1156
+ -0.1922,
1157
+ -0.9497,
1158
+ 0.2503,
1159
+ -0.2921,
1160
+ ]
1161
+ std = [
1162
+ 2.8184,
1163
+ 1.4541,
1164
+ 2.3275,
1165
+ 2.6558,
1166
+ 1.2196,
1167
+ 1.7708,
1168
+ 2.6052,
1169
+ 2.0743,
1170
+ 3.2687,
1171
+ 2.1526,
1172
+ 2.8652,
1173
+ 1.5579,
1174
+ 1.6382,
1175
+ 1.1253,
1176
+ 2.8251,
1177
+ 1.9160,
1178
+ ]
1179
+ self.mean = torch.tensor(mean)
1180
+ self.std = torch.tensor(std)
1181
+ self.scale = [self.mean, 1.0 / self.std]
1182
+
1183
+ # init model
1184
+ self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
1185
+ self.upsampling_factor = 8
1186
+ self.z_dim = z_dim
1187
+
1188
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
1189
+ x = torch.ones((length,))
1190
+ if not left_bound:
1191
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
1192
+ if not right_bound:
1193
+ x[-border_width:] = torch.flip(
1194
+ (torch.arange(border_width) + 1) / border_width, dims=(0,)
1195
+ )
1196
+ return x
1197
+
1198
+ def build_mask(self, data, is_bound, border_width):
1199
+ _, _, _, H, W = data.shape
1200
+ h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
1201
+ w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
1202
+
1203
+ h = repeat(h, "H -> H W", H=H, W=W)
1204
+ w = repeat(w, "W -> H W", H=H, W=W)
1205
+
1206
+ mask = torch.stack([h, w]).min(dim=0).values
1207
+ mask = rearrange(mask, "H W -> 1 1 1 H W")
1208
+ return mask
1209
+
1210
+ def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
1211
+ _, _, T, H, W = hidden_states.shape
1212
+ size_h, size_w = tile_size
1213
+ stride_h, stride_w = tile_stride
1214
+
1215
+ # Split tasks
1216
+ tasks = []
1217
+ for h in range(0, H, stride_h):
1218
+ if h - stride_h >= 0 and h - stride_h + size_h >= H:
1219
+ continue
1220
+ for w in range(0, W, stride_w):
1221
+ if w - stride_w >= 0 and w - stride_w + size_w >= W:
1222
+ continue
1223
+ h_, w_ = h + size_h, w + size_w
1224
+ tasks.append((h, h_, w, w_))
1225
+
1226
+ data_device = "cpu"
1227
+ computation_device = device
1228
+
1229
+ out_T = T * 4 - 3
1230
+ weight = torch.zeros(
1231
+ (1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor),
1232
+ dtype=hidden_states.dtype,
1233
+ device=data_device,
1234
+ )
1235
+ values = torch.zeros(
1236
+ (1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor),
1237
+ dtype=hidden_states.dtype,
1238
+ device=data_device,
1239
+ )
1240
+
1241
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
1242
+ hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(
1243
+ computation_device
1244
+ )
1245
+ hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(
1246
+ data_device
1247
+ )
1248
+
1249
+ mask = self.build_mask(
1250
+ hidden_states_batch,
1251
+ is_bound=(h == 0, h_ >= H, w == 0, w_ >= W),
1252
+ border_width=(
1253
+ (size_h - stride_h) * self.upsampling_factor,
1254
+ (size_w - stride_w) * self.upsampling_factor,
1255
+ ),
1256
+ ).to(dtype=hidden_states.dtype, device=data_device)
1257
+
1258
+ target_h = h * self.upsampling_factor
1259
+ target_w = w * self.upsampling_factor
1260
+ values[
1261
+ :,
1262
+ :,
1263
+ :,
1264
+ target_h : target_h + hidden_states_batch.shape[3],
1265
+ target_w : target_w + hidden_states_batch.shape[4],
1266
+ ] += hidden_states_batch * mask
1267
+ weight[
1268
+ :,
1269
+ :,
1270
+ :,
1271
+ target_h : target_h + hidden_states_batch.shape[3],
1272
+ target_w : target_w + hidden_states_batch.shape[4],
1273
+ ] += mask
1274
+ values = values / weight
1275
+ values = values.clamp_(-1, 1)
1276
+ return values
1277
+
1278
+ def tiled_encode(self, video, device, tile_size, tile_stride):
1279
+ _, _, T, H, W = video.shape
1280
+ size_h, size_w = tile_size
1281
+ stride_h, stride_w = tile_stride
1282
+
1283
+ # Split tasks
1284
+ tasks = []
1285
+ for h in range(0, H, stride_h):
1286
+ if h - stride_h >= 0 and h - stride_h + size_h >= H:
1287
+ continue
1288
+ for w in range(0, W, stride_w):
1289
+ if w - stride_w >= 0 and w - stride_w + size_w >= W:
1290
+ continue
1291
+ h_, w_ = h + size_h, w + size_w
1292
+ tasks.append((h, h_, w, w_))
1293
+
1294
+ data_device = "cpu"
1295
+ computation_device = device
1296
+
1297
+ out_T = (T + 3) // 4
1298
+ weight = torch.zeros(
1299
+ (1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor),
1300
+ dtype=video.dtype,
1301
+ device=data_device,
1302
+ )
1303
+ values = torch.zeros(
1304
+ (
1305
+ 1,
1306
+ self.z_dim,
1307
+ out_T,
1308
+ H // self.upsampling_factor,
1309
+ W // self.upsampling_factor,
1310
+ ),
1311
+ dtype=video.dtype,
1312
+ device=data_device,
1313
+ )
1314
+
1315
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
1316
+ hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
1317
+ hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(
1318
+ data_device
1319
+ )
1320
+
1321
+ mask = self.build_mask(
1322
+ hidden_states_batch,
1323
+ is_bound=(h == 0, h_ >= H, w == 0, w_ >= W),
1324
+ border_width=(
1325
+ (size_h - stride_h) // self.upsampling_factor,
1326
+ (size_w - stride_w) // self.upsampling_factor,
1327
+ ),
1328
+ ).to(dtype=video.dtype, device=data_device)
1329
+
1330
+ target_h = h // self.upsampling_factor
1331
+ target_w = w // self.upsampling_factor
1332
+ values[
1333
+ :,
1334
+ :,
1335
+ :,
1336
+ target_h : target_h + hidden_states_batch.shape[3],
1337
+ target_w : target_w + hidden_states_batch.shape[4],
1338
+ ] += hidden_states_batch * mask
1339
+ weight[
1340
+ :,
1341
+ :,
1342
+ :,
1343
+ target_h : target_h + hidden_states_batch.shape[3],
1344
+ target_w : target_w + hidden_states_batch.shape[4],
1345
+ ] += mask
1346
+ values = values / weight
1347
+ return values
1348
+
1349
+ def single_encode(self, video, device):
1350
+ video = video.to(device)
1351
+ x = self.model.encode(video, self.scale)
1352
+ return x
1353
+
1354
+ def single_decode(self, hidden_state, device):
1355
+ hidden_state = hidden_state.to(device)
1356
+ video = self.model.decode(hidden_state, self.scale)
1357
+ return video.clamp_(-1, 1)
1358
+
1359
+ def encode(
1360
+ self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)
1361
+ ):
1362
+ videos = [video.to("cpu") for video in videos]
1363
+ hidden_states = []
1364
+ for video in videos:
1365
+ video = video.unsqueeze(0)
1366
+ if tiled:
1367
+ tile_size = (
1368
+ tile_size[0] * self.upsampling_factor,
1369
+ tile_size[1] * self.upsampling_factor,
1370
+ )
1371
+ tile_stride = (
1372
+ tile_stride[0] * self.upsampling_factor,
1373
+ tile_stride[1] * self.upsampling_factor,
1374
+ )
1375
+ hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
1376
+ else:
1377
+ hidden_state = self.single_encode(video, device)
1378
+ hidden_state = hidden_state.squeeze(0)
1379
+ hidden_states.append(hidden_state)
1380
+ hidden_states = torch.stack(hidden_states)
1381
+ return hidden_states
1382
+
1383
+ def decode(
1384
+ self,
1385
+ hidden_states,
1386
+ device,
1387
+ tiled=False,
1388
+ tile_size=(34, 34),
1389
+ tile_stride=(18, 16),
1390
+ ):
1391
+ if tiled:
1392
+ video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
1393
+ else:
1394
+ video = self.single_decode(hidden_states, device)
1395
+ return video
1396
+
1397
+ @staticmethod
1398
+ def state_dict_converter():
1399
+ return WanVideoVAEStateDictConverter()
1400
+
1401
+
1402
+ class WanVideoVAEStateDictConverter:
1403
+ def __init__(self):
1404
+ pass
1405
+
1406
+ def from_civitai(self, state_dict):
1407
+ state_dict_ = {}
1408
+ if "model_state" in state_dict:
1409
+ state_dict = state_dict["model_state"]
1410
+ for name in state_dict:
1411
+ state_dict_["model." + name] = state_dict[name]
1412
+ return state_dict_
1413
+
1414
+
1415
+ class VideoVAE38_(VideoVAE_):
1416
+ def __init__(
1417
+ self,
1418
+ dim=160,
1419
+ z_dim=48,
1420
+ dec_dim=256,
1421
+ dim_mult=[1, 2, 4, 4],
1422
+ num_res_blocks=2,
1423
+ attn_scales=[],
1424
+ temperal_downsample=[False, True, True],
1425
+ dropout=0.0,
1426
+ ):
1427
+ super(VideoVAE_, self).__init__()
1428
+ self.dim = dim
1429
+ self.z_dim = z_dim
1430
+ self.dim_mult = dim_mult
1431
+ self.num_res_blocks = num_res_blocks
1432
+ self.attn_scales = attn_scales
1433
+ self.temperal_downsample = temperal_downsample
1434
+ self.temperal_upsample = temperal_downsample[::-1]
1435
+
1436
+ # modules
1437
+ self.encoder = Encoder3d_38(
1438
+ dim,
1439
+ z_dim * 2,
1440
+ dim_mult,
1441
+ num_res_blocks,
1442
+ attn_scales,
1443
+ self.temperal_downsample,
1444
+ dropout,
1445
+ )
1446
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
1447
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
1448
+ self.decoder = Decoder3d_38(
1449
+ dec_dim,
1450
+ z_dim,
1451
+ dim_mult,
1452
+ num_res_blocks,
1453
+ attn_scales,
1454
+ self.temperal_upsample,
1455
+ dropout,
1456
+ )
1457
+
1458
+ def encode(self, x, scale):
1459
+ self.clear_cache()
1460
+ x = patchify(x, patch_size=2)
1461
+ t = x.shape[2]
1462
+ iter_ = 1 + (t - 1) // 4
1463
+ for i in range(iter_):
1464
+ self._enc_conv_idx = [0]
1465
+ if i == 0:
1466
+ out = self.encoder(
1467
+ x[:, :, :1, :, :],
1468
+ feat_cache=self._enc_feat_map,
1469
+ feat_idx=self._enc_conv_idx,
1470
+ )
1471
+ else:
1472
+ out_ = self.encoder(
1473
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
1474
+ feat_cache=self._enc_feat_map,
1475
+ feat_idx=self._enc_conv_idx,
1476
+ )
1477
+ out = torch.cat([out, out_], 2)
1478
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
1479
+ if isinstance(scale[0], torch.Tensor):
1480
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
1481
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1482
+ 1, self.z_dim, 1, 1, 1
1483
+ )
1484
+ else:
1485
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
1486
+ mu = (mu - scale[0]) * scale[1]
1487
+ self.clear_cache()
1488
+ return mu
1489
+
1490
+ def decode(self, z, scale):
1491
+ self.clear_cache()
1492
+ if isinstance(scale[0], torch.Tensor):
1493
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
1494
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1495
+ 1, self.z_dim, 1, 1, 1
1496
+ )
1497
+ else:
1498
+ scale = scale.to(dtype=z.dtype, device=z.device)
1499
+ z = z / scale[1] + scale[0]
1500
+ iter_ = z.shape[2]
1501
+ x = self.conv2(z)
1502
+ for i in range(iter_):
1503
+ self._conv_idx = [0]
1504
+ if i == 0:
1505
+ out = self.decoder(
1506
+ x[:, :, i : i + 1, :, :],
1507
+ feat_cache=self._feat_map,
1508
+ feat_idx=self._conv_idx,
1509
+ first_chunk=True,
1510
+ )
1511
+ else:
1512
+ out_ = self.decoder(
1513
+ x[:, :, i : i + 1, :, :],
1514
+ feat_cache=self._feat_map,
1515
+ feat_idx=self._conv_idx,
1516
+ )
1517
+ out = torch.cat([out, out_], 2)
1518
+ out = unpatchify(out, patch_size=2)
1519
+ self.clear_cache()
1520
+ return out
1521
+
1522
+
1523
+ class WanVideoVAE38(WanVideoVAE):
1524
+ def __init__(self, z_dim=48, dim=160):
1525
+ super(WanVideoVAE, self).__init__()
1526
+
1527
+ mean = [
1528
+ -0.2289,
1529
+ -0.0052,
1530
+ -0.1323,
1531
+ -0.2339,
1532
+ -0.2799,
1533
+ 0.0174,
1534
+ 0.1838,
1535
+ 0.1557,
1536
+ -0.1382,
1537
+ 0.0542,
1538
+ 0.2813,
1539
+ 0.0891,
1540
+ 0.1570,
1541
+ -0.0098,
1542
+ 0.0375,
1543
+ -0.1825,
1544
+ -0.2246,
1545
+ -0.1207,
1546
+ -0.0698,
1547
+ 0.5109,
1548
+ 0.2665,
1549
+ -0.2108,
1550
+ -0.2158,
1551
+ 0.2502,
1552
+ -0.2055,
1553
+ -0.0322,
1554
+ 0.1109,
1555
+ 0.1567,
1556
+ -0.0729,
1557
+ 0.0899,
1558
+ -0.2799,
1559
+ -0.1230,
1560
+ -0.0313,
1561
+ -0.1649,
1562
+ 0.0117,
1563
+ 0.0723,
1564
+ -0.2839,
1565
+ -0.2083,
1566
+ -0.0520,
1567
+ 0.3748,
1568
+ 0.0152,
1569
+ 0.1957,
1570
+ 0.1433,
1571
+ -0.2944,
1572
+ 0.3573,
1573
+ -0.0548,
1574
+ -0.1681,
1575
+ -0.0667,
1576
+ ]
1577
+ std = [
1578
+ 0.4765,
1579
+ 1.0364,
1580
+ 0.4514,
1581
+ 1.1677,
1582
+ 0.5313,
1583
+ 0.4990,
1584
+ 0.4818,
1585
+ 0.5013,
1586
+ 0.8158,
1587
+ 1.0344,
1588
+ 0.5894,
1589
+ 1.0901,
1590
+ 0.6885,
1591
+ 0.6165,
1592
+ 0.8454,
1593
+ 0.4978,
1594
+ 0.5759,
1595
+ 0.3523,
1596
+ 0.7135,
1597
+ 0.6804,
1598
+ 0.5833,
1599
+ 1.4146,
1600
+ 0.8986,
1601
+ 0.5659,
1602
+ 0.7069,
1603
+ 0.5338,
1604
+ 0.4889,
1605
+ 0.4917,
1606
+ 0.4069,
1607
+ 0.4999,
1608
+ 0.6866,
1609
+ 0.4093,
1610
+ 0.5709,
1611
+ 0.6065,
1612
+ 0.6415,
1613
+ 0.4944,
1614
+ 0.5726,
1615
+ 1.2042,
1616
+ 0.5458,
1617
+ 1.6887,
1618
+ 0.3971,
1619
+ 1.0600,
1620
+ 0.3943,
1621
+ 0.5537,
1622
+ 0.5444,
1623
+ 0.4089,
1624
+ 0.7468,
1625
+ 0.7744,
1626
+ ]
1627
+ self.mean = torch.tensor(mean)
1628
+ self.std = torch.tensor(std)
1629
+ self.scale = [self.mean, 1.0 / self.std]
1630
+
1631
+ # init model
1632
+ self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False)
1633
+ self.upsampling_factor = 16
1634
+ self.z_dim = z_dim
pipelines/base.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from torchvision.transforms import GaussianBlur
5
+
6
+
7
+ class BasePipeline(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ device="cuda",
11
+ torch_dtype=torch.float16,
12
+ height_division_factor=64,
13
+ width_division_factor=64,
14
+ ):
15
+ super().__init__()
16
+ self.device = device
17
+ self.torch_dtype = torch_dtype
18
+ self.height_division_factor = height_division_factor
19
+ self.width_division_factor = width_division_factor
20
+ self.cpu_offload = False
21
+ self.model_names = []
22
+
23
+ def check_resize_height_width(self, height, width):
24
+ if height % self.height_division_factor != 0:
25
+ height = (
26
+ (height + self.height_division_factor - 1)
27
+ // self.height_division_factor
28
+ * self.height_division_factor
29
+ )
30
+ print(
31
+ f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}."
32
+ )
33
+ if width % self.width_division_factor != 0:
34
+ width = (
35
+ (width + self.width_division_factor - 1)
36
+ // self.width_division_factor
37
+ * self.width_division_factor
38
+ )
39
+ print(
40
+ f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}."
41
+ )
42
+ return height, width
43
+
44
+ def preprocess_image(self, image):
45
+ image = (
46
+ torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1)
47
+ .permute(2, 0, 1)
48
+ .unsqueeze(0)
49
+ )
50
+ return image
51
+
52
+ def preprocess_images(self, images):
53
+ return [self.preprocess_image(image) for image in images]
54
+
55
+ def vae_output_to_image(self, vae_output):
56
+ image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
57
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
58
+ return image
59
+
60
+ def vae_output_to_video(self, vae_output):
61
+ video = vae_output.cpu().permute(1, 2, 0).numpy()
62
+ video = [
63
+ Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
64
+ for image in video
65
+ ]
66
+ return video
67
+
68
+ def merge_latents(
69
+ self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0
70
+ ):
71
+ if len(latents) > 0:
72
+ blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
73
+ height, width = value.shape[-2:]
74
+ weight = torch.ones_like(value)
75
+ for latent, mask, scale in zip(latents, masks, scales):
76
+ mask = (
77
+ self.preprocess_image(mask.resize((width, height))).mean(
78
+ dim=1, keepdim=True
79
+ )
80
+ > 0
81
+ )
82
+ mask = mask.repeat(1, latent.shape[1], 1, 1).to(
83
+ dtype=latent.dtype, device=latent.device
84
+ )
85
+ mask = blur(mask)
86
+ value += latent * mask * scale
87
+ weight += mask * scale
88
+ value /= weight
89
+ return value
90
+
91
+ def control_noise_via_local_prompts(
92
+ self,
93
+ prompt_emb_global,
94
+ prompt_emb_locals,
95
+ masks,
96
+ mask_scales,
97
+ inference_callback,
98
+ special_kwargs=None,
99
+ special_local_kwargs_list=None,
100
+ ):
101
+ if special_kwargs is None:
102
+ noise_pred_global = inference_callback(prompt_emb_global)
103
+ else:
104
+ noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
105
+ if special_local_kwargs_list is None:
106
+ noise_pred_locals = [
107
+ inference_callback(prompt_emb_local)
108
+ for prompt_emb_local in prompt_emb_locals
109
+ ]
110
+ else:
111
+ noise_pred_locals = [
112
+ inference_callback(prompt_emb_local, special_kwargs)
113
+ for prompt_emb_local, special_kwargs in zip(
114
+ prompt_emb_locals, special_local_kwargs_list
115
+ )
116
+ ]
117
+ noise_pred = self.merge_latents(
118
+ noise_pred_global, noise_pred_locals, masks, mask_scales
119
+ )
120
+ return noise_pred
121
+
122
+ def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
123
+ local_prompts = local_prompts or []
124
+ masks = masks or []
125
+ mask_scales = mask_scales or []
126
+ extended_prompt_dict = self.prompter.extend_prompt(prompt)
127
+ prompt = extended_prompt_dict.get("prompt", prompt)
128
+ local_prompts += extended_prompt_dict.get("prompts", [])
129
+ masks += extended_prompt_dict.get("masks", [])
130
+ mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
131
+ return prompt, local_prompts, masks, mask_scales
132
+
133
+ def enable_cpu_offload(self):
134
+ self.cpu_offload = True
135
+
136
+ def load_models_to_device(self, loadmodel_names=[]):
137
+ # only load models to device if cpu_offload is enabled
138
+ if not self.cpu_offload:
139
+ return
140
+ # offload the unneeded models to cpu
141
+ for model_name in self.model_names:
142
+ if model_name not in loadmodel_names:
143
+ model = getattr(self, model_name)
144
+ if model is not None:
145
+ if (
146
+ hasattr(model, "vram_management_enabled")
147
+ and model.vram_management_enabled
148
+ ):
149
+ for module in model.modules():
150
+ if hasattr(module, "offload"):
151
+ module.offload()
152
+ else:
153
+ model.cpu()
154
+ # load the needed models to device
155
+ for model_name in loadmodel_names:
156
+ model = getattr(self, model_name)
157
+ if model is not None:
158
+ if (
159
+ hasattr(model, "vram_management_enabled")
160
+ and model.vram_management_enabled
161
+ ):
162
+ for module in model.modules():
163
+ if hasattr(module, "onload"):
164
+ module.onload()
165
+ else:
166
+ model.to(self.device)
167
+ # fresh the cuda cache
168
+ torch.cuda.empty_cache()
169
+
170
+ def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
171
+ generator = None if seed is None else torch.Generator(device).manual_seed(seed)
172
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
173
+ return noise
pipelines/wan_video.py ADDED
@@ -0,0 +1,1793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, types
2
+ import numpy as np
3
+ from PIL import Image
4
+ from einops import repeat
5
+ from typing import Optional, Union
6
+ from einops import rearrange
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ from typing import Optional
10
+ from typing_extensions import Literal
11
+ import imageio
12
+ import os
13
+ from typing import List, Tuple
14
+ import PIL
15
+ from utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
16
+ from models import ModelManager, load_state_dict
17
+ from models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
18
+ from models.wan_video_text_encoder import (
19
+ WanTextEncoder,
20
+ T5RelativeEmbedding,
21
+ T5LayerNorm,
22
+ )
23
+ from models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
24
+ from models.wan_video_image_encoder import WanImageEncoder
25
+ from models.wan_video_vace import VaceWanModel
26
+ from models.wan_video_motion_controller import WanMotionControllerModel
27
+ from schedulers.flow_match import FlowMatchScheduler
28
+ from prompters import WanPrompter
29
+ from vram_management import (
30
+ enable_vram_management,
31
+ AutoWrappedModule,
32
+ AutoWrappedLinear,
33
+ WanAutoCastLayerNorm,
34
+ )
35
+ from lora import GeneralLoRALoader
36
+
37
+ def load_video_as_list(video_path: str) -> Tuple[List[Image.Image], int, int, int]:
38
+ if not os.path.isfile(video_path):
39
+ raise FileNotFoundError(f"Video file not found: {video_path}")
40
+
41
+ reader = imageio.get_reader(video_path)
42
+
43
+ meta_data = reader.get_meta_data()
44
+ original_width = meta_data['size'][0]
45
+ original_height = meta_data['size'][1]
46
+
47
+ new_width = (original_width // 16) * 16
48
+ new_height = (original_height // 16) * 16
49
+
50
+ left = (original_width - new_width) // 2
51
+ top = (original_height - new_height) // 2
52
+ right = left + new_width
53
+ bottom = top + new_height
54
+ crop_box = (left, top, right, bottom)
55
+
56
+ original_frame_count = reader.count_frames()
57
+ new_frame_count = original_frame_count - ((original_frame_count - 1) % 4)
58
+
59
+ frames = []
60
+ for i in range(new_frame_count):
61
+ try:
62
+ frame_data = reader.get_data(i)
63
+ pil_image = Image.fromarray(frame_data)
64
+ cropped_image = pil_image.crop(crop_box)
65
+ frames.append(cropped_image)
66
+ except IndexError:
67
+ print(f"Warning: Actual number of frames is less than expected. Stopping at frame {i}.")
68
+ new_frame_count = len(frames)
69
+ break
70
+
71
+ reader.close()
72
+
73
+ return frames, new_width, new_height, new_frame_count
74
+
75
+ class WanVideoPipeline(BasePipeline):
76
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
77
+ super().__init__(
78
+ device=device,
79
+ torch_dtype=torch_dtype,
80
+ height_division_factor=16,
81
+ width_division_factor=16,
82
+ time_division_factor=4,
83
+ time_division_remainder=1,
84
+ )
85
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
86
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
87
+ self.text_encoder: WanTextEncoder = None
88
+ self.image_encoder: WanImageEncoder = None
89
+ self.dit: WanModel = None
90
+ self.dit2: WanModel = None
91
+ self.vae: WanVideoVAE = None
92
+ self.motion_controller: WanMotionControllerModel = None
93
+ self.vace: VaceWanModel = None
94
+ self.in_iteration_models = ("dit", "motion_controller", "vace")
95
+ self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
96
+ self.unit_runner = PipelineUnitRunner()
97
+ self.units = [
98
+ WanVideoUnit_ShapeChecker(),
99
+ WanVideoUnit_NoiseInitializer(),
100
+ WanVideoUnit_InputVideoEmbedder(),
101
+ WanVideoUnit_PromptEmbedder(),
102
+ WanVideoUnit_ImageEmbedderVAE(),
103
+ WanVideoUnit_ImageEmbedderCLIP(),
104
+ WanVideoUnit_ImageEmbedderFused(),
105
+ WanVideoUnit_FunControl(),
106
+ WanVideoUnit_FunReference(),
107
+ WanVideoUnit_FunCameraControl(),
108
+ WanVideoUnit_SpeedControl(),
109
+ WanVideoUnit_VACE(),
110
+ WanVideoUnit_UnifiedSequenceParallel(),
111
+ WanVideoUnit_TeaCache(),
112
+ WanVideoUnit_CfgMerger(),
113
+ ]
114
+ self.model_fn = model_fn_wan_video
115
+
116
+ def encode_ip_image(self, ip_image):
117
+ self.load_models_to_device(["vae"])
118
+ ip_image = (
119
+ torch.tensor(np.array(ip_image)).permute(2, 0, 1).float() / 255.0
120
+ ) # [3, H, W]
121
+ ip_image = (
122
+ ip_image.unsqueeze(1).unsqueeze(0).to(dtype=self.torch_dtype)
123
+ ) # [B, 3, 1, H, W]
124
+ ip_image = ip_image * 2 - 1
125
+ ip_image_latent = self.vae.encode(ip_image, device=self.device, tiled=False)
126
+ return ip_image_latent
127
+
128
+ def load_lora(self, module, path, alpha=1):
129
+ loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
130
+ lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
131
+ loader.load(module, lora, alpha=alpha)
132
+
133
+ def training_loss(self, **inputs):
134
+ max_timestep_boundary = int(
135
+ inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps
136
+ )
137
+ min_timestep_boundary = int(
138
+ inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps
139
+ )
140
+ timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
141
+ timestep = self.scheduler.timesteps[timestep_id].to(
142
+ dtype=self.torch_dtype, device=self.device
143
+ )
144
+
145
+ inputs["latents"] = self.scheduler.add_noise(
146
+ inputs["input_latents"], inputs["noise"], timestep
147
+ )
148
+ training_target = self.scheduler.training_target(
149
+ inputs["input_latents"], inputs["noise"], timestep
150
+ )
151
+
152
+ noise_pred = self.model_fn(**inputs, timestep=timestep)
153
+
154
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
155
+ loss = loss * self.scheduler.training_weight(timestep)
156
+ return loss
157
+
158
+ def enable_vram_management(
159
+ self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5
160
+ ):
161
+ self.vram_management_enabled = True
162
+ if num_persistent_param_in_dit is not None:
163
+ vram_limit = None
164
+ else:
165
+ if vram_limit is None:
166
+ vram_limit = self.get_vram()
167
+ vram_limit = vram_limit - vram_buffer
168
+ if self.text_encoder is not None:
169
+ dtype = next(iter(self.text_encoder.parameters())).dtype
170
+ enable_vram_management(
171
+ self.text_encoder,
172
+ module_map={
173
+ torch.nn.Linear: AutoWrappedLinear,
174
+ torch.nn.Embedding: AutoWrappedModule,
175
+ T5RelativeEmbedding: AutoWrappedModule,
176
+ T5LayerNorm: AutoWrappedModule,
177
+ },
178
+ module_config=dict(
179
+ offload_dtype=dtype,
180
+ offload_device="cpu",
181
+ onload_dtype=dtype,
182
+ onload_device="cpu",
183
+ computation_dtype=self.torch_dtype,
184
+ computation_device=self.device,
185
+ ),
186
+ vram_limit=vram_limit,
187
+ )
188
+ if self.dit is not None:
189
+ dtype = next(iter(self.dit.parameters())).dtype
190
+ device = "cpu" if vram_limit is not None else self.device
191
+ enable_vram_management(
192
+ self.dit,
193
+ module_map={
194
+ torch.nn.Linear: AutoWrappedLinear,
195
+ torch.nn.Conv3d: AutoWrappedModule,
196
+ torch.nn.LayerNorm: WanAutoCastLayerNorm,
197
+ RMSNorm: AutoWrappedModule,
198
+ torch.nn.Conv2d: AutoWrappedModule,
199
+ },
200
+ module_config=dict(
201
+ offload_dtype=dtype,
202
+ offload_device="cpu",
203
+ onload_dtype=dtype,
204
+ onload_device=device,
205
+ computation_dtype=self.torch_dtype,
206
+ computation_device=self.device,
207
+ ),
208
+ max_num_param=num_persistent_param_in_dit,
209
+ overflow_module_config=dict(
210
+ offload_dtype=dtype,
211
+ offload_device="cpu",
212
+ onload_dtype=dtype,
213
+ onload_device="cpu",
214
+ computation_dtype=self.torch_dtype,
215
+ computation_device=self.device,
216
+ ),
217
+ vram_limit=vram_limit,
218
+ )
219
+ if self.dit2 is not None:
220
+ dtype = next(iter(self.dit2.parameters())).dtype
221
+ device = "cpu" if vram_limit is not None else self.device
222
+ enable_vram_management(
223
+ self.dit2,
224
+ module_map={
225
+ torch.nn.Linear: AutoWrappedLinear,
226
+ torch.nn.Conv3d: AutoWrappedModule,
227
+ torch.nn.LayerNorm: WanAutoCastLayerNorm,
228
+ RMSNorm: AutoWrappedModule,
229
+ torch.nn.Conv2d: AutoWrappedModule,
230
+ },
231
+ module_config=dict(
232
+ offload_dtype=dtype,
233
+ offload_device="cpu",
234
+ onload_dtype=dtype,
235
+ onload_device=device,
236
+ computation_dtype=self.torch_dtype,
237
+ computation_device=self.device,
238
+ ),
239
+ max_num_param=num_persistent_param_in_dit,
240
+ overflow_module_config=dict(
241
+ offload_dtype=dtype,
242
+ offload_device="cpu",
243
+ onload_dtype=dtype,
244
+ onload_device="cpu",
245
+ computation_dtype=self.torch_dtype,
246
+ computation_device=self.device,
247
+ ),
248
+ vram_limit=vram_limit,
249
+ )
250
+ if self.vae is not None:
251
+ dtype = next(iter(self.vae.parameters())).dtype
252
+ enable_vram_management(
253
+ self.vae,
254
+ module_map={
255
+ torch.nn.Linear: AutoWrappedLinear,
256
+ torch.nn.Conv2d: AutoWrappedModule,
257
+ RMS_norm: AutoWrappedModule,
258
+ CausalConv3d: AutoWrappedModule,
259
+ Upsample: AutoWrappedModule,
260
+ torch.nn.SiLU: AutoWrappedModule,
261
+ torch.nn.Dropout: AutoWrappedModule,
262
+ },
263
+ module_config=dict(
264
+ offload_dtype=dtype,
265
+ offload_device="cpu",
266
+ onload_dtype=dtype,
267
+ onload_device=self.device,
268
+ computation_dtype=self.torch_dtype,
269
+ computation_device=self.device,
270
+ ),
271
+ )
272
+ if self.image_encoder is not None:
273
+ dtype = next(iter(self.image_encoder.parameters())).dtype
274
+ enable_vram_management(
275
+ self.image_encoder,
276
+ module_map={
277
+ torch.nn.Linear: AutoWrappedLinear,
278
+ torch.nn.Conv2d: AutoWrappedModule,
279
+ torch.nn.LayerNorm: AutoWrappedModule,
280
+ },
281
+ module_config=dict(
282
+ offload_dtype=dtype,
283
+ offload_device="cpu",
284
+ onload_dtype=dtype,
285
+ onload_device="cpu",
286
+ computation_dtype=dtype,
287
+ computation_device=self.device,
288
+ ),
289
+ )
290
+ if self.motion_controller is not None:
291
+ dtype = next(iter(self.motion_controller.parameters())).dtype
292
+ enable_vram_management(
293
+ self.motion_controller,
294
+ module_map={
295
+ torch.nn.Linear: AutoWrappedLinear,
296
+ },
297
+ module_config=dict(
298
+ offload_dtype=dtype,
299
+ offload_device="cpu",
300
+ onload_dtype=dtype,
301
+ onload_device="cpu",
302
+ computation_dtype=dtype,
303
+ computation_device=self.device,
304
+ ),
305
+ )
306
+ if self.vace is not None:
307
+ device = "cpu" if vram_limit is not None else self.device
308
+ enable_vram_management(
309
+ self.vace,
310
+ module_map={
311
+ torch.nn.Linear: AutoWrappedLinear,
312
+ torch.nn.Conv3d: AutoWrappedModule,
313
+ torch.nn.LayerNorm: AutoWrappedModule,
314
+ RMSNorm: AutoWrappedModule,
315
+ },
316
+ module_config=dict(
317
+ offload_dtype=dtype,
318
+ offload_device="cpu",
319
+ onload_dtype=dtype,
320
+ onload_device=device,
321
+ computation_dtype=self.torch_dtype,
322
+ computation_device=self.device,
323
+ ),
324
+ vram_limit=vram_limit,
325
+ )
326
+
327
+ def initialize_usp(self):
328
+ import torch.distributed as dist
329
+ from xfuser.core.distributed import (
330
+ initialize_model_parallel,
331
+ init_distributed_environment,
332
+ )
333
+
334
+ dist.init_process_group(backend="nccl", init_method="env://")
335
+ init_distributed_environment(
336
+ rank=dist.get_rank(), world_size=dist.get_world_size()
337
+ )
338
+ initialize_model_parallel(
339
+ sequence_parallel_degree=dist.get_world_size(),
340
+ ring_degree=1,
341
+ ulysses_degree=dist.get_world_size(),
342
+ )
343
+ torch.cuda.set_device(dist.get_rank())
344
+
345
+ def enable_usp(self):
346
+ from xfuser.core.distributed import get_sequence_parallel_world_size
347
+ from distributed.xdit_context_parallel import (
348
+ usp_attn_forward,
349
+ usp_dit_forward,
350
+ )
351
+
352
+ for block in self.dit.blocks:
353
+ block.self_attn.forward = types.MethodType(
354
+ usp_attn_forward, block.self_attn
355
+ )
356
+ self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
357
+ if self.dit2 is not None:
358
+ for block in self.dit2.blocks:
359
+ block.self_attn.forward = types.MethodType(
360
+ usp_attn_forward, block.self_attn
361
+ )
362
+ self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
363
+ self.sp_size = get_sequence_parallel_world_size()
364
+ self.use_unified_sequence_parallel = True
365
+
366
+ @staticmethod
367
+ def from_pretrained(
368
+ torch_dtype: torch.dtype = torch.bfloat16,
369
+ device: Union[str, torch.device] = "cuda",
370
+ model_configs: list[ModelConfig] = [],
371
+ tokenizer_config: ModelConfig = ModelConfig(
372
+ model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"
373
+ ),
374
+ redirect_common_files: bool = True,
375
+ use_usp=False,
376
+ ):
377
+ # Redirect model path
378
+ if redirect_common_files:
379
+ redirect_dict = {
380
+ "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
381
+ "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
382
+ "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
383
+ }
384
+ for model_config in model_configs:
385
+ if (
386
+ model_config.origin_file_pattern is None
387
+ or model_config.model_id is None
388
+ ):
389
+ continue
390
+ if (
391
+ model_config.origin_file_pattern in redirect_dict
392
+ and model_config.model_id
393
+ != redirect_dict[model_config.origin_file_pattern]
394
+ ):
395
+ print(
396
+ f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection."
397
+ )
398
+ model_config.model_id = redirect_dict[
399
+ model_config.origin_file_pattern
400
+ ]
401
+
402
+ # Initialize pipeline
403
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
404
+ if use_usp:
405
+ pipe.initialize_usp()
406
+
407
+ # Download and load models
408
+ model_manager = ModelManager()
409
+ for model_config in model_configs:
410
+ model_config.download_if_necessary(use_usp=use_usp)
411
+ model_manager.load_model(
412
+ model_config.path,
413
+ device=model_config.offload_device or device,
414
+ torch_dtype=model_config.offload_dtype or torch_dtype,
415
+ )
416
+
417
+ # Load models
418
+ pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
419
+ dit = model_manager.fetch_model("wan_video_dit", index=2)
420
+ if isinstance(dit, list):
421
+ pipe.dit, pipe.dit2 = dit
422
+ else:
423
+ pipe.dit = dit
424
+ pipe.vae = model_manager.fetch_model("wan_video_vae")
425
+ pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
426
+ pipe.motion_controller = model_manager.fetch_model(
427
+ "wan_video_motion_controller"
428
+ )
429
+ pipe.vace = model_manager.fetch_model("wan_video_vace")
430
+
431
+ # Size division factor
432
+ if pipe.vae is not None:
433
+ pipe.height_division_factor = pipe.vae.upsampling_factor * 2
434
+ pipe.width_division_factor = pipe.vae.upsampling_factor * 2
435
+
436
+ # Initialize tokenizer
437
+ tokenizer_config.download_if_necessary(use_usp=use_usp)
438
+ pipe.prompter.fetch_models(pipe.text_encoder)
439
+ pipe.prompter.fetch_tokenizer(tokenizer_config.path)
440
+
441
+ # Unified Sequence Parallel
442
+ if use_usp:
443
+ pipe.enable_usp()
444
+ return pipe
445
+
446
+ @torch.no_grad()
447
+ def __call__(
448
+ self,
449
+ # Prompt
450
+ prompt: str,
451
+ negative_prompt: Optional[str] = "",
452
+ # Image-to-video
453
+ input_image: Optional[Image.Image] = None,
454
+ # First-last-frame-to-video
455
+ end_image: Optional[Image.Image] = None,
456
+ # Video-to-video
457
+ input_video: Optional[list[Image.Image]] = None,
458
+ denoising_strength: Optional[float] = 1.0,
459
+ # ControlNet
460
+ control_video: Optional[list[Image.Image]] = None,
461
+ reference_image: Optional[Image.Image] = None,
462
+ # Camera control
463
+ camera_control_direction: Optional[
464
+ Literal[
465
+ "Left",
466
+ "Right",
467
+ "Up",
468
+ "Down",
469
+ "LeftUp",
470
+ "LeftDown",
471
+ "RightUp",
472
+ "RightDown",
473
+ ]
474
+ ] = None,
475
+ camera_control_speed: Optional[float] = 1 / 54,
476
+ camera_control_origin: Optional[tuple] = (
477
+ 0,
478
+ 0.532139961,
479
+ 0.946026558,
480
+ 0.5,
481
+ 0.5,
482
+ 0,
483
+ 0,
484
+ 1,
485
+ 0,
486
+ 0,
487
+ 0,
488
+ 0,
489
+ 1,
490
+ 0,
491
+ 0,
492
+ 0,
493
+ 0,
494
+ 1,
495
+ 0,
496
+ ),
497
+ # VACE
498
+ vace_video: Optional[list[Image.Image]] = None,
499
+ vace_video_mask: Optional[Image.Image] = None,
500
+ vace_reference_image: Optional[Image.Image] = None,
501
+ vace_scale: Optional[float] = 1.0,
502
+ # Randomness
503
+ seed: Optional[int] = None,
504
+ rand_device: Optional[str] = "cpu",
505
+ # Shape
506
+ height: Optional[int] = 480,
507
+ width: Optional[int] = 832,
508
+ num_frames=81,
509
+ # Classifier-free guidance
510
+ cfg_scale: Optional[float] = 5.0,
511
+ cfg_merge: Optional[bool] = False,
512
+ # Boundary
513
+ switch_DiT_boundary: Optional[float] = 0.875,
514
+ # Scheduler
515
+ num_inference_steps: Optional[int] = 50,
516
+ sigma_shift: Optional[float] = 5.0,
517
+ # Speed control
518
+ motion_bucket_id: Optional[int] = None,
519
+ # VAE tiling
520
+ tiled: Optional[bool] = True,
521
+ tile_size: Optional[tuple[int, int]] = (30, 52),
522
+ tile_stride: Optional[tuple[int, int]] = (15, 26),
523
+ # Sliding window
524
+ sliding_window_size: Optional[int] = None,
525
+ sliding_window_stride: Optional[int] = None,
526
+ # Teacache
527
+ tea_cache_l1_thresh: Optional[float] = None,
528
+ tea_cache_model_id: Optional[str] = "",
529
+ # progress_bar
530
+ progress_bar_cmd=tqdm,
531
+ # Stand-In
532
+ ip_image=None,
533
+ ):
534
+ if ip_image is not None:
535
+ ip_image = self.encode_ip_image(ip_image)
536
+ if vace_video is not None:
537
+ vace_video, width, height, num_frames = load_video_as_list(vace_video)
538
+ if vace_reference_image is not None:
539
+ vace_reference_image = Image.open(vace_reference_image).convert('RGB')
540
+ ref_width, ref_height = vace_reference_image.size
541
+ if ref_width != width or ref_height != height:
542
+ scale_ratio = min(width / ref_width, height / ref_height)
543
+
544
+ new_ref_width = int(ref_width * scale_ratio)
545
+ new_ref_height = int(ref_height * scale_ratio)
546
+
547
+ resized_image = vace_reference_image.resize((new_ref_width, new_ref_height), Image.LANCZOS)
548
+
549
+ background = Image.new('RGB', (width, height), (255, 255, 255))
550
+
551
+ paste_x = (width - new_ref_width) // 2
552
+ paste_y = (height - new_ref_height) // 2
553
+ background.paste(resized_image, (paste_x, paste_y))
554
+
555
+ vace_reference_image = background
556
+ # Scheduler
557
+ self.scheduler.set_timesteps(
558
+ num_inference_steps,
559
+ denoising_strength=denoising_strength,
560
+ shift=sigma_shift,
561
+ )
562
+
563
+ # Inputs
564
+ inputs_posi = {
565
+ "prompt": prompt,
566
+ "tea_cache_l1_thresh": tea_cache_l1_thresh,
567
+ "tea_cache_model_id": tea_cache_model_id,
568
+ "num_inference_steps": num_inference_steps,
569
+ }
570
+ inputs_nega = {
571
+ "negative_prompt": negative_prompt,
572
+ "tea_cache_l1_thresh": tea_cache_l1_thresh,
573
+ "tea_cache_model_id": tea_cache_model_id,
574
+ "num_inference_steps": num_inference_steps,
575
+ }
576
+ inputs_shared = {
577
+ "input_image": input_image,
578
+ "end_image": end_image,
579
+ "input_video": input_video,
580
+ "denoising_strength": denoising_strength,
581
+ "control_video": control_video,
582
+ "reference_image": reference_image,
583
+ "camera_control_direction": camera_control_direction,
584
+ "camera_control_speed": camera_control_speed,
585
+ "camera_control_origin": camera_control_origin,
586
+ "vace_video": vace_video,
587
+ "vace_video_mask": vace_video_mask,
588
+ "vace_reference_image": vace_reference_image,
589
+ "vace_scale": vace_scale,
590
+ "seed": seed,
591
+ "rand_device": rand_device,
592
+ "height": height,
593
+ "width": width,
594
+ "num_frames": num_frames,
595
+ "cfg_scale": cfg_scale,
596
+ "cfg_merge": cfg_merge,
597
+ "sigma_shift": sigma_shift,
598
+ "motion_bucket_id": motion_bucket_id,
599
+ "tiled": tiled,
600
+ "tile_size": tile_size,
601
+ "tile_stride": tile_stride,
602
+ "sliding_window_size": sliding_window_size,
603
+ "sliding_window_stride": sliding_window_stride,
604
+ "ip_image": ip_image,
605
+ }
606
+ for unit in self.units:
607
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
608
+ unit, self, inputs_shared, inputs_posi, inputs_nega
609
+ )
610
+ # Denoise
611
+ self.load_models_to_device(self.in_iteration_models)
612
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
613
+ for progress_id, timestep in enumerate(
614
+ progress_bar_cmd(self.scheduler.timesteps)
615
+ ):
616
+ # Switch DiT if necessary
617
+ if (
618
+ timestep.item()
619
+ < switch_DiT_boundary * self.scheduler.num_train_timesteps
620
+ and self.dit2 is not None
621
+ and not models["dit"] is self.dit2
622
+ ):
623
+ self.load_models_to_device(self.in_iteration_models_2)
624
+ models["dit"] = self.dit2
625
+
626
+ # Timestep
627
+ timestep = timestep.unsqueeze(0).to(
628
+ dtype=self.torch_dtype, device=self.device
629
+ )
630
+
631
+ # Inference
632
+ noise_pred_posi = self.model_fn(
633
+ **models, **inputs_shared, **inputs_posi, timestep=timestep
634
+ )
635
+ inputs_shared["ip_image"] = None
636
+ if cfg_scale != 1.0:
637
+ if cfg_merge:
638
+ noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
639
+ else:
640
+ noise_pred_nega = self.model_fn(
641
+ **models, **inputs_shared, **inputs_nega, timestep=timestep
642
+ )
643
+ noise_pred = noise_pred_nega + cfg_scale * (
644
+ noise_pred_posi - noise_pred_nega
645
+ )
646
+ else:
647
+ noise_pred = noise_pred_posi
648
+
649
+ # Scheduler
650
+ inputs_shared["latents"] = self.scheduler.step(
651
+ noise_pred,
652
+ self.scheduler.timesteps[progress_id],
653
+ inputs_shared["latents"],
654
+ )
655
+ if "first_frame_latents" in inputs_shared:
656
+ inputs_shared["latents"][:, :, 0:1] = inputs_shared[
657
+ "first_frame_latents"
658
+ ]
659
+
660
+ if vace_reference_image is not None:
661
+ inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
662
+
663
+ # Decode
664
+ self.load_models_to_device(["vae"])
665
+ video = self.vae.decode(
666
+ inputs_shared["latents"],
667
+ device=self.device,
668
+ tiled=tiled,
669
+ tile_size=tile_size,
670
+ tile_stride=tile_stride,
671
+ )
672
+ video = self.vae_output_to_video(video)
673
+ self.load_models_to_device([])
674
+
675
+ return video
676
+
677
+
678
+ class WanVideoUnit_ShapeChecker(PipelineUnit):
679
+ def __init__(self):
680
+ super().__init__(input_params=("height", "width", "num_frames"))
681
+
682
+ def process(self, pipe: WanVideoPipeline, height, width, num_frames):
683
+ height, width, num_frames = pipe.check_resize_height_width(
684
+ height, width, num_frames
685
+ )
686
+ return {"height": height, "width": width, "num_frames": num_frames}
687
+
688
+
689
+ class WanVideoUnit_NoiseInitializer(PipelineUnit):
690
+ def __init__(self):
691
+ super().__init__(
692
+ input_params=(
693
+ "height",
694
+ "width",
695
+ "num_frames",
696
+ "seed",
697
+ "rand_device",
698
+ "vace_reference_image",
699
+ )
700
+ )
701
+
702
+ def process(
703
+ self,
704
+ pipe: WanVideoPipeline,
705
+ height,
706
+ width,
707
+ num_frames,
708
+ seed,
709
+ rand_device,
710
+ vace_reference_image,
711
+ ):
712
+ length = (num_frames - 1) // 4 + 1
713
+ if vace_reference_image is not None:
714
+ length += 1
715
+ shape = (
716
+ 1,
717
+ pipe.vae.model.z_dim,
718
+ length,
719
+ height // pipe.vae.upsampling_factor,
720
+ width // pipe.vae.upsampling_factor,
721
+ )
722
+ noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
723
+ if vace_reference_image is not None:
724
+ noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
725
+ return {"noise": noise}
726
+
727
+
728
+ class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
729
+ def __init__(self):
730
+ super().__init__(
731
+ input_params=(
732
+ "input_video",
733
+ "noise",
734
+ "tiled",
735
+ "tile_size",
736
+ "tile_stride",
737
+ "vace_reference_image",
738
+ ),
739
+ onload_model_names=("vae",),
740
+ )
741
+
742
+ def process(
743
+ self,
744
+ pipe: WanVideoPipeline,
745
+ input_video,
746
+ noise,
747
+ tiled,
748
+ tile_size,
749
+ tile_stride,
750
+ vace_reference_image,
751
+ ):
752
+ if input_video is None:
753
+ return {"latents": noise}
754
+ pipe.load_models_to_device(["vae"])
755
+ input_video = pipe.preprocess_video(input_video)
756
+ input_latents = pipe.vae.encode(
757
+ input_video,
758
+ device=pipe.device,
759
+ tiled=tiled,
760
+ tile_size=tile_size,
761
+ tile_stride=tile_stride,
762
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
763
+ if vace_reference_image is not None:
764
+ vace_reference_image = pipe.preprocess_video([vace_reference_image])
765
+ vace_reference_latents = pipe.vae.encode(
766
+ vace_reference_image, device=pipe.device
767
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
768
+ input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
769
+ if pipe.scheduler.training:
770
+ return {"latents": noise, "input_latents": input_latents}
771
+ else:
772
+ latents = pipe.scheduler.add_noise(
773
+ input_latents, noise, timestep=pipe.scheduler.timesteps[0]
774
+ )
775
+ return {"latents": latents}
776
+
777
+
778
+ class WanVideoUnit_PromptEmbedder(PipelineUnit):
779
+ def __init__(self):
780
+ super().__init__(
781
+ seperate_cfg=True,
782
+ input_params_posi={"prompt": "prompt", "positive": "positive"},
783
+ input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
784
+ onload_model_names=("text_encoder",),
785
+ )
786
+
787
+ def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict:
788
+ pipe.load_models_to_device(self.onload_model_names)
789
+ prompt_emb = pipe.prompter.encode_prompt(
790
+ prompt, positive=positive, device=pipe.device
791
+ )
792
+ return {"context": prompt_emb}
793
+
794
+
795
+ class WanVideoUnit_ImageEmbedder(PipelineUnit):
796
+ """
797
+ Deprecated
798
+ """
799
+
800
+ def __init__(self):
801
+ super().__init__(
802
+ input_params=(
803
+ "input_image",
804
+ "end_image",
805
+ "num_frames",
806
+ "height",
807
+ "width",
808
+ "tiled",
809
+ "tile_size",
810
+ "tile_stride",
811
+ ),
812
+ onload_model_names=("image_encoder", "vae"),
813
+ )
814
+
815
+ def process(
816
+ self,
817
+ pipe: WanVideoPipeline,
818
+ input_image,
819
+ end_image,
820
+ num_frames,
821
+ height,
822
+ width,
823
+ tiled,
824
+ tile_size,
825
+ tile_stride,
826
+ ):
827
+ if input_image is None or pipe.image_encoder is None:
828
+ return {}
829
+ pipe.load_models_to_device(self.onload_model_names)
830
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(
831
+ pipe.device
832
+ )
833
+ clip_context = pipe.image_encoder.encode_image([image])
834
+ msk = torch.ones(1, num_frames, height // 8, width // 8, device=pipe.device)
835
+ msk[:, 1:] = 0
836
+ if end_image is not None:
837
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
838
+ pipe.device
839
+ )
840
+ vae_input = torch.concat(
841
+ [
842
+ image.transpose(0, 1),
843
+ torch.zeros(3, num_frames - 2, height, width).to(image.device),
844
+ end_image.transpose(0, 1),
845
+ ],
846
+ dim=1,
847
+ )
848
+ if pipe.dit.has_image_pos_emb:
849
+ clip_context = torch.concat(
850
+ [clip_context, pipe.image_encoder.encode_image([end_image])], dim=1
851
+ )
852
+ msk[:, -1:] = 1
853
+ else:
854
+ vae_input = torch.concat(
855
+ [
856
+ image.transpose(0, 1),
857
+ torch.zeros(3, num_frames - 1, height, width).to(image.device),
858
+ ],
859
+ dim=1,
860
+ )
861
+
862
+ msk = torch.concat(
863
+ [torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1
864
+ )
865
+ msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
866
+ msk = msk.transpose(1, 2)[0]
867
+
868
+ y = pipe.vae.encode(
869
+ [vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)],
870
+ device=pipe.device,
871
+ tiled=tiled,
872
+ tile_size=tile_size,
873
+ tile_stride=tile_stride,
874
+ )[0]
875
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
876
+ y = torch.concat([msk, y])
877
+ y = y.unsqueeze(0)
878
+ clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
879
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
880
+ return {"clip_feature": clip_context, "y": y}
881
+
882
+
883
+ class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
884
+ def __init__(self):
885
+ super().__init__(
886
+ input_params=("input_image", "end_image", "height", "width"),
887
+ onload_model_names=("image_encoder",),
888
+ )
889
+
890
+ def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):
891
+ if (
892
+ input_image is None
893
+ or pipe.image_encoder is None
894
+ or not pipe.dit.require_clip_embedding
895
+ ):
896
+ return {}
897
+ pipe.load_models_to_device(self.onload_model_names)
898
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(
899
+ pipe.device
900
+ )
901
+ clip_context = pipe.image_encoder.encode_image([image])
902
+ if end_image is not None:
903
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
904
+ pipe.device
905
+ )
906
+ if pipe.dit.has_image_pos_emb:
907
+ clip_context = torch.concat(
908
+ [clip_context, pipe.image_encoder.encode_image([end_image])], dim=1
909
+ )
910
+ clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
911
+ return {"clip_feature": clip_context}
912
+
913
+
914
+ class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
915
+ def __init__(self):
916
+ super().__init__(
917
+ input_params=(
918
+ "input_image",
919
+ "end_image",
920
+ "num_frames",
921
+ "height",
922
+ "width",
923
+ "tiled",
924
+ "tile_size",
925
+ "tile_stride",
926
+ ),
927
+ onload_model_names=("vae",),
928
+ )
929
+
930
+ def process(
931
+ self,
932
+ pipe: WanVideoPipeline,
933
+ input_image,
934
+ end_image,
935
+ num_frames,
936
+ height,
937
+ width,
938
+ tiled,
939
+ tile_size,
940
+ tile_stride,
941
+ ):
942
+ if input_image is None or not pipe.dit.require_vae_embedding:
943
+ return {}
944
+ pipe.load_models_to_device(self.onload_model_names)
945
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(
946
+ pipe.device
947
+ )
948
+ msk = torch.ones(1, num_frames, height // 8, width // 8, device=pipe.device)
949
+ msk[:, 1:] = 0
950
+ if end_image is not None:
951
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
952
+ pipe.device
953
+ )
954
+ vae_input = torch.concat(
955
+ [
956
+ image.transpose(0, 1),
957
+ torch.zeros(3, num_frames - 2, height, width).to(image.device),
958
+ end_image.transpose(0, 1),
959
+ ],
960
+ dim=1,
961
+ )
962
+ msk[:, -1:] = 1
963
+ else:
964
+ vae_input = torch.concat(
965
+ [
966
+ image.transpose(0, 1),
967
+ torch.zeros(3, num_frames - 1, height, width).to(image.device),
968
+ ],
969
+ dim=1,
970
+ )
971
+
972
+ msk = torch.concat(
973
+ [torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1
974
+ )
975
+ msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
976
+ msk = msk.transpose(1, 2)[0]
977
+
978
+ y = pipe.vae.encode(
979
+ [vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)],
980
+ device=pipe.device,
981
+ tiled=tiled,
982
+ tile_size=tile_size,
983
+ tile_stride=tile_stride,
984
+ )[0]
985
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
986
+ y = torch.concat([msk, y])
987
+ y = y.unsqueeze(0)
988
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
989
+ return {"y": y}
990
+
991
+
992
+ class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
993
+ """
994
+ Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
995
+ """
996
+
997
+ def __init__(self):
998
+ super().__init__(
999
+ input_params=(
1000
+ "input_image",
1001
+ "latents",
1002
+ "height",
1003
+ "width",
1004
+ "tiled",
1005
+ "tile_size",
1006
+ "tile_stride",
1007
+ ),
1008
+ onload_model_names=("vae",),
1009
+ )
1010
+
1011
+ def process(
1012
+ self,
1013
+ pipe: WanVideoPipeline,
1014
+ input_image,
1015
+ latents,
1016
+ height,
1017
+ width,
1018
+ tiled,
1019
+ tile_size,
1020
+ tile_stride,
1021
+ ):
1022
+ if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
1023
+ return {}
1024
+ pipe.load_models_to_device(self.onload_model_names)
1025
+ image = pipe.preprocess_image(input_image.resize((width, height))).transpose(
1026
+ 0, 1
1027
+ )
1028
+ z = pipe.vae.encode(
1029
+ [image],
1030
+ device=pipe.device,
1031
+ tiled=tiled,
1032
+ tile_size=tile_size,
1033
+ tile_stride=tile_stride,
1034
+ )
1035
+ latents[:, :, 0:1] = z
1036
+ return {
1037
+ "latents": latents,
1038
+ "fuse_vae_embedding_in_latents": True,
1039
+ "first_frame_latents": z,
1040
+ }
1041
+
1042
+
1043
+ class WanVideoUnit_FunControl(PipelineUnit):
1044
+ def __init__(self):
1045
+ super().__init__(
1046
+ input_params=(
1047
+ "control_video",
1048
+ "num_frames",
1049
+ "height",
1050
+ "width",
1051
+ "tiled",
1052
+ "tile_size",
1053
+ "tile_stride",
1054
+ "clip_feature",
1055
+ "y",
1056
+ ),
1057
+ onload_model_names=("vae",),
1058
+ )
1059
+
1060
+ def process(
1061
+ self,
1062
+ pipe: WanVideoPipeline,
1063
+ control_video,
1064
+ num_frames,
1065
+ height,
1066
+ width,
1067
+ tiled,
1068
+ tile_size,
1069
+ tile_stride,
1070
+ clip_feature,
1071
+ y,
1072
+ ):
1073
+ if control_video is None:
1074
+ return {}
1075
+ pipe.load_models_to_device(self.onload_model_names)
1076
+ control_video = pipe.preprocess_video(control_video)
1077
+ control_latents = pipe.vae.encode(
1078
+ control_video,
1079
+ device=pipe.device,
1080
+ tiled=tiled,
1081
+ tile_size=tile_size,
1082
+ tile_stride=tile_stride,
1083
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1084
+ control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
1085
+ if clip_feature is None or y is None:
1086
+ clip_feature = torch.zeros(
1087
+ (1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device
1088
+ )
1089
+ y = torch.zeros(
1090
+ (1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8),
1091
+ dtype=pipe.torch_dtype,
1092
+ device=pipe.device,
1093
+ )
1094
+ else:
1095
+ y = y[:, -16:]
1096
+ y = torch.concat([control_latents, y], dim=1)
1097
+ return {"clip_feature": clip_feature, "y": y}
1098
+
1099
+
1100
+ class WanVideoUnit_FunReference(PipelineUnit):
1101
+ def __init__(self):
1102
+ super().__init__(
1103
+ input_params=("reference_image", "height", "width", "reference_image"),
1104
+ onload_model_names=("vae",),
1105
+ )
1106
+
1107
+ def process(self, pipe: WanVideoPipeline, reference_image, height, width):
1108
+ if reference_image is None:
1109
+ return {}
1110
+ pipe.load_models_to_device(["vae"])
1111
+ reference_image = reference_image.resize((width, height))
1112
+ reference_latents = pipe.preprocess_video([reference_image])
1113
+ reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
1114
+ clip_feature = pipe.preprocess_image(reference_image)
1115
+ clip_feature = pipe.image_encoder.encode_image([clip_feature])
1116
+ return {"reference_latents": reference_latents, "clip_feature": clip_feature}
1117
+
1118
+
1119
+ class WanVideoUnit_FunCameraControl(PipelineUnit):
1120
+ def __init__(self):
1121
+ super().__init__(
1122
+ input_params=(
1123
+ "height",
1124
+ "width",
1125
+ "num_frames",
1126
+ "camera_control_direction",
1127
+ "camera_control_speed",
1128
+ "camera_control_origin",
1129
+ "latents",
1130
+ "input_image",
1131
+ ),
1132
+ onload_model_names=("vae",),
1133
+ )
1134
+
1135
+ def process(
1136
+ self,
1137
+ pipe: WanVideoPipeline,
1138
+ height,
1139
+ width,
1140
+ num_frames,
1141
+ camera_control_direction,
1142
+ camera_control_speed,
1143
+ camera_control_origin,
1144
+ latents,
1145
+ input_image,
1146
+ ):
1147
+ if camera_control_direction is None:
1148
+ return {}
1149
+ camera_control_plucker_embedding = (
1150
+ pipe.dit.control_adapter.process_camera_coordinates(
1151
+ camera_control_direction,
1152
+ num_frames,
1153
+ height,
1154
+ width,
1155
+ camera_control_speed,
1156
+ camera_control_origin,
1157
+ )
1158
+ )
1159
+
1160
+ control_camera_video = (
1161
+ camera_control_plucker_embedding[:num_frames]
1162
+ .permute([3, 0, 1, 2])
1163
+ .unsqueeze(0)
1164
+ )
1165
+ control_camera_latents = torch.concat(
1166
+ [
1167
+ torch.repeat_interleave(
1168
+ control_camera_video[:, :, 0:1], repeats=4, dim=2
1169
+ ),
1170
+ control_camera_video[:, :, 1:],
1171
+ ],
1172
+ dim=2,
1173
+ ).transpose(1, 2)
1174
+ b, f, c, h, w = control_camera_latents.shape
1175
+ control_camera_latents = (
1176
+ control_camera_latents.contiguous()
1177
+ .view(b, f // 4, 4, c, h, w)
1178
+ .transpose(2, 3)
1179
+ )
1180
+ control_camera_latents = (
1181
+ control_camera_latents.contiguous()
1182
+ .view(b, f // 4, c * 4, h, w)
1183
+ .transpose(1, 2)
1184
+ )
1185
+ control_camera_latents_input = control_camera_latents.to(
1186
+ device=pipe.device, dtype=pipe.torch_dtype
1187
+ )
1188
+
1189
+ input_image = input_image.resize((width, height))
1190
+ input_latents = pipe.preprocess_video([input_image])
1191
+ pipe.load_models_to_device(self.onload_model_names)
1192
+ input_latents = pipe.vae.encode(input_latents, device=pipe.device)
1193
+ y = torch.zeros_like(latents).to(pipe.device)
1194
+ y[:, :, :1] = input_latents
1195
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
1196
+ return {"control_camera_latents_input": control_camera_latents_input, "y": y}
1197
+
1198
+
1199
+ class WanVideoUnit_SpeedControl(PipelineUnit):
1200
+ def __init__(self):
1201
+ super().__init__(input_params=("motion_bucket_id",))
1202
+
1203
+ def process(self, pipe: WanVideoPipeline, motion_bucket_id):
1204
+ if motion_bucket_id is None:
1205
+ return {}
1206
+ motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(
1207
+ dtype=pipe.torch_dtype, device=pipe.device
1208
+ )
1209
+ return {"motion_bucket_id": motion_bucket_id}
1210
+
1211
+
1212
+ class WanVideoUnit_VACE(PipelineUnit):
1213
+ def __init__(self):
1214
+ super().__init__(
1215
+ input_params=(
1216
+ "vace_video",
1217
+ "vace_video_mask",
1218
+ "vace_reference_image",
1219
+ "vace_scale",
1220
+ "height",
1221
+ "width",
1222
+ "num_frames",
1223
+ "tiled",
1224
+ "tile_size",
1225
+ "tile_stride",
1226
+ ),
1227
+ onload_model_names=("vae",),
1228
+ )
1229
+
1230
+ def process(
1231
+ self,
1232
+ pipe: WanVideoPipeline,
1233
+ vace_video,
1234
+ vace_video_mask,
1235
+ vace_reference_image,
1236
+ vace_scale,
1237
+ height,
1238
+ width,
1239
+ num_frames,
1240
+ tiled,
1241
+ tile_size,
1242
+ tile_stride,
1243
+ ):
1244
+ if (
1245
+ vace_video is not None
1246
+ or vace_video_mask is not None
1247
+ or vace_reference_image is not None
1248
+ ):
1249
+ pipe.load_models_to_device(["vae"])
1250
+ if vace_video is None:
1251
+ vace_video = torch.zeros(
1252
+ (1, 3, num_frames, height, width),
1253
+ dtype=pipe.torch_dtype,
1254
+ device=pipe.device,
1255
+ )
1256
+ else:
1257
+ vace_video = pipe.preprocess_video(vace_video)
1258
+
1259
+ if vace_video_mask is None:
1260
+ vace_video_mask = torch.ones_like(vace_video)
1261
+ else:
1262
+ vace_video_mask = pipe.preprocess_video(
1263
+ vace_video_mask, min_value=0, max_value=1
1264
+ )
1265
+
1266
+ inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
1267
+ reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
1268
+ inactive = pipe.vae.encode(
1269
+ inactive,
1270
+ device=pipe.device,
1271
+ tiled=tiled,
1272
+ tile_size=tile_size,
1273
+ tile_stride=tile_stride,
1274
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1275
+ reactive = pipe.vae.encode(
1276
+ reactive,
1277
+ device=pipe.device,
1278
+ tiled=tiled,
1279
+ tile_size=tile_size,
1280
+ tile_stride=tile_stride,
1281
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1282
+ vace_video_latents = torch.concat((inactive, reactive), dim=1)
1283
+
1284
+ vace_mask_latents = rearrange(
1285
+ vace_video_mask[0, 0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8
1286
+ )
1287
+ vace_mask_latents = torch.nn.functional.interpolate(
1288
+ vace_mask_latents,
1289
+ size=(
1290
+ (vace_mask_latents.shape[2] + 3) // 4,
1291
+ vace_mask_latents.shape[3],
1292
+ vace_mask_latents.shape[4],
1293
+ ),
1294
+ mode="nearest-exact",
1295
+ )
1296
+
1297
+ if vace_reference_image is None:
1298
+ pass
1299
+ else:
1300
+ vace_reference_image = pipe.preprocess_video([vace_reference_image])
1301
+ vace_reference_latents = pipe.vae.encode(
1302
+ vace_reference_image,
1303
+ device=pipe.device,
1304
+ tiled=tiled,
1305
+ tile_size=tile_size,
1306
+ tile_stride=tile_stride,
1307
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1308
+ vace_reference_latents = torch.concat(
1309
+ (vace_reference_latents, torch.zeros_like(vace_reference_latents)),
1310
+ dim=1,
1311
+ )
1312
+ vace_video_latents = torch.concat(
1313
+ (vace_reference_latents, vace_video_latents), dim=2
1314
+ )
1315
+ vace_mask_latents = torch.concat(
1316
+ (torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents),
1317
+ dim=2,
1318
+ )
1319
+
1320
+ vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
1321
+ return {"vace_context": vace_context, "vace_scale": vace_scale}
1322
+ else:
1323
+ return {"vace_context": None, "vace_scale": vace_scale}
1324
+
1325
+
1326
+ class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
1327
+ def __init__(self):
1328
+ super().__init__(input_params=())
1329
+
1330
+ def process(self, pipe: WanVideoPipeline):
1331
+ if hasattr(pipe, "use_unified_sequence_parallel"):
1332
+ if pipe.use_unified_sequence_parallel:
1333
+ return {"use_unified_sequence_parallel": True}
1334
+ return {}
1335
+
1336
+
1337
+ class WanVideoUnit_TeaCache(PipelineUnit):
1338
+ def __init__(self):
1339
+ super().__init__(
1340
+ seperate_cfg=True,
1341
+ input_params_posi={
1342
+ "num_inference_steps": "num_inference_steps",
1343
+ "tea_cache_l1_thresh": "tea_cache_l1_thresh",
1344
+ "tea_cache_model_id": "tea_cache_model_id",
1345
+ },
1346
+ input_params_nega={
1347
+ "num_inference_steps": "num_inference_steps",
1348
+ "tea_cache_l1_thresh": "tea_cache_l1_thresh",
1349
+ "tea_cache_model_id": "tea_cache_model_id",
1350
+ },
1351
+ )
1352
+
1353
+ def process(
1354
+ self,
1355
+ pipe: WanVideoPipeline,
1356
+ num_inference_steps,
1357
+ tea_cache_l1_thresh,
1358
+ tea_cache_model_id,
1359
+ ):
1360
+ if tea_cache_l1_thresh is None:
1361
+ return {}
1362
+ return {
1363
+ "tea_cache": TeaCache(
1364
+ num_inference_steps,
1365
+ rel_l1_thresh=tea_cache_l1_thresh,
1366
+ model_id=tea_cache_model_id,
1367
+ )
1368
+ }
1369
+
1370
+
1371
+ class WanVideoUnit_CfgMerger(PipelineUnit):
1372
+ def __init__(self):
1373
+ super().__init__(take_over=True)
1374
+ self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"]
1375
+
1376
+ def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
1377
+ if not inputs_shared["cfg_merge"]:
1378
+ return inputs_shared, inputs_posi, inputs_nega
1379
+ for name in self.concat_tensor_names:
1380
+ tensor_posi = inputs_posi.get(name)
1381
+ tensor_nega = inputs_nega.get(name)
1382
+ tensor_shared = inputs_shared.get(name)
1383
+ if tensor_posi is not None and tensor_nega is not None:
1384
+ inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0)
1385
+ elif tensor_shared is not None:
1386
+ inputs_shared[name] = torch.concat(
1387
+ (tensor_shared, tensor_shared), dim=0
1388
+ )
1389
+ inputs_posi.clear()
1390
+ inputs_nega.clear()
1391
+ return inputs_shared, inputs_posi, inputs_nega
1392
+
1393
+
1394
+ class TeaCache:
1395
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
1396
+ self.num_inference_steps = num_inference_steps
1397
+ self.step = 0
1398
+ self.accumulated_rel_l1_distance = 0
1399
+ self.previous_modulated_input = None
1400
+ self.rel_l1_thresh = rel_l1_thresh
1401
+ self.previous_residual = None
1402
+ self.previous_hidden_states = None
1403
+
1404
+ self.coefficients_dict = {
1405
+ "Wan2.1-T2V-1.3B": [
1406
+ -5.21862437e04,
1407
+ 9.23041404e03,
1408
+ -5.28275948e02,
1409
+ 1.36987616e01,
1410
+ -4.99875664e-02,
1411
+ ],
1412
+ "Wan2.1-T2V-14B": [
1413
+ -3.03318725e05,
1414
+ 4.90537029e04,
1415
+ -2.65530556e03,
1416
+ 5.87365115e01,
1417
+ -3.15583525e-01,
1418
+ ],
1419
+ "Wan2.1-I2V-14B-480P": [
1420
+ 2.57151496e05,
1421
+ -3.54229917e04,
1422
+ 1.40286849e03,
1423
+ -1.35890334e01,
1424
+ 1.32517977e-01,
1425
+ ],
1426
+ "Wan2.1-I2V-14B-720P": [
1427
+ 8.10705460e03,
1428
+ 2.13393892e03,
1429
+ -3.72934672e02,
1430
+ 1.66203073e01,
1431
+ -4.17769401e-02,
1432
+ ],
1433
+ }
1434
+ if model_id not in self.coefficients_dict:
1435
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
1436
+ raise ValueError(
1437
+ f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids})."
1438
+ )
1439
+ self.coefficients = self.coefficients_dict[model_id]
1440
+
1441
+ def check(self, dit: WanModel, x, t_mod):
1442
+ modulated_inp = t_mod.clone()
1443
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
1444
+ should_calc = True
1445
+ self.accumulated_rel_l1_distance = 0
1446
+ else:
1447
+ coefficients = self.coefficients
1448
+ rescale_func = np.poly1d(coefficients)
1449
+ self.accumulated_rel_l1_distance += rescale_func(
1450
+ (
1451
+ (modulated_inp - self.previous_modulated_input).abs().mean()
1452
+ / self.previous_modulated_input.abs().mean()
1453
+ )
1454
+ .cpu()
1455
+ .item()
1456
+ )
1457
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
1458
+ should_calc = False
1459
+ else:
1460
+ should_calc = True
1461
+ self.accumulated_rel_l1_distance = 0
1462
+ self.previous_modulated_input = modulated_inp
1463
+ self.step += 1
1464
+ if self.step == self.num_inference_steps:
1465
+ self.step = 0
1466
+ if should_calc:
1467
+ self.previous_hidden_states = x.clone()
1468
+ return not should_calc
1469
+
1470
+ def store(self, hidden_states):
1471
+ self.previous_residual = hidden_states - self.previous_hidden_states
1472
+ self.previous_hidden_states = None
1473
+
1474
+ def update(self, hidden_states):
1475
+ hidden_states = hidden_states + self.previous_residual
1476
+ return hidden_states
1477
+
1478
+
1479
+ class TemporalTiler_BCTHW:
1480
+ def __init__(self):
1481
+ pass
1482
+
1483
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
1484
+ x = torch.ones((length,))
1485
+ if not left_bound:
1486
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
1487
+ if not right_bound:
1488
+ x[-border_width:] = torch.flip(
1489
+ (torch.arange(border_width) + 1) / border_width, dims=(0,)
1490
+ )
1491
+ return x
1492
+
1493
+ def build_mask(self, data, is_bound, border_width):
1494
+ _, _, T, _, _ = data.shape
1495
+ t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
1496
+ mask = repeat(t, "T -> 1 1 T 1 1")
1497
+ return mask
1498
+
1499
+ def run(
1500
+ self,
1501
+ model_fn,
1502
+ sliding_window_size,
1503
+ sliding_window_stride,
1504
+ computation_device,
1505
+ computation_dtype,
1506
+ model_kwargs,
1507
+ tensor_names,
1508
+ batch_size=None,
1509
+ ):
1510
+ tensor_names = [
1511
+ tensor_name
1512
+ for tensor_name in tensor_names
1513
+ if model_kwargs.get(tensor_name) is not None
1514
+ ]
1515
+ tensor_dict = {
1516
+ tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names
1517
+ }
1518
+ B, C, T, H, W = tensor_dict[tensor_names[0]].shape
1519
+ if batch_size is not None:
1520
+ B *= batch_size
1521
+ data_device, data_dtype = (
1522
+ tensor_dict[tensor_names[0]].device,
1523
+ tensor_dict[tensor_names[0]].dtype,
1524
+ )
1525
+ value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)
1526
+ weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
1527
+ for t in range(0, T, sliding_window_stride):
1528
+ if (
1529
+ t - sliding_window_stride >= 0
1530
+ and t - sliding_window_stride + sliding_window_size >= T
1531
+ ):
1532
+ continue
1533
+ t_ = min(t + sliding_window_size, T)
1534
+ model_kwargs.update(
1535
+ {
1536
+ tensor_name: tensor_dict[tensor_name][:, :, t:t_:, :].to(
1537
+ device=computation_device, dtype=computation_dtype
1538
+ )
1539
+ for tensor_name in tensor_names
1540
+ }
1541
+ )
1542
+ model_output = model_fn(**model_kwargs).to(
1543
+ device=data_device, dtype=data_dtype
1544
+ )
1545
+ mask = self.build_mask(
1546
+ model_output,
1547
+ is_bound=(t == 0, t_ == T),
1548
+ border_width=(sliding_window_size - sliding_window_stride,),
1549
+ ).to(device=data_device, dtype=data_dtype)
1550
+ value[:, :, t:t_, :, :] += model_output * mask
1551
+ weight[:, :, t:t_, :, :] += mask
1552
+ value /= weight
1553
+ model_kwargs.update(tensor_dict)
1554
+ return value
1555
+
1556
+
1557
+ def model_fn_wan_video(
1558
+ dit: WanModel,
1559
+ motion_controller: WanMotionControllerModel = None,
1560
+ vace: VaceWanModel = None,
1561
+ latents: torch.Tensor = None,
1562
+ timestep: torch.Tensor = None,
1563
+ context: torch.Tensor = None,
1564
+ clip_feature: Optional[torch.Tensor] = None,
1565
+ y: Optional[torch.Tensor] = None,
1566
+ reference_latents=None,
1567
+ vace_context=None,
1568
+ vace_scale=1.0,
1569
+ tea_cache: TeaCache = None,
1570
+ use_unified_sequence_parallel: bool = False,
1571
+ motion_bucket_id: Optional[torch.Tensor] = None,
1572
+ sliding_window_size: Optional[int] = None,
1573
+ sliding_window_stride: Optional[int] = None,
1574
+ cfg_merge: bool = False,
1575
+ use_gradient_checkpointing: bool = False,
1576
+ use_gradient_checkpointing_offload: bool = False,
1577
+ control_camera_latents_input=None,
1578
+ fuse_vae_embedding_in_latents: bool = False,
1579
+ ip_image=None,
1580
+ **kwargs,
1581
+ ):
1582
+ if sliding_window_size is not None and sliding_window_stride is not None:
1583
+ model_kwargs = dict(
1584
+ dit=dit,
1585
+ motion_controller=motion_controller,
1586
+ vace=vace,
1587
+ latents=latents,
1588
+ timestep=timestep,
1589
+ context=context,
1590
+ clip_feature=clip_feature,
1591
+ y=y,
1592
+ reference_latents=reference_latents,
1593
+ vace_context=vace_context,
1594
+ vace_scale=vace_scale,
1595
+ tea_cache=tea_cache,
1596
+ use_unified_sequence_parallel=use_unified_sequence_parallel,
1597
+ motion_bucket_id=motion_bucket_id,
1598
+ )
1599
+ return TemporalTiler_BCTHW().run(
1600
+ model_fn_wan_video,
1601
+ sliding_window_size,
1602
+ sliding_window_stride,
1603
+ latents.device,
1604
+ latents.dtype,
1605
+ model_kwargs=model_kwargs,
1606
+ tensor_names=["latents", "y"],
1607
+ batch_size=2 if cfg_merge else 1,
1608
+ )
1609
+
1610
+ if use_unified_sequence_parallel:
1611
+ import torch.distributed as dist
1612
+ from xfuser.core.distributed import (
1613
+ get_sequence_parallel_rank,
1614
+ get_sequence_parallel_world_size,
1615
+ get_sp_group,
1616
+ )
1617
+ x_ip = None
1618
+ t_mod_ip = None
1619
+ # Timestep
1620
+ if dit.seperated_timestep and fuse_vae_embedding_in_latents:
1621
+ timestep = torch.concat(
1622
+ [
1623
+ torch.zeros(
1624
+ (1, latents.shape[3] * latents.shape[4] // 4),
1625
+ dtype=latents.dtype,
1626
+ device=latents.device,
1627
+ ),
1628
+ torch.ones(
1629
+ (latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4),
1630
+ dtype=latents.dtype,
1631
+ device=latents.device,
1632
+ )
1633
+ * timestep,
1634
+ ]
1635
+ ).flatten()
1636
+ t = dit.time_embedding(
1637
+ sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)
1638
+ )
1639
+ t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
1640
+ else:
1641
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
1642
+ t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
1643
+
1644
+ if ip_image is not None:
1645
+ timestep_ip = torch.zeros_like(timestep) # [B] with 0s
1646
+ t_ip = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep_ip))
1647
+ t_mod_ip = dit.time_projection(t_ip).unflatten(1, (6, dit.dim))
1648
+
1649
+ # Motion Controller
1650
+ if motion_bucket_id is not None and motion_controller is not None:
1651
+ t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
1652
+ context = dit.text_embedding(context)
1653
+
1654
+ x = latents
1655
+ # Merged cfg
1656
+ if x.shape[0] != context.shape[0]:
1657
+ x = torch.concat([x] * context.shape[0], dim=0)
1658
+ if timestep.shape[0] != context.shape[0]:
1659
+ timestep = torch.concat([timestep] * context.shape[0], dim=0)
1660
+
1661
+ # Image Embedding
1662
+ if y is not None and dit.require_vae_embedding:
1663
+ x = torch.cat([x, y], dim=1)
1664
+ if clip_feature is not None and dit.require_clip_embedding:
1665
+ clip_embdding = dit.img_emb(clip_feature)
1666
+ context = torch.cat([clip_embdding, context], dim=1)
1667
+
1668
+ # Add camera control
1669
+ x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
1670
+
1671
+ # Reference image
1672
+ if reference_latents is not None:
1673
+ if len(reference_latents.shape) == 5:
1674
+ reference_latents = reference_latents[:, :, 0]
1675
+ reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2)
1676
+ x = torch.concat([reference_latents, x], dim=1)
1677
+ f += 1
1678
+
1679
+ offset = 1
1680
+ freqs = (
1681
+ torch.cat(
1682
+ [
1683
+ dit.freqs[0][offset : f + offset].view(f, 1, 1, -1).expand(f, h, w, -1),
1684
+ dit.freqs[1][offset : h + offset].view(1, h, 1, -1).expand(f, h, w, -1),
1685
+ dit.freqs[2][offset : w + offset].view(1, 1, w, -1).expand(f, h, w, -1),
1686
+ ],
1687
+ dim=-1,
1688
+ )
1689
+ .reshape(f * h * w, 1, -1)
1690
+ .to(x.device)
1691
+ )
1692
+
1693
+ ############################################################################################
1694
+ if ip_image is not None:
1695
+ x_ip, (f_ip, h_ip, w_ip) = dit.patchify(
1696
+ ip_image
1697
+ ) # x_ip [1, 1024, 5120] [B, N, D] f_ip = 1 h_ip = 32 w_ip = 32
1698
+ freqs_ip = (
1699
+ torch.cat(
1700
+ [
1701
+ dit.freqs[0][0].view(f_ip, 1, 1, -1).expand(f_ip, h_ip, w_ip, -1),
1702
+ dit.freqs[1][h + offset : h + offset + h_ip]
1703
+ .view(1, h_ip, 1, -1)
1704
+ .expand(f_ip, h_ip, w_ip, -1),
1705
+ dit.freqs[2][w + offset : w + offset + w_ip]
1706
+ .view(1, 1, w_ip, -1)
1707
+ .expand(f_ip, h_ip, w_ip, -1),
1708
+ ],
1709
+ dim=-1,
1710
+ )
1711
+ .reshape(f_ip * h_ip * w_ip, 1, -1)
1712
+ .to(x_ip.device)
1713
+ )
1714
+ freqs_original = freqs
1715
+ freqs = torch.cat([freqs, freqs_ip], dim=0)
1716
+ ############################################################################################
1717
+ else:
1718
+ freqs_original = freqs
1719
+ # TeaCache
1720
+ if tea_cache is not None:
1721
+ tea_cache_update = tea_cache.check(dit, x, t_mod)
1722
+ else:
1723
+ tea_cache_update = False
1724
+
1725
+ if vace_context is not None:
1726
+ vace_hints = vace(x, vace_context, context, t_mod, freqs_original)
1727
+
1728
+ # blocks
1729
+ if use_unified_sequence_parallel:
1730
+ if dist.is_initialized() and dist.get_world_size() > 1:
1731
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[
1732
+ get_sequence_parallel_rank()
1733
+ ]
1734
+ if tea_cache_update:
1735
+ x = tea_cache.update(x)
1736
+ else:
1737
+
1738
+ def create_custom_forward(module):
1739
+ def custom_forward(*inputs):
1740
+ return module(*inputs)
1741
+
1742
+ return custom_forward
1743
+
1744
+ for block_id, block in enumerate(dit.blocks):
1745
+ if use_gradient_checkpointing_offload:
1746
+ with torch.autograd.graph.save_on_cpu():
1747
+ x, x_ip = torch.utils.checkpoint.checkpoint(
1748
+ create_custom_forward(block),
1749
+ x,
1750
+ context,
1751
+ t_mod,
1752
+ freqs,
1753
+ x_ip=x_ip,
1754
+ t_mod_ip=t_mod_ip,
1755
+ use_reentrant=False,
1756
+ )
1757
+ elif use_gradient_checkpointing:
1758
+ x, x_ip = torch.utils.checkpoint.checkpoint(
1759
+ create_custom_forward(block),
1760
+ x,
1761
+ context,
1762
+ t_mod,
1763
+ freqs,
1764
+ x_ip=x_ip,
1765
+ t_mod_ip=t_mod_ip,
1766
+ use_reentrant=False,
1767
+ )
1768
+ else:
1769
+ x, x_ip = block(x, context, t_mod, freqs, x_ip=x_ip, t_mod_ip=t_mod_ip)
1770
+ if vace_context is not None and block_id in vace.vace_layers_mapping:
1771
+ current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
1772
+ if (
1773
+ use_unified_sequence_parallel
1774
+ and dist.is_initialized()
1775
+ and dist.get_world_size() > 1
1776
+ ):
1777
+ current_vace_hint = torch.chunk(
1778
+ current_vace_hint, get_sequence_parallel_world_size(), dim=1
1779
+ )[get_sequence_parallel_rank()]
1780
+ x = x + current_vace_hint * vace_scale
1781
+ if tea_cache is not None:
1782
+ tea_cache.store(x)
1783
+
1784
+ x = dit.head(x, t)
1785
+ if use_unified_sequence_parallel:
1786
+ if dist.is_initialized() and dist.get_world_size() > 1:
1787
+ x = get_sp_group().all_gather(x, dim=1)
1788
+ # Remove reference latents
1789
+ if reference_latents is not None:
1790
+ x = x[:, reference_latents.shape[1] :]
1791
+ f -= 1
1792
+ x = dit.unpatchify(x, (f, h, w))
1793
+ return x
pipelines/wan_video_face_swap.py ADDED
@@ -0,0 +1,1786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, types
2
+ import numpy as np
3
+ from PIL import Image
4
+ from einops import repeat
5
+ from typing import Optional, Union
6
+ from einops import rearrange
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ from typing import Optional
10
+ from typing_extensions import Literal
11
+ import imageio
12
+ import os
13
+ from typing import List
14
+ import cv2
15
+
16
+ from utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
17
+ from models import ModelManager, load_state_dict
18
+ from models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
19
+ from models.wan_video_text_encoder import (
20
+ WanTextEncoder,
21
+ T5RelativeEmbedding,
22
+ T5LayerNorm,
23
+ )
24
+ from models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
25
+ from models.wan_video_image_encoder import WanImageEncoder
26
+ from models.wan_video_vace import VaceWanModel
27
+ from models.wan_video_motion_controller import WanMotionControllerModel
28
+ from schedulers.flow_match import FlowMatchScheduler
29
+ from prompters import WanPrompter
30
+ from vram_management import (
31
+ enable_vram_management,
32
+ AutoWrappedModule,
33
+ AutoWrappedLinear,
34
+ WanAutoCastLayerNorm,
35
+ )
36
+ from lora import GeneralLoRALoader
37
+
38
+
39
+ def load_video_as_list(video_path: str) -> List[Image.Image]:
40
+ if not os.path.isfile(video_path):
41
+ raise FileNotFoundError(video_path)
42
+ reader = imageio.get_reader(video_path)
43
+ frames = []
44
+ for i, frame_data in enumerate(reader):
45
+ pil_image = Image.fromarray(frame_data)
46
+ frames.append(pil_image)
47
+ reader.close()
48
+ return frames
49
+
50
+
51
+ class WanVideoPipeline_FaceSwap(BasePipeline):
52
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
53
+ super().__init__(
54
+ device=device,
55
+ torch_dtype=torch_dtype,
56
+ height_division_factor=16,
57
+ width_division_factor=16,
58
+ time_division_factor=4,
59
+ time_division_remainder=1,
60
+ )
61
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
62
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
63
+ self.text_encoder: WanTextEncoder = None
64
+ self.image_encoder: WanImageEncoder = None
65
+ self.dit: WanModel = None
66
+ self.dit2: WanModel = None
67
+ self.vae: WanVideoVAE = None
68
+ self.motion_controller: WanMotionControllerModel = None
69
+ self.vace: VaceWanModel = None
70
+ self.in_iteration_models = ("dit", "motion_controller", "vace")
71
+ self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
72
+ self.unit_runner = PipelineUnitRunner()
73
+ self.units = [
74
+ WanVideoUnit_ShapeChecker(),
75
+ WanVideoUnit_NoiseInitializer(),
76
+ WanVideoUnit_InputVideoEmbedder(),
77
+ WanVideoUnit_PromptEmbedder(),
78
+ WanVideoUnit_ImageEmbedderVAE(),
79
+ WanVideoUnit_ImageEmbedderCLIP(),
80
+ WanVideoUnit_ImageEmbedderFused(),
81
+ WanVideoUnit_FunControl(),
82
+ WanVideoUnit_FunReference(),
83
+ WanVideoUnit_FunCameraControl(),
84
+ WanVideoUnit_SpeedControl(),
85
+ WanVideoUnit_VACE(),
86
+ WanVideoUnit_UnifiedSequenceParallel(),
87
+ WanVideoUnit_TeaCache(),
88
+ WanVideoUnit_CfgMerger(),
89
+ ]
90
+ self.model_fn = model_fn_wan_video
91
+
92
+ def encode_ip_image(self, ip_image):
93
+ self.load_models_to_device(["vae"])
94
+ ip_image = (
95
+ torch.tensor(np.array(ip_image)).permute(2, 0, 1).float() / 255.0
96
+ ) # [3, H, W]
97
+ ip_image = (
98
+ ip_image.unsqueeze(1).unsqueeze(0).to(dtype=self.torch_dtype)
99
+ ) # [B, 3, 1, H, W]
100
+ ip_image = ip_image * 2 - 1
101
+ ip_image_latent = self.vae.encode(ip_image, device=self.device, tiled=False)
102
+ return ip_image_latent
103
+
104
+ def load_lora(self, module, path, alpha=1):
105
+ loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
106
+ lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
107
+ loader.load(module, lora, alpha=alpha)
108
+
109
+ def training_loss(self, **inputs):
110
+ max_timestep_boundary = int(
111
+ inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps
112
+ )
113
+ min_timestep_boundary = int(
114
+ inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps
115
+ )
116
+ timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
117
+ timestep = self.scheduler.timesteps[timestep_id].to(
118
+ dtype=self.torch_dtype, device=self.device
119
+ )
120
+
121
+ inputs["latents"] = self.scheduler.add_noise(
122
+ inputs["input_latents"], inputs["noise"], timestep
123
+ )
124
+ training_target = self.scheduler.training_target(
125
+ inputs["input_latents"], inputs["noise"], timestep
126
+ )
127
+
128
+ noise_pred = self.model_fn(**inputs, timestep=timestep)
129
+
130
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
131
+ loss = loss * self.scheduler.training_weight(timestep)
132
+ return loss
133
+
134
+ def enable_vram_management(
135
+ self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5
136
+ ):
137
+ self.vram_management_enabled = True
138
+ if num_persistent_param_in_dit is not None:
139
+ vram_limit = None
140
+ else:
141
+ if vram_limit is None:
142
+ vram_limit = self.get_vram()
143
+ vram_limit = vram_limit - vram_buffer
144
+ if self.text_encoder is not None:
145
+ dtype = next(iter(self.text_encoder.parameters())).dtype
146
+ enable_vram_management(
147
+ self.text_encoder,
148
+ module_map={
149
+ torch.nn.Linear: AutoWrappedLinear,
150
+ torch.nn.Embedding: AutoWrappedModule,
151
+ T5RelativeEmbedding: AutoWrappedModule,
152
+ T5LayerNorm: AutoWrappedModule,
153
+ },
154
+ module_config=dict(
155
+ offload_dtype=dtype,
156
+ offload_device="cpu",
157
+ onload_dtype=dtype,
158
+ onload_device="cpu",
159
+ computation_dtype=self.torch_dtype,
160
+ computation_device=self.device,
161
+ ),
162
+ vram_limit=vram_limit,
163
+ )
164
+ if self.dit is not None:
165
+ dtype = next(iter(self.dit.parameters())).dtype
166
+ device = "cpu" if vram_limit is not None else self.device
167
+ enable_vram_management(
168
+ self.dit,
169
+ module_map={
170
+ torch.nn.Linear: AutoWrappedLinear,
171
+ torch.nn.Conv3d: AutoWrappedModule,
172
+ torch.nn.LayerNorm: WanAutoCastLayerNorm,
173
+ RMSNorm: AutoWrappedModule,
174
+ torch.nn.Conv2d: AutoWrappedModule,
175
+ },
176
+ module_config=dict(
177
+ offload_dtype=dtype,
178
+ offload_device="cpu",
179
+ onload_dtype=dtype,
180
+ onload_device=device,
181
+ computation_dtype=self.torch_dtype,
182
+ computation_device=self.device,
183
+ ),
184
+ max_num_param=num_persistent_param_in_dit,
185
+ overflow_module_config=dict(
186
+ offload_dtype=dtype,
187
+ offload_device="cpu",
188
+ onload_dtype=dtype,
189
+ onload_device="cpu",
190
+ computation_dtype=self.torch_dtype,
191
+ computation_device=self.device,
192
+ ),
193
+ vram_limit=vram_limit,
194
+ )
195
+ if self.dit2 is not None:
196
+ dtype = next(iter(self.dit2.parameters())).dtype
197
+ device = "cpu" if vram_limit is not None else self.device
198
+ enable_vram_management(
199
+ self.dit2,
200
+ module_map={
201
+ torch.nn.Linear: AutoWrappedLinear,
202
+ torch.nn.Conv3d: AutoWrappedModule,
203
+ torch.nn.LayerNorm: WanAutoCastLayerNorm,
204
+ RMSNorm: AutoWrappedModule,
205
+ torch.nn.Conv2d: AutoWrappedModule,
206
+ },
207
+ module_config=dict(
208
+ offload_dtype=dtype,
209
+ offload_device="cpu",
210
+ onload_dtype=dtype,
211
+ onload_device=device,
212
+ computation_dtype=self.torch_dtype,
213
+ computation_device=self.device,
214
+ ),
215
+ max_num_param=num_persistent_param_in_dit,
216
+ overflow_module_config=dict(
217
+ offload_dtype=dtype,
218
+ offload_device="cpu",
219
+ onload_dtype=dtype,
220
+ onload_device="cpu",
221
+ computation_dtype=self.torch_dtype,
222
+ computation_device=self.device,
223
+ ),
224
+ vram_limit=vram_limit,
225
+ )
226
+ if self.vae is not None:
227
+ dtype = next(iter(self.vae.parameters())).dtype
228
+ enable_vram_management(
229
+ self.vae,
230
+ module_map={
231
+ torch.nn.Linear: AutoWrappedLinear,
232
+ torch.nn.Conv2d: AutoWrappedModule,
233
+ RMS_norm: AutoWrappedModule,
234
+ CausalConv3d: AutoWrappedModule,
235
+ Upsample: AutoWrappedModule,
236
+ torch.nn.SiLU: AutoWrappedModule,
237
+ torch.nn.Dropout: AutoWrappedModule,
238
+ },
239
+ module_config=dict(
240
+ offload_dtype=dtype,
241
+ offload_device="cpu",
242
+ onload_dtype=dtype,
243
+ onload_device=self.device,
244
+ computation_dtype=self.torch_dtype,
245
+ computation_device=self.device,
246
+ ),
247
+ )
248
+ if self.image_encoder is not None:
249
+ dtype = next(iter(self.image_encoder.parameters())).dtype
250
+ enable_vram_management(
251
+ self.image_encoder,
252
+ module_map={
253
+ torch.nn.Linear: AutoWrappedLinear,
254
+ torch.nn.Conv2d: AutoWrappedModule,
255
+ torch.nn.LayerNorm: AutoWrappedModule,
256
+ },
257
+ module_config=dict(
258
+ offload_dtype=dtype,
259
+ offload_device="cpu",
260
+ onload_dtype=dtype,
261
+ onload_device="cpu",
262
+ computation_dtype=dtype,
263
+ computation_device=self.device,
264
+ ),
265
+ )
266
+ if self.motion_controller is not None:
267
+ dtype = next(iter(self.motion_controller.parameters())).dtype
268
+ enable_vram_management(
269
+ self.motion_controller,
270
+ module_map={
271
+ torch.nn.Linear: AutoWrappedLinear,
272
+ },
273
+ module_config=dict(
274
+ offload_dtype=dtype,
275
+ offload_device="cpu",
276
+ onload_dtype=dtype,
277
+ onload_device="cpu",
278
+ computation_dtype=dtype,
279
+ computation_device=self.device,
280
+ ),
281
+ )
282
+ if self.vace is not None:
283
+ device = "cpu" if vram_limit is not None else self.device
284
+ enable_vram_management(
285
+ self.vace,
286
+ module_map={
287
+ torch.nn.Linear: AutoWrappedLinear,
288
+ torch.nn.Conv3d: AutoWrappedModule,
289
+ torch.nn.LayerNorm: AutoWrappedModule,
290
+ RMSNorm: AutoWrappedModule,
291
+ },
292
+ module_config=dict(
293
+ offload_dtype=dtype,
294
+ offload_device="cpu",
295
+ onload_dtype=dtype,
296
+ onload_device=device,
297
+ computation_dtype=self.torch_dtype,
298
+ computation_device=self.device,
299
+ ),
300
+ vram_limit=vram_limit,
301
+ )
302
+
303
+ def initialize_usp(self):
304
+ import torch.distributed as dist
305
+ from xfuser.core.distributed import (
306
+ initialize_model_parallel,
307
+ init_distributed_environment,
308
+ )
309
+
310
+ dist.init_process_group(backend="nccl", init_method="env://")
311
+ init_distributed_environment(
312
+ rank=dist.get_rank(), world_size=dist.get_world_size()
313
+ )
314
+ initialize_model_parallel(
315
+ sequence_parallel_degree=dist.get_world_size(),
316
+ ring_degree=1,
317
+ ulysses_degree=dist.get_world_size(),
318
+ )
319
+ torch.cuda.set_device(dist.get_rank())
320
+
321
+ def enable_usp(self):
322
+ from xfuser.core.distributed import get_sequence_parallel_world_size
323
+ from distributed.xdit_context_parallel import (
324
+ usp_attn_forward,
325
+ usp_dit_forward,
326
+ )
327
+
328
+ for block in self.dit.blocks:
329
+ block.self_attn.forward = types.MethodType(
330
+ usp_attn_forward, block.self_attn
331
+ )
332
+ self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
333
+ if self.dit2 is not None:
334
+ for block in self.dit2.blocks:
335
+ block.self_attn.forward = types.MethodType(
336
+ usp_attn_forward, block.self_attn
337
+ )
338
+ self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
339
+ self.sp_size = get_sequence_parallel_world_size()
340
+ self.use_unified_sequence_parallel = True
341
+
342
+ @staticmethod
343
+ def from_pretrained(
344
+ torch_dtype: torch.dtype = torch.bfloat16,
345
+ device: Union[str, torch.device] = "cuda",
346
+ model_configs: list[ModelConfig] = [],
347
+ tokenizer_config: ModelConfig = ModelConfig(
348
+ model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"
349
+ ),
350
+ redirect_common_files: bool = True,
351
+ use_usp=False,
352
+ ):
353
+ # Redirect model path
354
+ if redirect_common_files:
355
+ redirect_dict = {
356
+ "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
357
+ "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
358
+ "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
359
+ }
360
+ for model_config in model_configs:
361
+ if (
362
+ model_config.origin_file_pattern is None
363
+ or model_config.model_id is None
364
+ ):
365
+ continue
366
+ if (
367
+ model_config.origin_file_pattern in redirect_dict
368
+ and model_config.model_id
369
+ != redirect_dict[model_config.origin_file_pattern]
370
+ ):
371
+ print(
372
+ f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection."
373
+ )
374
+ model_config.model_id = redirect_dict[
375
+ model_config.origin_file_pattern
376
+ ]
377
+
378
+ # Initialize pipeline
379
+ pipe = WanVideoPipeline_FaceSwap(device=device, torch_dtype=torch_dtype)
380
+ if use_usp:
381
+ pipe.initialize_usp()
382
+
383
+ # Download and load models
384
+ model_manager = ModelManager()
385
+ for model_config in model_configs:
386
+ model_config.download_if_necessary(use_usp=use_usp)
387
+ model_manager.load_model(
388
+ model_config.path,
389
+ device=model_config.offload_device or device,
390
+ torch_dtype=model_config.offload_dtype or torch_dtype,
391
+ )
392
+
393
+ # Load models
394
+ pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
395
+ dit = model_manager.fetch_model("wan_video_dit", index=2)
396
+ if isinstance(dit, list):
397
+ pipe.dit, pipe.dit2 = dit
398
+ else:
399
+ pipe.dit = dit
400
+ pipe.vae = model_manager.fetch_model("wan_video_vae")
401
+ pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
402
+ pipe.motion_controller = model_manager.fetch_model(
403
+ "wan_video_motion_controller"
404
+ )
405
+ pipe.vace = model_manager.fetch_model("wan_video_vace")
406
+
407
+ # Size division factor
408
+ if pipe.vae is not None:
409
+ pipe.height_division_factor = pipe.vae.upsampling_factor * 2
410
+ pipe.width_division_factor = pipe.vae.upsampling_factor * 2
411
+
412
+ # Initialize tokenizer
413
+ tokenizer_config.download_if_necessary(use_usp=use_usp)
414
+ pipe.prompter.fetch_models(pipe.text_encoder)
415
+ pipe.prompter.fetch_tokenizer(tokenizer_config.path)
416
+
417
+ # Unified Sequence Parallel
418
+ if use_usp:
419
+ pipe.enable_usp()
420
+ return pipe
421
+
422
+ @torch.no_grad()
423
+ def __call__(
424
+ self,
425
+ # Prompt
426
+ prompt: str,
427
+ negative_prompt: Optional[str] = "",
428
+ # Image-to-video
429
+ input_image: Optional[Image.Image] = None,
430
+ # First-last-frame-to-video
431
+ end_image: Optional[Image.Image] = None,
432
+ # Video-to-video
433
+ input_video: Optional[list[Image.Image]] = None,
434
+ denoising_strength: Optional[float] = 1,
435
+ # ControlNet
436
+ control_video: Optional[list[Image.Image]] = None,
437
+ reference_image: Optional[Image.Image] = None,
438
+ # Camera control
439
+ camera_control_direction: Optional[
440
+ Literal[
441
+ "Left",
442
+ "Right",
443
+ "Up",
444
+ "Down",
445
+ "LeftUp",
446
+ "LeftDown",
447
+ "RightUp",
448
+ "RightDown",
449
+ ]
450
+ ] = None,
451
+ camera_control_speed: Optional[float] = 1 / 54,
452
+ camera_control_origin: Optional[tuple] = (
453
+ 0,
454
+ 0.532139961,
455
+ 0.946026558,
456
+ 0.5,
457
+ 0.5,
458
+ 0,
459
+ 0,
460
+ 1,
461
+ 0,
462
+ 0,
463
+ 0,
464
+ 0,
465
+ 1,
466
+ 0,
467
+ 0,
468
+ 0,
469
+ 0,
470
+ 1,
471
+ 0,
472
+ ),
473
+ # VACE
474
+ vace_video: Optional[list[Image.Image]] = None,
475
+ vace_video_mask: Optional[Image.Image] = None,
476
+ vace_reference_image: Optional[Image.Image] = None,
477
+ vace_scale: Optional[float] = 1.0,
478
+ # Randomness
479
+ seed: Optional[int] = None,
480
+ rand_device: Optional[str] = "cpu",
481
+ # Shape
482
+ height: Optional[int] = 480,
483
+ width: Optional[int] = 832,
484
+ num_frames=81,
485
+ # Classifier-free guidance
486
+ cfg_scale: Optional[float] = 5.0,
487
+ cfg_merge: Optional[bool] = False,
488
+ # Boundary
489
+ switch_DiT_boundary: Optional[float] = 0.875,
490
+ # Scheduler
491
+ num_inference_steps: Optional[int] = 50,
492
+ sigma_shift: Optional[float] = 5.0,
493
+ # Speed control
494
+ motion_bucket_id: Optional[int] = None,
495
+ # VAE tiling
496
+ tiled: Optional[bool] = True,
497
+ tile_size: Optional[tuple[int, int]] = (30, 52),
498
+ tile_stride: Optional[tuple[int, int]] = (15, 26),
499
+ # Sliding window
500
+ sliding_window_size: Optional[int] = None,
501
+ sliding_window_stride: Optional[int] = None,
502
+ # Teacache
503
+ tea_cache_l1_thresh: Optional[float] = None,
504
+ tea_cache_model_id: Optional[str] = "",
505
+ # progress_bar
506
+ progress_bar_cmd=tqdm,
507
+ # Stand-In
508
+ face_mask=None,
509
+ ip_image=None,
510
+ force_background_consistency=False
511
+ ):
512
+ if ip_image is not None:
513
+ ip_image = self.encode_ip_image(ip_image)
514
+ # Scheduler
515
+ self.scheduler.set_timesteps(
516
+ num_inference_steps,
517
+ denoising_strength=denoising_strength,
518
+ shift=sigma_shift,
519
+ )
520
+
521
+ # Inputs
522
+ inputs_posi = {
523
+ "prompt": prompt,
524
+ "tea_cache_l1_thresh": tea_cache_l1_thresh,
525
+ "tea_cache_model_id": tea_cache_model_id,
526
+ "num_inference_steps": num_inference_steps,
527
+ }
528
+ inputs_nega = {
529
+ "negative_prompt": negative_prompt,
530
+ "tea_cache_l1_thresh": tea_cache_l1_thresh,
531
+ "tea_cache_model_id": tea_cache_model_id,
532
+ "num_inference_steps": num_inference_steps,
533
+ }
534
+ inputs_shared = {
535
+ "input_image": input_image,
536
+ "end_image": end_image,
537
+ "input_video": input_video,
538
+ "denoising_strength": denoising_strength,
539
+ "control_video": control_video,
540
+ "reference_image": reference_image,
541
+ "camera_control_direction": camera_control_direction,
542
+ "camera_control_speed": camera_control_speed,
543
+ "camera_control_origin": camera_control_origin,
544
+ "vace_video": vace_video,
545
+ "vace_video_mask": vace_video_mask,
546
+ "vace_reference_image": vace_reference_image,
547
+ "vace_scale": vace_scale,
548
+ "seed": seed,
549
+ "rand_device": rand_device,
550
+ "height": height,
551
+ "width": width,
552
+ "num_frames": num_frames,
553
+ "cfg_scale": cfg_scale,
554
+ "cfg_merge": cfg_merge,
555
+ "sigma_shift": sigma_shift,
556
+ "motion_bucket_id": motion_bucket_id,
557
+ "tiled": tiled,
558
+ "tile_size": tile_size,
559
+ "tile_stride": tile_stride,
560
+ "sliding_window_size": sliding_window_size,
561
+ "sliding_window_stride": sliding_window_stride,
562
+ "ip_image": ip_image,
563
+ }
564
+ for unit in self.units:
565
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
566
+ unit, self, inputs_shared, inputs_posi, inputs_nega
567
+ )
568
+ if face_mask is not None:
569
+ mask_processed = self.preprocess_video(face_mask)
570
+ mask_processed = mask_processed[:, 0:1, ...]
571
+ latent_mask = torch.nn.functional.interpolate(
572
+ mask_processed,
573
+ size=inputs_shared["latents"].shape[2:],
574
+ mode="nearest-exact",
575
+ )
576
+ # Denoise
577
+ self.load_models_to_device(self.in_iteration_models)
578
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
579
+ for progress_id, timestep in enumerate(
580
+ progress_bar_cmd(self.scheduler.timesteps)
581
+ ):
582
+ # Switch DiT if necessary
583
+ if (
584
+ timestep.item()
585
+ < switch_DiT_boundary * self.scheduler.num_train_timesteps
586
+ and self.dit2 is not None
587
+ and not models["dit"] is self.dit2
588
+ ):
589
+ self.load_models_to_device(self.in_iteration_models_2)
590
+ models["dit"] = self.dit2
591
+
592
+ # Timestep
593
+ timestep = timestep.unsqueeze(0).to(
594
+ dtype=self.torch_dtype, device=self.device
595
+ )
596
+
597
+ # Inference
598
+ noise_pred_posi = self.model_fn(
599
+ **models, **inputs_shared, **inputs_posi, timestep=timestep
600
+ )
601
+ inputs_shared["ip_image"] = None
602
+ if cfg_scale != 1.0:
603
+ if cfg_merge:
604
+ noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
605
+ else:
606
+ noise_pred_nega = self.model_fn(
607
+ **models, **inputs_shared, **inputs_nega, timestep=timestep
608
+ )
609
+ noise_pred = noise_pred_nega + cfg_scale * (
610
+ noise_pred_posi - noise_pred_nega
611
+ )
612
+ else:
613
+ noise_pred = noise_pred_posi
614
+
615
+ # Scheduler
616
+ inputs_shared["latents"] = self.scheduler.step(
617
+ noise_pred,
618
+ self.scheduler.timesteps[progress_id],
619
+ inputs_shared["latents"],
620
+ )
621
+ if force_background_consistency:
622
+ if (
623
+ inputs_shared["input_latents"] is not None
624
+ and latent_mask is not None
625
+ ):
626
+ if progress_id == len(self.scheduler.timesteps) - 1:
627
+ noised_original_latents = inputs_shared["input_latents"]
628
+ else:
629
+ next_timestep = self.scheduler.timesteps[progress_id + 1]
630
+ noised_original_latents = self.scheduler.add_noise(
631
+ inputs_shared["input_latents"],
632
+ inputs_shared["noise"],
633
+ timestep=next_timestep,
634
+ )
635
+
636
+ hard_mask = (latent_mask > 0.5).to(
637
+ dtype=inputs_shared["latents"].dtype
638
+ )
639
+
640
+ inputs_shared["latents"] = (
641
+ 1 - hard_mask
642
+ ) * noised_original_latents + hard_mask * inputs_shared["latents"]
643
+
644
+ if "first_frame_latents" in inputs_shared:
645
+ inputs_shared["latents"][:, :, 0:1] = inputs_shared[
646
+ "first_frame_latents"
647
+ ]
648
+
649
+ if vace_reference_image is not None:
650
+ inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
651
+
652
+ # Decode
653
+ self.load_models_to_device(["vae"])
654
+ video = self.vae.decode(
655
+ inputs_shared["latents"],
656
+ device=self.device,
657
+ tiled=tiled,
658
+ tile_size=tile_size,
659
+ tile_stride=tile_stride,
660
+ )
661
+ video = self.vae_output_to_video(video)
662
+ self.load_models_to_device([])
663
+
664
+ return video
665
+
666
+
667
+ class WanVideoUnit_ShapeChecker(PipelineUnit):
668
+ def __init__(self):
669
+ super().__init__(input_params=("height", "width", "num_frames"))
670
+
671
+ def process(self, pipe: WanVideoPipeline_FaceSwap, height, width, num_frames):
672
+ height, width, num_frames = pipe.check_resize_height_width(
673
+ height, width, num_frames
674
+ )
675
+ return {"height": height, "width": width, "num_frames": num_frames}
676
+
677
+
678
+ class WanVideoUnit_NoiseInitializer(PipelineUnit):
679
+ def __init__(self):
680
+ super().__init__(
681
+ input_params=(
682
+ "height",
683
+ "width",
684
+ "num_frames",
685
+ "seed",
686
+ "rand_device",
687
+ "vace_reference_image",
688
+ )
689
+ )
690
+
691
+ def process(
692
+ self,
693
+ pipe: WanVideoPipeline_FaceSwap,
694
+ height,
695
+ width,
696
+ num_frames,
697
+ seed,
698
+ rand_device,
699
+ vace_reference_image,
700
+ ):
701
+ length = (num_frames - 1) // 4 + 1
702
+ if vace_reference_image is not None:
703
+ length += 1
704
+ shape = (
705
+ 1,
706
+ pipe.vae.model.z_dim,
707
+ length,
708
+ height // pipe.vae.upsampling_factor,
709
+ width // pipe.vae.upsampling_factor,
710
+ )
711
+ noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
712
+ if vace_reference_image is not None:
713
+ noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
714
+ return {"noise": noise}
715
+
716
+
717
+ class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
718
+ def __init__(self):
719
+ super().__init__(
720
+ input_params=(
721
+ "input_video",
722
+ "noise",
723
+ "tiled",
724
+ "tile_size",
725
+ "tile_stride",
726
+ "vace_reference_image",
727
+ ),
728
+ onload_model_names=("vae",),
729
+ )
730
+
731
+ def process(
732
+ self,
733
+ pipe: WanVideoPipeline_FaceSwap,
734
+ input_video,
735
+ noise,
736
+ tiled,
737
+ tile_size,
738
+ tile_stride,
739
+ vace_reference_image,
740
+ ):
741
+ if input_video is None:
742
+ return {"latents": noise}
743
+ pipe.load_models_to_device(["vae"])
744
+ input_video = pipe.preprocess_video(input_video)
745
+ input_latents = pipe.vae.encode(
746
+ input_video,
747
+ device=pipe.device,
748
+ tiled=tiled,
749
+ tile_size=tile_size,
750
+ tile_stride=tile_stride,
751
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
752
+ if vace_reference_image is not None:
753
+ vace_reference_image = pipe.preprocess_video([vace_reference_image])
754
+ vace_reference_latents = pipe.vae.encode(
755
+ vace_reference_image, device=pipe.device
756
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
757
+ input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
758
+ if pipe.scheduler.training:
759
+ return {"latents": noise, "input_latents": input_latents}
760
+ else:
761
+ latents = pipe.scheduler.add_noise(
762
+ input_latents, noise, timestep=pipe.scheduler.timesteps[0]
763
+ )
764
+ return {"latents": latents, "input_latents": input_latents}
765
+
766
+
767
+ class WanVideoUnit_PromptEmbedder(PipelineUnit):
768
+ def __init__(self):
769
+ super().__init__(
770
+ seperate_cfg=True,
771
+ input_params_posi={"prompt": "prompt", "positive": "positive"},
772
+ input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
773
+ onload_model_names=("text_encoder",),
774
+ )
775
+
776
+ def process(self, pipe: WanVideoPipeline_FaceSwap, prompt, positive) -> dict:
777
+ pipe.load_models_to_device(self.onload_model_names)
778
+ prompt_emb = pipe.prompter.encode_prompt(
779
+ prompt, positive=positive, device=pipe.device
780
+ )
781
+ return {"context": prompt_emb}
782
+
783
+
784
+ class WanVideoUnit_ImageEmbedder(PipelineUnit):
785
+ """
786
+ Deprecated
787
+ """
788
+
789
+ def __init__(self):
790
+ super().__init__(
791
+ input_params=(
792
+ "input_image",
793
+ "end_image",
794
+ "num_frames",
795
+ "height",
796
+ "width",
797
+ "tiled",
798
+ "tile_size",
799
+ "tile_stride",
800
+ ),
801
+ onload_model_names=("image_encoder", "vae"),
802
+ )
803
+
804
+ def process(
805
+ self,
806
+ pipe: WanVideoPipeline_FaceSwap,
807
+ input_image,
808
+ end_image,
809
+ num_frames,
810
+ height,
811
+ width,
812
+ tiled,
813
+ tile_size,
814
+ tile_stride,
815
+ ):
816
+ if input_image is None or pipe.image_encoder is None:
817
+ return {}
818
+ pipe.load_models_to_device(self.onload_model_names)
819
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(
820
+ pipe.device
821
+ )
822
+ clip_context = pipe.image_encoder.encode_image([image])
823
+ msk = torch.ones(1, num_frames, height // 8, width // 8, device=pipe.device)
824
+ msk[:, 1:] = 0
825
+ if end_image is not None:
826
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
827
+ pipe.device
828
+ )
829
+ vae_input = torch.concat(
830
+ [
831
+ image.transpose(0, 1),
832
+ torch.zeros(3, num_frames - 2, height, width).to(image.device),
833
+ end_image.transpose(0, 1),
834
+ ],
835
+ dim=1,
836
+ )
837
+ if pipe.dit.has_image_pos_emb:
838
+ clip_context = torch.concat(
839
+ [clip_context, pipe.image_encoder.encode_image([end_image])], dim=1
840
+ )
841
+ msk[:, -1:] = 1
842
+ else:
843
+ vae_input = torch.concat(
844
+ [
845
+ image.transpose(0, 1),
846
+ torch.zeros(3, num_frames - 1, height, width).to(image.device),
847
+ ],
848
+ dim=1,
849
+ )
850
+
851
+ msk = torch.concat(
852
+ [torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1
853
+ )
854
+ msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
855
+ msk = msk.transpose(1, 2)[0]
856
+
857
+ y = pipe.vae.encode(
858
+ [vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)],
859
+ device=pipe.device,
860
+ tiled=tiled,
861
+ tile_size=tile_size,
862
+ tile_stride=tile_stride,
863
+ )[0]
864
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
865
+ y = torch.concat([msk, y])
866
+ y = y.unsqueeze(0)
867
+ clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
868
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
869
+ return {"clip_feature": clip_context, "y": y}
870
+
871
+
872
+ class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
873
+ def __init__(self):
874
+ super().__init__(
875
+ input_params=("input_image", "end_image", "height", "width"),
876
+ onload_model_names=("image_encoder",),
877
+ )
878
+
879
+ def process(
880
+ self, pipe: WanVideoPipeline_FaceSwap, input_image, end_image, height, width
881
+ ):
882
+ if (
883
+ input_image is None
884
+ or pipe.image_encoder is None
885
+ or not pipe.dit.require_clip_embedding
886
+ ):
887
+ return {}
888
+ pipe.load_models_to_device(self.onload_model_names)
889
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(
890
+ pipe.device
891
+ )
892
+ clip_context = pipe.image_encoder.encode_image([image])
893
+ if end_image is not None:
894
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
895
+ pipe.device
896
+ )
897
+ if pipe.dit.has_image_pos_emb:
898
+ clip_context = torch.concat(
899
+ [clip_context, pipe.image_encoder.encode_image([end_image])], dim=1
900
+ )
901
+ clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
902
+ return {"clip_feature": clip_context}
903
+
904
+
905
+ class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
906
+ def __init__(self):
907
+ super().__init__(
908
+ input_params=(
909
+ "input_image",
910
+ "end_image",
911
+ "num_frames",
912
+ "height",
913
+ "width",
914
+ "tiled",
915
+ "tile_size",
916
+ "tile_stride",
917
+ ),
918
+ onload_model_names=("vae",),
919
+ )
920
+
921
+ def process(
922
+ self,
923
+ pipe: WanVideoPipeline_FaceSwap,
924
+ input_image,
925
+ end_image,
926
+ num_frames,
927
+ height,
928
+ width,
929
+ tiled,
930
+ tile_size,
931
+ tile_stride,
932
+ ):
933
+ if input_image is None or not pipe.dit.require_vae_embedding:
934
+ return {}
935
+ pipe.load_models_to_device(self.onload_model_names)
936
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(
937
+ pipe.device
938
+ )
939
+ msk = torch.ones(1, num_frames, height // 8, width // 8, device=pipe.device)
940
+ msk[:, 1:] = 0
941
+ if end_image is not None:
942
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
943
+ pipe.device
944
+ )
945
+ vae_input = torch.concat(
946
+ [
947
+ image.transpose(0, 1),
948
+ torch.zeros(3, num_frames - 2, height, width).to(image.device),
949
+ end_image.transpose(0, 1),
950
+ ],
951
+ dim=1,
952
+ )
953
+ msk[:, -1:] = 1
954
+ else:
955
+ vae_input = torch.concat(
956
+ [
957
+ image.transpose(0, 1),
958
+ torch.zeros(3, num_frames - 1, height, width).to(image.device),
959
+ ],
960
+ dim=1,
961
+ )
962
+
963
+ msk = torch.concat(
964
+ [torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1
965
+ )
966
+ msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
967
+ msk = msk.transpose(1, 2)[0]
968
+
969
+ y = pipe.vae.encode(
970
+ [vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)],
971
+ device=pipe.device,
972
+ tiled=tiled,
973
+ tile_size=tile_size,
974
+ tile_stride=tile_stride,
975
+ )[0]
976
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
977
+ y = torch.concat([msk, y])
978
+ y = y.unsqueeze(0)
979
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
980
+ return {"y": y}
981
+
982
+
983
+ class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
984
+ """
985
+ Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
986
+ """
987
+
988
+ def __init__(self):
989
+ super().__init__(
990
+ input_params=(
991
+ "input_image",
992
+ "latents",
993
+ "height",
994
+ "width",
995
+ "tiled",
996
+ "tile_size",
997
+ "tile_stride",
998
+ ),
999
+ onload_model_names=("vae",),
1000
+ )
1001
+
1002
+ def process(
1003
+ self,
1004
+ pipe: WanVideoPipeline_FaceSwap,
1005
+ input_image,
1006
+ latents,
1007
+ height,
1008
+ width,
1009
+ tiled,
1010
+ tile_size,
1011
+ tile_stride,
1012
+ ):
1013
+ if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
1014
+ return {}
1015
+ pipe.load_models_to_device(self.onload_model_names)
1016
+ image = pipe.preprocess_image(input_image.resize((width, height))).transpose(
1017
+ 0, 1
1018
+ )
1019
+ z = pipe.vae.encode(
1020
+ [image],
1021
+ device=pipe.device,
1022
+ tiled=tiled,
1023
+ tile_size=tile_size,
1024
+ tile_stride=tile_stride,
1025
+ )
1026
+ latents[:, :, 0:1] = z
1027
+ return {
1028
+ "latents": latents,
1029
+ "fuse_vae_embedding_in_latents": True,
1030
+ "first_frame_latents": z,
1031
+ }
1032
+
1033
+
1034
+ class WanVideoUnit_FunControl(PipelineUnit):
1035
+ def __init__(self):
1036
+ super().__init__(
1037
+ input_params=(
1038
+ "control_video",
1039
+ "num_frames",
1040
+ "height",
1041
+ "width",
1042
+ "tiled",
1043
+ "tile_size",
1044
+ "tile_stride",
1045
+ "clip_feature",
1046
+ "y",
1047
+ ),
1048
+ onload_model_names=("vae",),
1049
+ )
1050
+
1051
+ def process(
1052
+ self,
1053
+ pipe: WanVideoPipeline_FaceSwap,
1054
+ control_video,
1055
+ num_frames,
1056
+ height,
1057
+ width,
1058
+ tiled,
1059
+ tile_size,
1060
+ tile_stride,
1061
+ clip_feature,
1062
+ y,
1063
+ ):
1064
+ if control_video is None:
1065
+ return {}
1066
+ pipe.load_models_to_device(self.onload_model_names)
1067
+ control_video = pipe.preprocess_video(control_video)
1068
+ control_latents = pipe.vae.encode(
1069
+ control_video,
1070
+ device=pipe.device,
1071
+ tiled=tiled,
1072
+ tile_size=tile_size,
1073
+ tile_stride=tile_stride,
1074
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1075
+ control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
1076
+ if clip_feature is None or y is None:
1077
+ clip_feature = torch.zeros(
1078
+ (1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device
1079
+ )
1080
+ y = torch.zeros(
1081
+ (1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8),
1082
+ dtype=pipe.torch_dtype,
1083
+ device=pipe.device,
1084
+ )
1085
+ else:
1086
+ y = y[:, -16:]
1087
+ y = torch.concat([control_latents, y], dim=1)
1088
+ return {"clip_feature": clip_feature, "y": y}
1089
+
1090
+
1091
+ class WanVideoUnit_FunReference(PipelineUnit):
1092
+ def __init__(self):
1093
+ super().__init__(
1094
+ input_params=("reference_image", "height", "width", "reference_image"),
1095
+ onload_model_names=("vae",),
1096
+ )
1097
+
1098
+ def process(self, pipe: WanVideoPipeline_FaceSwap, reference_image, height, width):
1099
+ if reference_image is None:
1100
+ return {}
1101
+ pipe.load_models_to_device(["vae"])
1102
+ reference_image = reference_image.resize((width, height))
1103
+ reference_latents = pipe.preprocess_video([reference_image])
1104
+ reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
1105
+ clip_feature = pipe.preprocess_image(reference_image)
1106
+ clip_feature = pipe.image_encoder.encode_image([clip_feature])
1107
+ return {"reference_latents": reference_latents, "clip_feature": clip_feature}
1108
+
1109
+
1110
+ class WanVideoUnit_FunCameraControl(PipelineUnit):
1111
+ def __init__(self):
1112
+ super().__init__(
1113
+ input_params=(
1114
+ "height",
1115
+ "width",
1116
+ "num_frames",
1117
+ "camera_control_direction",
1118
+ "camera_control_speed",
1119
+ "camera_control_origin",
1120
+ "latents",
1121
+ "input_image",
1122
+ ),
1123
+ onload_model_names=("vae",),
1124
+ )
1125
+
1126
+ def process(
1127
+ self,
1128
+ pipe: WanVideoPipeline_FaceSwap,
1129
+ height,
1130
+ width,
1131
+ num_frames,
1132
+ camera_control_direction,
1133
+ camera_control_speed,
1134
+ camera_control_origin,
1135
+ latents,
1136
+ input_image,
1137
+ ):
1138
+ if camera_control_direction is None:
1139
+ return {}
1140
+ camera_control_plucker_embedding = (
1141
+ pipe.dit.control_adapter.process_camera_coordinates(
1142
+ camera_control_direction,
1143
+ num_frames,
1144
+ height,
1145
+ width,
1146
+ camera_control_speed,
1147
+ camera_control_origin,
1148
+ )
1149
+ )
1150
+
1151
+ control_camera_video = (
1152
+ camera_control_plucker_embedding[:num_frames]
1153
+ .permute([3, 0, 1, 2])
1154
+ .unsqueeze(0)
1155
+ )
1156
+ control_camera_latents = torch.concat(
1157
+ [
1158
+ torch.repeat_interleave(
1159
+ control_camera_video[:, :, 0:1], repeats=4, dim=2
1160
+ ),
1161
+ control_camera_video[:, :, 1:],
1162
+ ],
1163
+ dim=2,
1164
+ ).transpose(1, 2)
1165
+ b, f, c, h, w = control_camera_latents.shape
1166
+ control_camera_latents = (
1167
+ control_camera_latents.contiguous()
1168
+ .view(b, f // 4, 4, c, h, w)
1169
+ .transpose(2, 3)
1170
+ )
1171
+ control_camera_latents = (
1172
+ control_camera_latents.contiguous()
1173
+ .view(b, f // 4, c * 4, h, w)
1174
+ .transpose(1, 2)
1175
+ )
1176
+ control_camera_latents_input = control_camera_latents.to(
1177
+ device=pipe.device, dtype=pipe.torch_dtype
1178
+ )
1179
+
1180
+ input_image = input_image.resize((width, height))
1181
+ input_latents = pipe.preprocess_video([input_image])
1182
+ pipe.load_models_to_device(self.onload_model_names)
1183
+ input_latents = pipe.vae.encode(input_latents, device=pipe.device)
1184
+ y = torch.zeros_like(latents).to(pipe.device)
1185
+ y[:, :, :1] = input_latents
1186
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
1187
+ return {"control_camera_latents_input": control_camera_latents_input, "y": y}
1188
+
1189
+
1190
+ class WanVideoUnit_SpeedControl(PipelineUnit):
1191
+ def __init__(self):
1192
+ super().__init__(input_params=("motion_bucket_id",))
1193
+
1194
+ def process(self, pipe: WanVideoPipeline_FaceSwap, motion_bucket_id):
1195
+ if motion_bucket_id is None:
1196
+ return {}
1197
+ motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(
1198
+ dtype=pipe.torch_dtype, device=pipe.device
1199
+ )
1200
+ return {"motion_bucket_id": motion_bucket_id}
1201
+
1202
+
1203
+ class WanVideoUnit_VACE(PipelineUnit):
1204
+ def __init__(self):
1205
+ super().__init__(
1206
+ input_params=(
1207
+ "vace_video",
1208
+ "vace_video_mask",
1209
+ "vace_reference_image",
1210
+ "vace_scale",
1211
+ "height",
1212
+ "width",
1213
+ "num_frames",
1214
+ "tiled",
1215
+ "tile_size",
1216
+ "tile_stride",
1217
+ ),
1218
+ onload_model_names=("vae",),
1219
+ )
1220
+
1221
+ def process(
1222
+ self,
1223
+ pipe: WanVideoPipeline_FaceSwap,
1224
+ vace_video,
1225
+ vace_video_mask,
1226
+ vace_reference_image,
1227
+ vace_scale,
1228
+ height,
1229
+ width,
1230
+ num_frames,
1231
+ tiled,
1232
+ tile_size,
1233
+ tile_stride,
1234
+ ):
1235
+ if (
1236
+ vace_video is not None
1237
+ or vace_video_mask is not None
1238
+ or vace_reference_image is not None
1239
+ ):
1240
+ pipe.load_models_to_device(["vae"])
1241
+ if vace_video is None:
1242
+ vace_video = torch.zeros(
1243
+ (1, 3, num_frames, height, width),
1244
+ dtype=pipe.torch_dtype,
1245
+ device=pipe.device,
1246
+ )
1247
+ else:
1248
+ vace_video = pipe.preprocess_video(vace_video)
1249
+
1250
+ if vace_video_mask is None:
1251
+ vace_video_mask = torch.ones_like(vace_video)
1252
+ else:
1253
+ vace_video_mask = pipe.preprocess_video(
1254
+ vace_video_mask, min_value=0, max_value=1
1255
+ )
1256
+
1257
+ inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
1258
+ reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
1259
+ inactive = pipe.vae.encode(
1260
+ inactive,
1261
+ device=pipe.device,
1262
+ tiled=tiled,
1263
+ tile_size=tile_size,
1264
+ tile_stride=tile_stride,
1265
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1266
+ reactive = pipe.vae.encode(
1267
+ reactive,
1268
+ device=pipe.device,
1269
+ tiled=tiled,
1270
+ tile_size=tile_size,
1271
+ tile_stride=tile_stride,
1272
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1273
+ vace_video_latents = torch.concat((inactive, reactive), dim=1)
1274
+
1275
+ vace_mask_latents = rearrange(
1276
+ vace_video_mask[0, 0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8
1277
+ )
1278
+ vace_mask_latents = torch.nn.functional.interpolate(
1279
+ vace_mask_latents,
1280
+ size=(
1281
+ (vace_mask_latents.shape[2] + 3) // 4,
1282
+ vace_mask_latents.shape[3],
1283
+ vace_mask_latents.shape[4],
1284
+ ),
1285
+ mode="nearest-exact",
1286
+ )
1287
+
1288
+ if vace_reference_image is None:
1289
+ pass
1290
+ else:
1291
+ vace_reference_image = pipe.preprocess_video([vace_reference_image])
1292
+ vace_reference_latents = pipe.vae.encode(
1293
+ vace_reference_image,
1294
+ device=pipe.device,
1295
+ tiled=tiled,
1296
+ tile_size=tile_size,
1297
+ tile_stride=tile_stride,
1298
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1299
+ vace_reference_latents = torch.concat(
1300
+ (vace_reference_latents, torch.zeros_like(vace_reference_latents)),
1301
+ dim=1,
1302
+ )
1303
+ vace_video_latents = torch.concat(
1304
+ (vace_reference_latents, vace_video_latents), dim=2
1305
+ )
1306
+ vace_mask_latents = torch.concat(
1307
+ (torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents),
1308
+ dim=2,
1309
+ )
1310
+
1311
+ vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
1312
+ return {"vace_context": vace_context, "vace_scale": vace_scale}
1313
+ else:
1314
+ return {"vace_context": None, "vace_scale": vace_scale}
1315
+
1316
+
1317
+ class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
1318
+ def __init__(self):
1319
+ super().__init__(input_params=())
1320
+
1321
+ def process(self, pipe: WanVideoPipeline_FaceSwap):
1322
+ if hasattr(pipe, "use_unified_sequence_parallel"):
1323
+ if pipe.use_unified_sequence_parallel:
1324
+ return {"use_unified_sequence_parallel": True}
1325
+ return {}
1326
+
1327
+
1328
+ class WanVideoUnit_TeaCache(PipelineUnit):
1329
+ def __init__(self):
1330
+ super().__init__(
1331
+ seperate_cfg=True,
1332
+ input_params_posi={
1333
+ "num_inference_steps": "num_inference_steps",
1334
+ "tea_cache_l1_thresh": "tea_cache_l1_thresh",
1335
+ "tea_cache_model_id": "tea_cache_model_id",
1336
+ },
1337
+ input_params_nega={
1338
+ "num_inference_steps": "num_inference_steps",
1339
+ "tea_cache_l1_thresh": "tea_cache_l1_thresh",
1340
+ "tea_cache_model_id": "tea_cache_model_id",
1341
+ },
1342
+ )
1343
+
1344
+ def process(
1345
+ self,
1346
+ pipe: WanVideoPipeline_FaceSwap,
1347
+ num_inference_steps,
1348
+ tea_cache_l1_thresh,
1349
+ tea_cache_model_id,
1350
+ ):
1351
+ if tea_cache_l1_thresh is None:
1352
+ return {}
1353
+ return {
1354
+ "tea_cache": TeaCache(
1355
+ num_inference_steps,
1356
+ rel_l1_thresh=tea_cache_l1_thresh,
1357
+ model_id=tea_cache_model_id,
1358
+ )
1359
+ }
1360
+
1361
+
1362
+ class WanVideoUnit_CfgMerger(PipelineUnit):
1363
+ def __init__(self):
1364
+ super().__init__(take_over=True)
1365
+ self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"]
1366
+
1367
+ def process(
1368
+ self, pipe: WanVideoPipeline_FaceSwap, inputs_shared, inputs_posi, inputs_nega
1369
+ ):
1370
+ if not inputs_shared["cfg_merge"]:
1371
+ return inputs_shared, inputs_posi, inputs_nega
1372
+ for name in self.concat_tensor_names:
1373
+ tensor_posi = inputs_posi.get(name)
1374
+ tensor_nega = inputs_nega.get(name)
1375
+ tensor_shared = inputs_shared.get(name)
1376
+ if tensor_posi is not None and tensor_nega is not None:
1377
+ inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0)
1378
+ elif tensor_shared is not None:
1379
+ inputs_shared[name] = torch.concat(
1380
+ (tensor_shared, tensor_shared), dim=0
1381
+ )
1382
+ inputs_posi.clear()
1383
+ inputs_nega.clear()
1384
+ return inputs_shared, inputs_posi, inputs_nega
1385
+
1386
+
1387
+ class TeaCache:
1388
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
1389
+ self.num_inference_steps = num_inference_steps
1390
+ self.step = 0
1391
+ self.accumulated_rel_l1_distance = 0
1392
+ self.previous_modulated_input = None
1393
+ self.rel_l1_thresh = rel_l1_thresh
1394
+ self.previous_residual = None
1395
+ self.previous_hidden_states = None
1396
+
1397
+ self.coefficients_dict = {
1398
+ "Wan2.1-T2V-1.3B": [
1399
+ -5.21862437e04,
1400
+ 9.23041404e03,
1401
+ -5.28275948e02,
1402
+ 1.36987616e01,
1403
+ -4.99875664e-02,
1404
+ ],
1405
+ "Wan2.1-T2V-14B": [
1406
+ -3.03318725e05,
1407
+ 4.90537029e04,
1408
+ -2.65530556e03,
1409
+ 5.87365115e01,
1410
+ -3.15583525e-01,
1411
+ ],
1412
+ "Wan2.1-I2V-14B-480P": [
1413
+ 2.57151496e05,
1414
+ -3.54229917e04,
1415
+ 1.40286849e03,
1416
+ -1.35890334e01,
1417
+ 1.32517977e-01,
1418
+ ],
1419
+ "Wan2.1-I2V-14B-720P": [
1420
+ 8.10705460e03,
1421
+ 2.13393892e03,
1422
+ -3.72934672e02,
1423
+ 1.66203073e01,
1424
+ -4.17769401e-02,
1425
+ ],
1426
+ }
1427
+ if model_id not in self.coefficients_dict:
1428
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
1429
+ raise ValueError(
1430
+ f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids})."
1431
+ )
1432
+ self.coefficients = self.coefficients_dict[model_id]
1433
+
1434
+ def check(self, dit: WanModel, x, t_mod):
1435
+ modulated_inp = t_mod.clone()
1436
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
1437
+ should_calc = True
1438
+ self.accumulated_rel_l1_distance = 0
1439
+ else:
1440
+ coefficients = self.coefficients
1441
+ rescale_func = np.poly1d(coefficients)
1442
+ self.accumulated_rel_l1_distance += rescale_func(
1443
+ (
1444
+ (modulated_inp - self.previous_modulated_input).abs().mean()
1445
+ / self.previous_modulated_input.abs().mean()
1446
+ )
1447
+ .cpu()
1448
+ .item()
1449
+ )
1450
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
1451
+ should_calc = False
1452
+ else:
1453
+ should_calc = True
1454
+ self.accumulated_rel_l1_distance = 0
1455
+ self.previous_modulated_input = modulated_inp
1456
+ self.step += 1
1457
+ if self.step == self.num_inference_steps:
1458
+ self.step = 0
1459
+ if should_calc:
1460
+ self.previous_hidden_states = x.clone()
1461
+ return not should_calc
1462
+
1463
+ def store(self, hidden_states):
1464
+ self.previous_residual = hidden_states - self.previous_hidden_states
1465
+ self.previous_hidden_states = None
1466
+
1467
+ def update(self, hidden_states):
1468
+ hidden_states = hidden_states + self.previous_residual
1469
+ return hidden_states
1470
+
1471
+
1472
+ class TemporalTiler_BCTHW:
1473
+ def __init__(self):
1474
+ pass
1475
+
1476
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
1477
+ x = torch.ones((length,))
1478
+ if not left_bound:
1479
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
1480
+ if not right_bound:
1481
+ x[-border_width:] = torch.flip(
1482
+ (torch.arange(border_width) + 1) / border_width, dims=(0,)
1483
+ )
1484
+ return x
1485
+
1486
+ def build_mask(self, data, is_bound, border_width):
1487
+ _, _, T, _, _ = data.shape
1488
+ t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
1489
+ mask = repeat(t, "T -> 1 1 T 1 1")
1490
+ return mask
1491
+
1492
+ def run(
1493
+ self,
1494
+ model_fn,
1495
+ sliding_window_size,
1496
+ sliding_window_stride,
1497
+ computation_device,
1498
+ computation_dtype,
1499
+ model_kwargs,
1500
+ tensor_names,
1501
+ batch_size=None,
1502
+ ):
1503
+ tensor_names = [
1504
+ tensor_name
1505
+ for tensor_name in tensor_names
1506
+ if model_kwargs.get(tensor_name) is not None
1507
+ ]
1508
+ tensor_dict = {
1509
+ tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names
1510
+ }
1511
+ B, C, T, H, W = tensor_dict[tensor_names[0]].shape
1512
+ if batch_size is not None:
1513
+ B *= batch_size
1514
+ data_device, data_dtype = (
1515
+ tensor_dict[tensor_names[0]].device,
1516
+ tensor_dict[tensor_names[0]].dtype,
1517
+ )
1518
+ value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)
1519
+ weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
1520
+ for t in range(0, T, sliding_window_stride):
1521
+ if (
1522
+ t - sliding_window_stride >= 0
1523
+ and t - sliding_window_stride + sliding_window_size >= T
1524
+ ):
1525
+ continue
1526
+ t_ = min(t + sliding_window_size, T)
1527
+ model_kwargs.update(
1528
+ {
1529
+ tensor_name: tensor_dict[tensor_name][:, :, t:t_:, :].to(
1530
+ device=computation_device, dtype=computation_dtype
1531
+ )
1532
+ for tensor_name in tensor_names
1533
+ }
1534
+ )
1535
+ model_output = model_fn(**model_kwargs).to(
1536
+ device=data_device, dtype=data_dtype
1537
+ )
1538
+ mask = self.build_mask(
1539
+ model_output,
1540
+ is_bound=(t == 0, t_ == T),
1541
+ border_width=(sliding_window_size - sliding_window_stride,),
1542
+ ).to(device=data_device, dtype=data_dtype)
1543
+ value[:, :, t:t_, :, :] += model_output * mask
1544
+ weight[:, :, t:t_, :, :] += mask
1545
+ value /= weight
1546
+ model_kwargs.update(tensor_dict)
1547
+ return value
1548
+
1549
+
1550
+ def model_fn_wan_video(
1551
+ dit: WanModel,
1552
+ motion_controller: WanMotionControllerModel = None,
1553
+ vace: VaceWanModel = None,
1554
+ latents: torch.Tensor = None,
1555
+ timestep: torch.Tensor = None,
1556
+ context: torch.Tensor = None,
1557
+ clip_feature: Optional[torch.Tensor] = None,
1558
+ y: Optional[torch.Tensor] = None,
1559
+ reference_latents=None,
1560
+ vace_context=None,
1561
+ vace_scale=1.0,
1562
+ tea_cache: TeaCache = None,
1563
+ use_unified_sequence_parallel: bool = False,
1564
+ motion_bucket_id: Optional[torch.Tensor] = None,
1565
+ sliding_window_size: Optional[int] = None,
1566
+ sliding_window_stride: Optional[int] = None,
1567
+ cfg_merge: bool = False,
1568
+ use_gradient_checkpointing: bool = False,
1569
+ use_gradient_checkpointing_offload: bool = False,
1570
+ control_camera_latents_input=None,
1571
+ fuse_vae_embedding_in_latents: bool = False,
1572
+ ip_image=None,
1573
+ **kwargs,
1574
+ ):
1575
+ if sliding_window_size is not None and sliding_window_stride is not None:
1576
+ model_kwargs = dict(
1577
+ dit=dit,
1578
+ motion_controller=motion_controller,
1579
+ vace=vace,
1580
+ latents=latents,
1581
+ timestep=timestep,
1582
+ context=context,
1583
+ clip_feature=clip_feature,
1584
+ y=y,
1585
+ reference_latents=reference_latents,
1586
+ vace_context=vace_context,
1587
+ vace_scale=vace_scale,
1588
+ tea_cache=tea_cache,
1589
+ use_unified_sequence_parallel=use_unified_sequence_parallel,
1590
+ motion_bucket_id=motion_bucket_id,
1591
+ )
1592
+ return TemporalTiler_BCTHW().run(
1593
+ model_fn_wan_video,
1594
+ sliding_window_size,
1595
+ sliding_window_stride,
1596
+ latents.device,
1597
+ latents.dtype,
1598
+ model_kwargs=model_kwargs,
1599
+ tensor_names=["latents", "y"],
1600
+ batch_size=2 if cfg_merge else 1,
1601
+ )
1602
+
1603
+ if use_unified_sequence_parallel:
1604
+ import torch.distributed as dist
1605
+ from xfuser.core.distributed import (
1606
+ get_sequence_parallel_rank,
1607
+ get_sequence_parallel_world_size,
1608
+ get_sp_group,
1609
+ )
1610
+ x_ip = None
1611
+ t_mod_ip = None
1612
+ # Timestep
1613
+ if dit.seperated_timestep and fuse_vae_embedding_in_latents:
1614
+ timestep = torch.concat(
1615
+ [
1616
+ torch.zeros(
1617
+ (1, latents.shape[3] * latents.shape[4] // 4),
1618
+ dtype=latents.dtype,
1619
+ device=latents.device,
1620
+ ),
1621
+ torch.ones(
1622
+ (latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4),
1623
+ dtype=latents.dtype,
1624
+ device=latents.device,
1625
+ )
1626
+ * timestep,
1627
+ ]
1628
+ ).flatten()
1629
+ t = dit.time_embedding(
1630
+ sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)
1631
+ )
1632
+ t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
1633
+ else:
1634
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
1635
+ t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
1636
+
1637
+ if ip_image is not None:
1638
+ timestep_ip = torch.zeros_like(timestep) # [B] with 0s
1639
+ t_ip = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep_ip))
1640
+ t_mod_ip = dit.time_projection(t_ip).unflatten(1, (6, dit.dim))
1641
+
1642
+ # Motion Controller
1643
+ if motion_bucket_id is not None and motion_controller is not None:
1644
+ t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
1645
+ context = dit.text_embedding(context)
1646
+
1647
+ x = latents
1648
+ # Merged cfg
1649
+ if x.shape[0] != context.shape[0]:
1650
+ x = torch.concat([x] * context.shape[0], dim=0)
1651
+ if timestep.shape[0] != context.shape[0]:
1652
+ timestep = torch.concat([timestep] * context.shape[0], dim=0)
1653
+
1654
+ # Image Embedding
1655
+ if y is not None and dit.require_vae_embedding:
1656
+ x = torch.cat([x, y], dim=1)
1657
+ if clip_feature is not None and dit.require_clip_embedding:
1658
+ clip_embdding = dit.img_emb(clip_feature)
1659
+ context = torch.cat([clip_embdding, context], dim=1)
1660
+
1661
+ # Add camera control
1662
+ x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
1663
+
1664
+ # Reference image
1665
+ if reference_latents is not None:
1666
+ if len(reference_latents.shape) == 5:
1667
+ reference_latents = reference_latents[:, :, 0]
1668
+ reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2)
1669
+ x = torch.concat([reference_latents, x], dim=1)
1670
+ f += 1
1671
+
1672
+ offset = 1
1673
+ freqs = (
1674
+ torch.cat(
1675
+ [
1676
+ dit.freqs[0][offset : f + offset].view(f, 1, 1, -1).expand(f, h, w, -1),
1677
+ dit.freqs[1][offset : h + offset].view(1, h, 1, -1).expand(f, h, w, -1),
1678
+ dit.freqs[2][offset : w + offset].view(1, 1, w, -1).expand(f, h, w, -1),
1679
+ ],
1680
+ dim=-1,
1681
+ )
1682
+ .reshape(f * h * w, 1, -1)
1683
+ .to(x.device)
1684
+ )
1685
+
1686
+ ############################################################################################
1687
+ if ip_image is not None:
1688
+ x_ip, (f_ip, h_ip, w_ip) = dit.patchify(
1689
+ ip_image
1690
+ ) # x_ip [1, 1024, 5120] [B, N, D] f_ip = 1 h_ip = 32 w_ip = 32
1691
+ freqs_ip = (
1692
+ torch.cat(
1693
+ [
1694
+ dit.freqs[0][0].view(f_ip, 1, 1, -1).expand(f_ip, h_ip, w_ip, -1),
1695
+ dit.freqs[1][h + offset : h + offset + h_ip]
1696
+ .view(1, h_ip, 1, -1)
1697
+ .expand(f_ip, h_ip, w_ip, -1),
1698
+ dit.freqs[2][w + offset : w + offset + w_ip]
1699
+ .view(1, 1, w_ip, -1)
1700
+ .expand(f_ip, h_ip, w_ip, -1),
1701
+ ],
1702
+ dim=-1,
1703
+ )
1704
+ .reshape(f_ip * h_ip * w_ip, 1, -1)
1705
+ .to(x_ip.device)
1706
+ )
1707
+ freqs_original = freqs
1708
+ freqs = torch.cat([freqs, freqs_ip], dim=0)
1709
+ ############################################################################################
1710
+ else:
1711
+ freqs_original = freqs
1712
+ # TeaCache
1713
+ if tea_cache is not None:
1714
+ tea_cache_update = tea_cache.check(dit, x, t_mod)
1715
+ else:
1716
+ tea_cache_update = False
1717
+
1718
+ if vace_context is not None:
1719
+ vace_hints = vace(x, vace_context, context, t_mod, freqs)
1720
+
1721
+ # blocks
1722
+ if use_unified_sequence_parallel:
1723
+ if dist.is_initialized() and dist.get_world_size() > 1:
1724
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[
1725
+ get_sequence_parallel_rank()
1726
+ ]
1727
+ if tea_cache_update:
1728
+ x = tea_cache.update(x)
1729
+ else:
1730
+
1731
+ def create_custom_forward(module):
1732
+ def custom_forward(*inputs):
1733
+ return module(*inputs)
1734
+
1735
+ return custom_forward
1736
+
1737
+ for block_id, block in enumerate(dit.blocks):
1738
+ if use_gradient_checkpointing_offload:
1739
+ with torch.autograd.graph.save_on_cpu():
1740
+ x, x_ip = torch.utils.checkpoint.checkpoint(
1741
+ create_custom_forward(block),
1742
+ x,
1743
+ context,
1744
+ t_mod,
1745
+ freqs,
1746
+ x_ip=x_ip,
1747
+ t_mod_ip=t_mod_ip,
1748
+ use_reentrant=False,
1749
+ )
1750
+ elif use_gradient_checkpointing:
1751
+ x, x_ip = torch.utils.checkpoint.checkpoint(
1752
+ create_custom_forward(block),
1753
+ x,
1754
+ context,
1755
+ t_mod,
1756
+ freqs,
1757
+ x_ip=x_ip,
1758
+ t_mod_ip=t_mod_ip,
1759
+ use_reentrant=False,
1760
+ )
1761
+ else:
1762
+ x, x_ip = block(x, context, t_mod, freqs, x_ip=x_ip, t_mod_ip=t_mod_ip)
1763
+ if vace_context is not None and block_id in vace.vace_layers_mapping:
1764
+ current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
1765
+ if (
1766
+ use_unified_sequence_parallel
1767
+ and dist.is_initialized()
1768
+ and dist.get_world_size() > 1
1769
+ ):
1770
+ current_vace_hint = torch.chunk(
1771
+ current_vace_hint, get_sequence_parallel_world_size(), dim=1
1772
+ )[get_sequence_parallel_rank()]
1773
+ x = x + current_vace_hint * vace_scale
1774
+ if tea_cache is not None:
1775
+ tea_cache.store(x)
1776
+
1777
+ x = dit.head(x, t)
1778
+ if use_unified_sequence_parallel:
1779
+ if dist.is_initialized() and dist.get_world_size() > 1:
1780
+ x = get_sp_group().all_gather(x, dim=1)
1781
+ # Remove reference latents
1782
+ if reference_latents is not None:
1783
+ x = x[:, reference_latents.shape[1] :]
1784
+ f -= 1
1785
+ x = dit.unpatchify(x, (f, h, w))
1786
+ return x
preprocessor/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .image_input_preprocessor import FaceProcessor
2
+ from .videomask_generator import VideoMaskGenerator
preprocessor/image_input_preprocessor.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import requests
4
+ import torch
5
+ import numpy as np
6
+ import PIL.Image
7
+ import PIL.ImageOps
8
+ from insightface.app import FaceAnalysis
9
+ from facexlib.parsing import init_parsing_model
10
+ from torchvision.transforms.functional import normalize
11
+ from typing import Union, Optional
12
+
13
+
14
+ def _img2tensor(img: np.ndarray, bgr2rgb: bool = True) -> torch.Tensor:
15
+ if bgr2rgb:
16
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
17
+ img = img.astype(np.float32) / 255.0
18
+ img = np.transpose(img, (2, 0, 1))
19
+ return torch.from_numpy(img)
20
+
21
+
22
+ def _pad_to_square(img: np.ndarray, pad_color: int = 255) -> np.ndarray:
23
+ h, w, _ = img.shape
24
+ if h == w:
25
+ return img
26
+
27
+ if h > w:
28
+ pad_size = (h - w) // 2
29
+ padded_img = cv2.copyMakeBorder(
30
+ img,
31
+ 0,
32
+ 0,
33
+ pad_size,
34
+ h - w - pad_size,
35
+ cv2.BORDER_CONSTANT,
36
+ value=[pad_color] * 3,
37
+ )
38
+ else:
39
+ pad_size = (w - h) // 2
40
+ padded_img = cv2.copyMakeBorder(
41
+ img,
42
+ pad_size,
43
+ w - h - pad_size,
44
+ 0,
45
+ 0,
46
+ cv2.BORDER_CONSTANT,
47
+ value=[pad_color] * 3,
48
+ )
49
+
50
+ return padded_img
51
+
52
+
53
+ class FaceProcessor:
54
+ def __init__(self, antelopv2_path=".", device: Optional[torch.device] = None):
55
+ if device is None:
56
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ else:
58
+ self.device = device
59
+
60
+ providers = (
61
+ ["CUDAExecutionProvider"]
62
+ if self.device.type == "cuda"
63
+ else ["CPUExecutionProvider"]
64
+ )
65
+ self.app = FaceAnalysis(
66
+ name="antelopev2", root=antelopv2_path, providers=providers
67
+ )
68
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
69
+
70
+ self.parsing_model = init_parsing_model(
71
+ model_name="bisenet", device=self.device
72
+ )
73
+ self.parsing_model.eval()
74
+
75
+ print("FaceProcessor initialized successfully.")
76
+
77
+ def process(
78
+ self,
79
+ image: Union[str, PIL.Image.Image],
80
+ resize_to: int = 512,
81
+ border_thresh: int = 10,
82
+ face_crop_scale: float = 1.5,
83
+ extra_input: bool = False,
84
+ ) -> PIL.Image.Image:
85
+ if isinstance(image, str):
86
+ if image.startswith("http://") or image.startswith("https://"):
87
+ image = PIL.Image.open(requests.get(image, stream=True, timeout=10).raw)
88
+ elif os.path.isfile(image):
89
+ image = PIL.Image.open(image)
90
+ else:
91
+ raise ValueError(
92
+ f"Input string is not a valid URL or file path: {image}"
93
+ )
94
+ elif not isinstance(image, PIL.Image.Image):
95
+ raise TypeError(
96
+ "Input must be a file path, a URL, or a PIL.Image.Image object."
97
+ )
98
+
99
+ image = PIL.ImageOps.exif_transpose(image).convert("RGB")
100
+
101
+ frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
102
+
103
+ faces = self.app.get(frame)
104
+ h, w, _ = frame.shape
105
+ image_to_process = None
106
+
107
+ if not faces:
108
+ print(
109
+ "[Warning] No face detected. Using the whole image, padded to square."
110
+ )
111
+ image_to_process = _pad_to_square(frame, pad_color=255)
112
+ else:
113
+ largest_face = max(
114
+ faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1])
115
+ )
116
+ x1, y1, x2, y2 = map(int, largest_face.bbox)
117
+
118
+ is_close_to_border = (
119
+ x1 <= border_thresh
120
+ and y1 <= border_thresh
121
+ and x2 >= w - border_thresh
122
+ and y2 >= h - border_thresh
123
+ )
124
+
125
+ if is_close_to_border:
126
+ print(
127
+ "[Info] Face is close to border, padding original image to square."
128
+ )
129
+ image_to_process = _pad_to_square(frame, pad_color=255)
130
+ else:
131
+ cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
132
+ side = int(max(x2 - x1, y2 - y1) * face_crop_scale)
133
+ half = side // 2
134
+
135
+ left = max(cx - half, 0)
136
+ top = max(cy - half, 0)
137
+ right = min(cx + half, w)
138
+ bottom = min(cy + half, h)
139
+
140
+ cropped_face = frame[top:bottom, left:right]
141
+ image_to_process = _pad_to_square(cropped_face, pad_color=255)
142
+
143
+ image_resized = cv2.resize(
144
+ image_to_process, (resize_to, resize_to), interpolation=cv2.INTER_AREA
145
+ )
146
+
147
+ face_tensor = (
148
+ _img2tensor(image_resized, bgr2rgb=True).unsqueeze(0).to(self.device)
149
+ )
150
+ with torch.no_grad():
151
+ normalized_face = normalize(face_tensor, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
152
+ parsing_out = self.parsing_model(normalized_face)[0]
153
+ parsing_mask = parsing_out.argmax(dim=1, keepdim=True)
154
+
155
+ background_mask_np = (parsing_mask.squeeze().cpu().numpy() == 0).astype(
156
+ np.uint8
157
+ )
158
+ white_background = np.ones_like(image_resized, dtype=np.uint8) * 255
159
+ mask_3channel = cv2.cvtColor(background_mask_np * 255, cv2.COLOR_GRAY2BGR)
160
+ result_img_bgr = np.where(mask_3channel == 255, white_background, image_resized)
161
+ result_img_rgb = cv2.cvtColor(result_img_bgr, cv2.COLOR_BGR2RGB)
162
+ img_white_bg = PIL.Image.fromarray(result_img_rgb)
163
+ if extra_input:
164
+ # 2. Create image with transparent background (new logic)
165
+ # Create an alpha channel: 255 for foreground (not background), 0 for background
166
+ alpha_channel = (parsing_mask.squeeze().cpu().numpy() != 0).astype(
167
+ np.uint8
168
+ ) * 255
169
+
170
+ # Convert the resized BGR image to RGB
171
+ image_resized_rgb = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB)
172
+
173
+ # Stack RGB channels with the new alpha channel
174
+ rgba_image = np.dstack((image_resized_rgb, alpha_channel))
175
+
176
+ # Create PIL image from the RGBA numpy array
177
+ img_transparent_bg = PIL.Image.fromarray(rgba_image, "RGBA")
178
+
179
+ return img_white_bg, img_transparent_bg
180
+ else:
181
+ return img_white_bg
preprocessor/videomask_generator.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ from torchvision.transforms.functional import normalize
5
+ from tqdm import tqdm
6
+ from PIL import Image, ImageOps
7
+ import random
8
+ import os
9
+ import requests
10
+ from insightface.app import FaceAnalysis
11
+ from facexlib.parsing import init_parsing_model
12
+ from typing import Union, Optional, Tuple, List
13
+
14
+ # --- Helper Functions (Unchanged) ---
15
+ def tensor_to_cv2_img(tensor_frame: torch.Tensor) -> np.ndarray:
16
+ """Converts a single RGB torch tensor to a BGR OpenCV image."""
17
+ img_np = (tensor_frame.cpu().numpy() * 255).astype(np.uint8)
18
+ return cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
19
+
20
+ def tensor_to_cv2_bgra_img(tensor_frame: torch.Tensor) -> np.ndarray:
21
+ """Converts a single RGBA torch tensor to a BGRA OpenCV image."""
22
+ if tensor_frame.shape[2] != 4:
23
+ raise ValueError("Input tensor must be an RGBA image with 4 channels.")
24
+ img_np = (tensor_frame.cpu().numpy() * 255).astype(np.uint8)
25
+ return cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGRA)
26
+
27
+ def pil_to_tensor(image: Image.Image) -> torch.Tensor:
28
+ """Converts a PIL image to a torch tensor."""
29
+ return torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
30
+
31
+ class VideoMaskGenerator:
32
+ def __init__(self, antelopv2_path=".", device: Optional[torch.device] = None):
33
+ if device is None:
34
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ else:
36
+ self.device = device
37
+
38
+ print(f"Using device: {self.device}")
39
+
40
+ providers = ["CUDAExecutionProvider"] if self.device.type == "cuda" else ["CPUExecutionProvider"]
41
+
42
+ # Initialize face detection and landmark model (antelopev2 provides both)
43
+ self.detection_model = FaceAnalysis(name="antelopev2", root=antelopv2_path, providers=providers)
44
+ self.detection_model.prepare(ctx_id=0, det_size=(640, 640))
45
+
46
+ # Initialize face parsing model
47
+ self.parsing_model = init_parsing_model(model_name="bisenet", device=self.device)
48
+ self.parsing_model.eval()
49
+
50
+ print("FaceProcessor initialized successfully.")
51
+
52
+ def process(
53
+ self,
54
+ video_path: str,
55
+ face_image: Union[str, Image.Image],
56
+ confidence_threshold: float = 0.5,
57
+ face_crop_scale: float = 1.5,
58
+ dilation_kernel_size: int = 10,
59
+ feather_amount: int = 21,
60
+ random_horizontal_flip_chance: float = 0.0,
61
+ match_angle_and_size: bool = True
62
+ ) -> Tuple[np.ndarray, np.ndarray, int, int, int]:
63
+ """
64
+ Processes a video to replace a face with a provided face image.
65
+
66
+ Args:
67
+ video_path (str): Path to the input video file.
68
+ face_image (Union[str, Image.Image]): Path or PIL image of the face to paste.
69
+ confidence_threshold (float): Confidence threshold for face detection.
70
+ face_crop_scale (float): Scale factor for cropping the detected face box.
71
+ dilation_kernel_size (int): Kernel size for mask dilation.
72
+ feather_amount (int): Amount of feathering for the mask edges.
73
+ random_horizontal_flip_chance (float): Chance to flip the source face horizontally.
74
+ match_angle_and_size (bool): Whether to use landmark matching for rotation and scale.
75
+
76
+ Returns:
77
+ Tuple[np.ndarray, np.ndarray, int, int, int]:
78
+ - Processed video as a numpy array (F, H, W, C).
79
+ - Generated masks as a numpy array (F, H, W).
80
+ - Width of the processed video.
81
+ - Height of the processed video.
82
+ - Number of frames in the processed video.
83
+ """
84
+ # --- Video Pre-processing ---
85
+ if not os.path.exists(video_path):
86
+ raise FileNotFoundError(f"Video file not found at: {video_path}")
87
+
88
+ cap = cv2.VideoCapture(video_path)
89
+ frames = []
90
+ while cap.isOpened():
91
+ ret, frame = cap.read()
92
+ if not ret:
93
+ break
94
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
95
+ cap.release()
96
+
97
+ if not frames:
98
+ raise ValueError("Could not read any frames from the video.")
99
+
100
+ video_np = np.array(frames)
101
+
102
+ h, w = video_np.shape[1], video_np.shape[2]
103
+ new_h, new_w = (h // 16) * 16, (w // 16) * 16
104
+
105
+ y_start = (h - new_h) // 2
106
+ x_start = (w - new_w) // 2
107
+ video_cropped = video_np[:, y_start:y_start+new_h, x_start:x_start+new_w, :]
108
+
109
+ num_frames = video_cropped.shape[0]
110
+ target_frames = (num_frames // 4) * 4 + 1
111
+ video_trimmed = video_cropped[:target_frames]
112
+
113
+ final_h, final_w, final_frames = video_trimmed.shape[1], video_trimmed.shape[2], video_trimmed.shape[0]
114
+ print(f"Video pre-processed: {final_w}x{final_h}, {final_frames} frames.")
115
+
116
+ # --- Face Image Pre-processing & Source Landmark Extraction ---
117
+ if isinstance(face_image, str):
118
+ if face_image.startswith("http"):
119
+ face_image = Image.open(requests.get(face_image, stream=True, timeout=10).raw)
120
+ else:
121
+ face_image = Image.open(face_image)
122
+
123
+ face_image = ImageOps.exif_transpose(face_image).convert("RGBA")
124
+ face_rgba_tensor = pil_to_tensor(face_image)
125
+ face_to_paste_cv2 = tensor_to_cv2_bgra_img(face_rgba_tensor)
126
+
127
+ source_kpts = None
128
+ if match_angle_and_size:
129
+ # Use insightface (antelopev2) to get landmarks from the source face image
130
+ source_face_bgr = cv2.cvtColor(face_to_paste_cv2, cv2.COLOR_BGRA2BGR)
131
+ source_faces = self.detection_model.get(source_face_bgr)
132
+ if source_faces:
133
+ # Use the landmarks from the first (and likely only) detected face
134
+ source_kpts = source_faces[0].kps
135
+ else:
136
+ print("[Warning] No face or landmarks found in source image. Disabling angle matching.")
137
+ match_angle_and_size = False
138
+
139
+ face_to_paste_pil = Image.fromarray((face_rgba_tensor.cpu().numpy() * 255).astype(np.uint8), 'RGBA')
140
+
141
+ # --- Main Processing Loop ---
142
+ processed_frames_list = []
143
+ mask_list = []
144
+
145
+ for i in tqdm(range(final_frames), desc="Pasting face onto frames"):
146
+ frame_rgb = video_trimmed[i]
147
+ frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
148
+
149
+ # Use insightface for detection and landmarks
150
+ faces = self.detection_model.get(frame_bgr)
151
+
152
+ pasted = False
153
+ final_mask = np.zeros((final_h, final_w), dtype=np.uint8)
154
+
155
+ if faces:
156
+ largest_face = max(faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]))
157
+
158
+ if largest_face.det_score > confidence_threshold:
159
+ # **MODIFIED BLOCK**: Use insightface landmarks for affine transform
160
+ if match_angle_and_size and source_kpts is not None:
161
+ target_kpts = largest_face.kps # Get landmarks directly from the detected face
162
+
163
+ # Estimate the transformation matrix
164
+ M, _ = cv2.estimateAffinePartial2D(source_kpts, target_kpts, method=cv2.LMEDS)
165
+
166
+ if M is not None:
167
+ # Split the RGBA source face for separate warping
168
+ b, g, r, a = cv2.split(face_to_paste_cv2)
169
+ source_rgb_cv2 = cv2.merge([r, g, b])
170
+
171
+ # Warp the face and its alpha channel
172
+ warped_face = cv2.warpAffine(source_rgb_cv2, M, (final_w, final_h))
173
+ warped_alpha = cv2.warpAffine(a, M, (final_w, final_h))
174
+
175
+ # Blend the warped face onto the frame using the warped alpha channel
176
+ alpha_float = warped_alpha.astype(np.float32) / 255.0
177
+ alpha_expanded = np.expand_dims(alpha_float, axis=2)
178
+
179
+ frame_rgb = (1.0 - alpha_expanded) * frame_rgb + alpha_expanded * warped_face
180
+ frame_rgb = frame_rgb.astype(np.uint8)
181
+ final_mask = warped_alpha
182
+ pasted = True
183
+
184
+ # Fallback to simple box-pasting if angle matching is off or fails
185
+ if not pasted:
186
+ x1, y1, x2, y2 = map(int, largest_face.bbox)
187
+ center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
188
+ side_len = int(max(x2 - x1, y2 - y1) * face_crop_scale)
189
+ half_side = side_len // 2
190
+
191
+ crop_y1, crop_x1 = max(center_y - half_side, 0), max(center_x - half_side, 0)
192
+ crop_y2, crop_x2 = min(center_y + half_side, final_h), min(center_x + half_side, final_w)
193
+
194
+ box_w, box_h = crop_x2 - crop_x1, crop_y2 - crop_y1
195
+
196
+ if box_w > 0 and box_h > 0:
197
+ source_img = face_to_paste_pil.copy()
198
+ if random.random() < random_horizontal_flip_chance:
199
+ source_img = source_img.transpose(Image.FLIP_LEFT_RIGHT)
200
+
201
+ face_resized = source_img.resize((box_w, box_h), Image.Resampling.LANCZOS)
202
+
203
+ target_frame_pil = Image.fromarray(frame_rgb)
204
+
205
+ # --- Mask Generation using BiSeNet ---
206
+ face_crop_bgr = cv2.cvtColor(frame_rgb[crop_y1:crop_y2, crop_x1:crop_x2], cv2.COLOR_RGB2BGR)
207
+ if face_crop_bgr.size > 0:
208
+ face_resized_512 = cv2.resize(face_crop_bgr, (512, 512), interpolation=cv2.INTER_AREA)
209
+ face_rgb_512 = cv2.cvtColor(face_resized_512, cv2.COLOR_BGR2RGB)
210
+ face_tensor_in = torch.from_numpy(face_rgb_512.astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0).to(self.device)
211
+
212
+ with torch.no_grad():
213
+ normalized_face = normalize(face_tensor_in, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
214
+ parsing_map = self.parsing_model(normalized_face)[0].argmax(dim=1, keepdim=True)
215
+
216
+ parsing_map_np = parsing_map.squeeze().cpu().numpy().astype(np.uint8)
217
+ parts_to_include = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] # All face parts
218
+ final_mask_512 = np.isin(parsing_map_np, parts_to_include).astype(np.uint8) * 255
219
+
220
+ if dilation_kernel_size > 0:
221
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_kernel_size, dilation_kernel_size))
222
+ final_mask_512 = cv2.dilate(final_mask_512, kernel, iterations=1)
223
+
224
+ if feather_amount > 0:
225
+ if feather_amount % 2 == 0: feather_amount += 1
226
+ final_mask_512 = cv2.GaussianBlur(final_mask_512, (feather_amount, feather_amount), 0)
227
+
228
+ mask_resized_to_crop = cv2.resize(final_mask_512, (box_w, box_h), interpolation=cv2.INTER_LINEAR)
229
+ generated_mask_pil = Image.fromarray(mask_resized_to_crop, mode='L')
230
+
231
+ target_frame_pil.paste(face_resized, (crop_x1, crop_y1), mask=generated_mask_pil)
232
+ frame_rgb = np.array(target_frame_pil)
233
+ final_mask[crop_y1:crop_y2, crop_x1:crop_x2] = mask_resized_to_crop
234
+
235
+ processed_frames_list.append(frame_rgb)
236
+ mask_list.append(final_mask)
237
+
238
+ output_video = np.stack(processed_frames_list)
239
+ # Ensure mask has a channel dimension for consistency
240
+ output_masks = np.stack(mask_list)[..., np.newaxis]
241
+
242
+ return (output_video, output_masks, final_w, final_h, final_frames)
prompters/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .prompt_refiners import Translator, BeautifulPrompt, QwenPrompt
2
+ from .omost import OmostPromter
3
+ from .wan_prompter import WanPrompter
prompters/base_prompter.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.model_manager import ModelManager
2
+ import torch
3
+
4
+
5
+ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
6
+ # Get model_max_length from self.tokenizer
7
+ length = tokenizer.model_max_length if max_length is None else max_length
8
+
9
+ # To avoid the warning. set self.tokenizer.model_max_length to +oo.
10
+ tokenizer.model_max_length = 99999999
11
+
12
+ # Tokenize it!
13
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
14
+
15
+ # Determine the real length.
16
+ max_length = (input_ids.shape[1] + length - 1) // length * length
17
+
18
+ # Restore tokenizer.model_max_length
19
+ tokenizer.model_max_length = length
20
+
21
+ # Tokenize it again with fixed length.
22
+ input_ids = tokenizer(
23
+ prompt,
24
+ return_tensors="pt",
25
+ padding="max_length",
26
+ max_length=max_length,
27
+ truncation=True,
28
+ ).input_ids
29
+
30
+ # Reshape input_ids to fit the text encoder.
31
+ num_sentence = input_ids.shape[1] // length
32
+ input_ids = input_ids.reshape((num_sentence, length))
33
+
34
+ return input_ids
35
+
36
+
37
+ class BasePrompter:
38
+ def __init__(self):
39
+ self.refiners = []
40
+ self.extenders = []
41
+
42
+ def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
43
+ for refiner_class in refiner_classes:
44
+ refiner = refiner_class.from_model_manager(model_manager)
45
+ self.refiners.append(refiner)
46
+
47
+ def load_prompt_extenders(self, model_manager: ModelManager, extender_classes=[]):
48
+ for extender_class in extender_classes:
49
+ extender = extender_class.from_model_manager(model_manager)
50
+ self.extenders.append(extender)
51
+
52
+ @torch.no_grad()
53
+ def process_prompt(self, prompt, positive=True):
54
+ if isinstance(prompt, list):
55
+ prompt = [
56
+ self.process_prompt(prompt_, positive=positive) for prompt_ in prompt
57
+ ]
58
+ else:
59
+ for refiner in self.refiners:
60
+ prompt = refiner(prompt, positive=positive)
61
+ return prompt
62
+
63
+ @torch.no_grad()
64
+ def extend_prompt(self, prompt: str, positive=True):
65
+ extended_prompt = dict(prompt=prompt)
66
+ for extender in self.extenders:
67
+ extended_prompt = extender(extended_prompt)
68
+ return extended_prompt
prompters/omost.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, TextIteratorStreamer
2
+ import difflib
3
+ import torch
4
+ import numpy as np
5
+ import re
6
+ from models.model_manager import ModelManager
7
+ from PIL import Image
8
+
9
+ valid_colors = { # r, g, b
10
+ "aliceblue": (240, 248, 255),
11
+ "antiquewhite": (250, 235, 215),
12
+ "aqua": (0, 255, 255),
13
+ "aquamarine": (127, 255, 212),
14
+ "azure": (240, 255, 255),
15
+ "beige": (245, 245, 220),
16
+ "bisque": (255, 228, 196),
17
+ "black": (0, 0, 0),
18
+ "blanchedalmond": (255, 235, 205),
19
+ "blue": (0, 0, 255),
20
+ "blueviolet": (138, 43, 226),
21
+ "brown": (165, 42, 42),
22
+ "burlywood": (222, 184, 135),
23
+ "cadetblue": (95, 158, 160),
24
+ "chartreuse": (127, 255, 0),
25
+ "chocolate": (210, 105, 30),
26
+ "coral": (255, 127, 80),
27
+ "cornflowerblue": (100, 149, 237),
28
+ "cornsilk": (255, 248, 220),
29
+ "crimson": (220, 20, 60),
30
+ "cyan": (0, 255, 255),
31
+ "darkblue": (0, 0, 139),
32
+ "darkcyan": (0, 139, 139),
33
+ "darkgoldenrod": (184, 134, 11),
34
+ "darkgray": (169, 169, 169),
35
+ "darkgrey": (169, 169, 169),
36
+ "darkgreen": (0, 100, 0),
37
+ "darkkhaki": (189, 183, 107),
38
+ "darkmagenta": (139, 0, 139),
39
+ "darkolivegreen": (85, 107, 47),
40
+ "darkorange": (255, 140, 0),
41
+ "darkorchid": (153, 50, 204),
42
+ "darkred": (139, 0, 0),
43
+ "darksalmon": (233, 150, 122),
44
+ "darkseagreen": (143, 188, 143),
45
+ "darkslateblue": (72, 61, 139),
46
+ "darkslategray": (47, 79, 79),
47
+ "darkslategrey": (47, 79, 79),
48
+ "darkturquoise": (0, 206, 209),
49
+ "darkviolet": (148, 0, 211),
50
+ "deeppink": (255, 20, 147),
51
+ "deepskyblue": (0, 191, 255),
52
+ "dimgray": (105, 105, 105),
53
+ "dimgrey": (105, 105, 105),
54
+ "dodgerblue": (30, 144, 255),
55
+ "firebrick": (178, 34, 34),
56
+ "floralwhite": (255, 250, 240),
57
+ "forestgreen": (34, 139, 34),
58
+ "fuchsia": (255, 0, 255),
59
+ "gainsboro": (220, 220, 220),
60
+ "ghostwhite": (248, 248, 255),
61
+ "gold": (255, 215, 0),
62
+ "goldenrod": (218, 165, 32),
63
+ "gray": (128, 128, 128),
64
+ "grey": (128, 128, 128),
65
+ "green": (0, 128, 0),
66
+ "greenyellow": (173, 255, 47),
67
+ "honeydew": (240, 255, 240),
68
+ "hotpink": (255, 105, 180),
69
+ "indianred": (205, 92, 92),
70
+ "indigo": (75, 0, 130),
71
+ "ivory": (255, 255, 240),
72
+ "khaki": (240, 230, 140),
73
+ "lavender": (230, 230, 250),
74
+ "lavenderblush": (255, 240, 245),
75
+ "lawngreen": (124, 252, 0),
76
+ "lemonchiffon": (255, 250, 205),
77
+ "lightblue": (173, 216, 230),
78
+ "lightcoral": (240, 128, 128),
79
+ "lightcyan": (224, 255, 255),
80
+ "lightgoldenrodyellow": (250, 250, 210),
81
+ "lightgray": (211, 211, 211),
82
+ "lightgrey": (211, 211, 211),
83
+ "lightgreen": (144, 238, 144),
84
+ "lightpink": (255, 182, 193),
85
+ "lightsalmon": (255, 160, 122),
86
+ "lightseagreen": (32, 178, 170),
87
+ "lightskyblue": (135, 206, 250),
88
+ "lightslategray": (119, 136, 153),
89
+ "lightslategrey": (119, 136, 153),
90
+ "lightsteelblue": (176, 196, 222),
91
+ "lightyellow": (255, 255, 224),
92
+ "lime": (0, 255, 0),
93
+ "limegreen": (50, 205, 50),
94
+ "linen": (250, 240, 230),
95
+ "magenta": (255, 0, 255),
96
+ "maroon": (128, 0, 0),
97
+ "mediumaquamarine": (102, 205, 170),
98
+ "mediumblue": (0, 0, 205),
99
+ "mediumorchid": (186, 85, 211),
100
+ "mediumpurple": (147, 112, 219),
101
+ "mediumseagreen": (60, 179, 113),
102
+ "mediumslateblue": (123, 104, 238),
103
+ "mediumspringgreen": (0, 250, 154),
104
+ "mediumturquoise": (72, 209, 204),
105
+ "mediumvioletred": (199, 21, 133),
106
+ "midnightblue": (25, 25, 112),
107
+ "mintcream": (245, 255, 250),
108
+ "mistyrose": (255, 228, 225),
109
+ "moccasin": (255, 228, 181),
110
+ "navajowhite": (255, 222, 173),
111
+ "navy": (0, 0, 128),
112
+ "navyblue": (0, 0, 128),
113
+ "oldlace": (253, 245, 230),
114
+ "olive": (128, 128, 0),
115
+ "olivedrab": (107, 142, 35),
116
+ "orange": (255, 165, 0),
117
+ "orangered": (255, 69, 0),
118
+ "orchid": (218, 112, 214),
119
+ "palegoldenrod": (238, 232, 170),
120
+ "palegreen": (152, 251, 152),
121
+ "paleturquoise": (175, 238, 238),
122
+ "palevioletred": (219, 112, 147),
123
+ "papayawhip": (255, 239, 213),
124
+ "peachpuff": (255, 218, 185),
125
+ "peru": (205, 133, 63),
126
+ "pink": (255, 192, 203),
127
+ "plum": (221, 160, 221),
128
+ "powderblue": (176, 224, 230),
129
+ "purple": (128, 0, 128),
130
+ "rebeccapurple": (102, 51, 153),
131
+ "red": (255, 0, 0),
132
+ "rosybrown": (188, 143, 143),
133
+ "royalblue": (65, 105, 225),
134
+ "saddlebrown": (139, 69, 19),
135
+ "salmon": (250, 128, 114),
136
+ "sandybrown": (244, 164, 96),
137
+ "seagreen": (46, 139, 87),
138
+ "seashell": (255, 245, 238),
139
+ "sienna": (160, 82, 45),
140
+ "silver": (192, 192, 192),
141
+ "skyblue": (135, 206, 235),
142
+ "slateblue": (106, 90, 205),
143
+ "slategray": (112, 128, 144),
144
+ "slategrey": (112, 128, 144),
145
+ "snow": (255, 250, 250),
146
+ "springgreen": (0, 255, 127),
147
+ "steelblue": (70, 130, 180),
148
+ "tan": (210, 180, 140),
149
+ "teal": (0, 128, 128),
150
+ "thistle": (216, 191, 216),
151
+ "tomato": (255, 99, 71),
152
+ "turquoise": (64, 224, 208),
153
+ "violet": (238, 130, 238),
154
+ "wheat": (245, 222, 179),
155
+ "white": (255, 255, 255),
156
+ "whitesmoke": (245, 245, 245),
157
+ "yellow": (255, 255, 0),
158
+ "yellowgreen": (154, 205, 50),
159
+ }
160
+
161
+ valid_locations = { # x, y in 90*90
162
+ "in the center": (45, 45),
163
+ "on the left": (15, 45),
164
+ "on the right": (75, 45),
165
+ "on the top": (45, 15),
166
+ "on the bottom": (45, 75),
167
+ "on the top-left": (15, 15),
168
+ "on the top-right": (75, 15),
169
+ "on the bottom-left": (15, 75),
170
+ "on the bottom-right": (75, 75),
171
+ }
172
+
173
+ valid_offsets = { # x, y in 90*90
174
+ "no offset": (0, 0),
175
+ "slightly to the left": (-10, 0),
176
+ "slightly to the right": (10, 0),
177
+ "slightly to the upper": (0, -10),
178
+ "slightly to the lower": (0, 10),
179
+ "slightly to the upper-left": (-10, -10),
180
+ "slightly to the upper-right": (10, -10),
181
+ "slightly to the lower-left": (-10, 10),
182
+ "slightly to the lower-right": (10, 10),
183
+ }
184
+
185
+ valid_areas = { # w, h in 90*90
186
+ "a small square area": (50, 50),
187
+ "a small vertical area": (40, 60),
188
+ "a small horizontal area": (60, 40),
189
+ "a medium-sized square area": (60, 60),
190
+ "a medium-sized vertical area": (50, 80),
191
+ "a medium-sized horizontal area": (80, 50),
192
+ "a large square area": (70, 70),
193
+ "a large vertical area": (60, 90),
194
+ "a large horizontal area": (90, 60),
195
+ }
196
+
197
+
198
+ def safe_str(x):
199
+ return x.strip(",. ") + "."
200
+
201
+
202
+ def closest_name(input_str, options):
203
+ input_str = input_str.lower()
204
+
205
+ closest_match = difflib.get_close_matches(
206
+ input_str, list(options.keys()), n=1, cutoff=0.5
207
+ )
208
+ assert isinstance(closest_match, list) and len(closest_match) > 0, (
209
+ f"The value [{input_str}] is not valid!"
210
+ )
211
+ result = closest_match[0]
212
+
213
+ if result != input_str:
214
+ print(f"Automatically corrected [{input_str}] -> [{result}].")
215
+
216
+ return result
217
+
218
+
219
+ class Canvas:
220
+ @staticmethod
221
+ def from_bot_response(response: str):
222
+ matched = re.search(r"```python\n(.*?)\n```", response, re.DOTALL)
223
+ assert matched, "Response does not contain codes!"
224
+ code_content = matched.group(1)
225
+ assert "canvas = Canvas()" in code_content, (
226
+ "Code block must include valid canvas var!"
227
+ )
228
+ local_vars = {"Canvas": Canvas}
229
+ exec(code_content, {}, local_vars)
230
+ canvas = local_vars.get("canvas", None)
231
+ assert isinstance(canvas, Canvas), "Code block must produce valid canvas var!"
232
+ return canvas
233
+
234
+ def __init__(self):
235
+ self.components = []
236
+ self.color = None
237
+ self.record_tags = True
238
+ self.prefixes = []
239
+ self.suffixes = []
240
+ return
241
+
242
+ def set_global_description(
243
+ self,
244
+ description: str,
245
+ detailed_descriptions: list,
246
+ tags: str,
247
+ HTML_web_color_name: str,
248
+ ):
249
+ assert isinstance(description, str), "Global description is not valid!"
250
+ assert isinstance(detailed_descriptions, list) and all(
251
+ isinstance(item, str) for item in detailed_descriptions
252
+ ), "Global detailed_descriptions is not valid!"
253
+ assert isinstance(tags, str), "Global tags is not valid!"
254
+
255
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
256
+ self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
257
+
258
+ self.prefixes = [description]
259
+ self.suffixes = detailed_descriptions
260
+
261
+ if self.record_tags:
262
+ self.suffixes = self.suffixes + [tags]
263
+
264
+ self.prefixes = [safe_str(x) for x in self.prefixes]
265
+ self.suffixes = [safe_str(x) for x in self.suffixes]
266
+
267
+ return
268
+
269
+ def add_local_description(
270
+ self,
271
+ location: str,
272
+ offset: str,
273
+ area: str,
274
+ distance_to_viewer: float,
275
+ description: str,
276
+ detailed_descriptions: list,
277
+ tags: str,
278
+ atmosphere: str,
279
+ style: str,
280
+ quality_meta: str,
281
+ HTML_web_color_name: str,
282
+ ):
283
+ assert isinstance(description, str), "Local description is wrong!"
284
+ assert (
285
+ isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0
286
+ ), f"The distance_to_viewer for [{description}] is not positive float number!"
287
+ assert isinstance(detailed_descriptions, list) and all(
288
+ isinstance(item, str) for item in detailed_descriptions
289
+ ), f"The detailed_descriptions for [{description}] is not valid!"
290
+ assert isinstance(tags, str), f"The tags for [{description}] is not valid!"
291
+ assert isinstance(atmosphere, str), (
292
+ f"The atmosphere for [{description}] is not valid!"
293
+ )
294
+ assert isinstance(style, str), f"The style for [{description}] is not valid!"
295
+ assert isinstance(quality_meta, str), (
296
+ f"The quality_meta for [{description}] is not valid!"
297
+ )
298
+
299
+ location = closest_name(location, valid_locations)
300
+ offset = closest_name(offset, valid_offsets)
301
+ area = closest_name(area, valid_areas)
302
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
303
+
304
+ xb, yb = valid_locations[location]
305
+ xo, yo = valid_offsets[offset]
306
+ w, h = valid_areas[area]
307
+ rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
308
+ rect = [max(0, min(90, i)) for i in rect]
309
+ color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
310
+
311
+ prefixes = self.prefixes + [description]
312
+ suffixes = detailed_descriptions
313
+
314
+ if self.record_tags:
315
+ suffixes = suffixes + [tags, atmosphere, style, quality_meta]
316
+
317
+ prefixes = [safe_str(x) for x in prefixes]
318
+ suffixes = [safe_str(x) for x in suffixes]
319
+
320
+ self.components.append(
321
+ dict(
322
+ rect=rect,
323
+ distance_to_viewer=distance_to_viewer,
324
+ color=color,
325
+ prefixes=prefixes,
326
+ suffixes=suffixes,
327
+ location=location,
328
+ )
329
+ )
330
+
331
+ return
332
+
333
+ def process(self):
334
+ # sort components
335
+ self.components = sorted(
336
+ self.components, key=lambda x: x["distance_to_viewer"], reverse=True
337
+ )
338
+
339
+ # compute initial latent
340
+ # print(self.color)
341
+ initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
342
+
343
+ for component in self.components:
344
+ a, b, c, d = component["rect"]
345
+ initial_latent[a:b, c:d] = (
346
+ 0.7 * component["color"] + 0.3 * initial_latent[a:b, c:d]
347
+ )
348
+
349
+ initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
350
+
351
+ # compute conditions
352
+
353
+ bag_of_conditions = [
354
+ dict(
355
+ mask=np.ones(shape=(90, 90), dtype=np.float32),
356
+ prefixes=self.prefixes,
357
+ suffixes=self.suffixes,
358
+ location="full",
359
+ )
360
+ ]
361
+
362
+ for i, component in enumerate(self.components):
363
+ a, b, c, d = component["rect"]
364
+ m = np.zeros(shape=(90, 90), dtype=np.float32)
365
+ m[a:b, c:d] = 1.0
366
+ bag_of_conditions.append(
367
+ dict(
368
+ mask=m,
369
+ prefixes=component["prefixes"],
370
+ suffixes=component["suffixes"],
371
+ location=component["location"],
372
+ )
373
+ )
374
+
375
+ return dict(
376
+ initial_latent=initial_latent,
377
+ bag_of_conditions=bag_of_conditions,
378
+ )
379
+
380
+
381
+ class OmostPromter(torch.nn.Module):
382
+ def __init__(self, model=None, tokenizer=None, template="", device="cpu"):
383
+ super().__init__()
384
+ self.model = model
385
+ self.tokenizer = tokenizer
386
+ self.device = device
387
+ if template == "":
388
+ template = r"""You are a helpful AI assistant to compose images using the below python class `Canvas`:
389
+ ```python
390
+ class Canvas:
391
+ def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
392
+ pass
393
+
394
+ def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
395
+ assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
396
+ assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
397
+ assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
398
+ assert distance_to_viewer > 0
399
+ pass
400
+ ```"""
401
+ self.template = template
402
+
403
+ @staticmethod
404
+ def from_model_manager(model_manager: ModelManager):
405
+ model, model_path = model_manager.fetch_model(
406
+ "omost_prompt", require_model_path=True
407
+ )
408
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
409
+ omost = OmostPromter(
410
+ model=model, tokenizer=tokenizer, device=model_manager.device
411
+ )
412
+ return omost
413
+
414
+ def __call__(self, prompt_dict: dict):
415
+ raw_prompt = prompt_dict["prompt"]
416
+ conversation = [{"role": "system", "content": self.template}]
417
+ conversation.append({"role": "user", "content": raw_prompt})
418
+
419
+ input_ids = self.tokenizer.apply_chat_template(
420
+ conversation, return_tensors="pt", add_generation_prompt=True
421
+ ).to(self.device)
422
+ streamer = TextIteratorStreamer(
423
+ self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
424
+ )
425
+ attention_mask = torch.ones(
426
+ input_ids.shape, dtype=torch.bfloat16, device=self.device
427
+ )
428
+
429
+ generate_kwargs = dict(
430
+ input_ids=input_ids,
431
+ streamer=streamer,
432
+ # stopping_criteria=stopping_criteria,
433
+ # max_new_tokens=max_new_tokens,
434
+ do_sample=True,
435
+ attention_mask=attention_mask,
436
+ pad_token_id=self.tokenizer.eos_token_id,
437
+ # temperature=temperature,
438
+ # top_p=top_p,
439
+ )
440
+ self.model.generate(**generate_kwargs)
441
+ outputs = []
442
+ for text in streamer:
443
+ outputs.append(text)
444
+ llm_outputs = "".join(outputs)
445
+
446
+ canvas = Canvas.from_bot_response(llm_outputs)
447
+ canvas_output = canvas.process()
448
+
449
+ prompts = [
450
+ " ".join(_["prefixes"] + _["suffixes"][:2])
451
+ for _ in canvas_output["bag_of_conditions"]
452
+ ]
453
+ canvas_output["prompt"] = prompts[0]
454
+ canvas_output["prompts"] = prompts[1:]
455
+
456
+ raw_masks = [_["mask"] for _ in canvas_output["bag_of_conditions"]]
457
+ masks = []
458
+ for mask in raw_masks:
459
+ mask[mask > 0.5] = 255
460
+ mask = np.stack([mask] * 3, axis=-1).astype("uint8")
461
+ masks.append(Image.fromarray(mask))
462
+
463
+ canvas_output["masks"] = masks
464
+ prompt_dict.update(canvas_output)
465
+ print(f"Your prompt is extended by Omost:\n")
466
+ cnt = 0
467
+ for component, pmt in zip(canvas_output["bag_of_conditions"], prompts):
468
+ loc = component["location"]
469
+ cnt += 1
470
+ print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n")
471
+
472
+ return prompt_dict
prompters/prompt_refiners.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ from models.model_manager import ModelManager
3
+ import torch
4
+ from .omost import OmostPromter
5
+
6
+
7
+ class BeautifulPrompt(torch.nn.Module):
8
+ def __init__(self, tokenizer_path=None, model=None, template=""):
9
+ super().__init__()
10
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
11
+ self.model = model
12
+ self.template = template
13
+
14
+ @staticmethod
15
+ def from_model_manager(model_manager: ModelManager):
16
+ model, model_path = model_manager.fetch_model(
17
+ "beautiful_prompt", require_model_path=True
18
+ )
19
+ template = "Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:"
20
+ if model_path.endswith("v2"):
21
+ template = """Converts a simple image description into a prompt. \
22
+ Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
23
+ or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
24
+ but make sure there is a correlation between the input and output.\n\
25
+ ### Input: {raw_prompt}\n### Output:"""
26
+ beautiful_prompt = BeautifulPrompt(
27
+ tokenizer_path=model_path, model=model, template=template
28
+ )
29
+ return beautiful_prompt
30
+
31
+ def __call__(self, raw_prompt, positive=True, **kwargs):
32
+ if positive:
33
+ model_input = self.template.format(raw_prompt=raw_prompt)
34
+ input_ids = self.tokenizer.encode(model_input, return_tensors="pt").to(
35
+ self.model.device
36
+ )
37
+ outputs = self.model.generate(
38
+ input_ids,
39
+ max_new_tokens=384,
40
+ do_sample=True,
41
+ temperature=0.9,
42
+ top_k=50,
43
+ top_p=0.95,
44
+ repetition_penalty=1.1,
45
+ num_return_sequences=1,
46
+ )
47
+ prompt = (
48
+ raw_prompt
49
+ + ", "
50
+ + self.tokenizer.batch_decode(
51
+ outputs[:, input_ids.size(1) :], skip_special_tokens=True
52
+ )[0].strip()
53
+ )
54
+ print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
55
+ return prompt
56
+ else:
57
+ return raw_prompt
58
+
59
+
60
+ class QwenPrompt(torch.nn.Module):
61
+ # This class leverages the open-source Qwen model to translate Chinese prompts into English,
62
+ # with an integrated optimization mechanism for enhanced translation quality.
63
+ def __init__(self, tokenizer_path=None, model=None, system_prompt=""):
64
+ super().__init__()
65
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
66
+ self.model = model
67
+ self.system_prompt = system_prompt
68
+
69
+ @staticmethod
70
+ def from_model_manager(model_nameger: ModelManager):
71
+ model, model_path = model_nameger.fetch_model(
72
+ "qwen_prompt", require_model_path=True
73
+ )
74
+ system_prompt = """You are an English image describer. Here are some example image styles:\n\n1. Extreme close-up: Clear focus on a single object with a blurred background, highlighted under natural sunlight.\n2. Vintage: A photograph of a historical scene, using techniques such as Daguerreotype or cyanotype.\n3. Anime: A stylized cartoon image, emphasizing hyper-realistic portraits and luminous brushwork.\n4. Candid: A natural, unposed shot capturing spontaneous moments, often with cinematic qualities.\n5. Landscape: A photorealistic image of natural scenery, such as a sunrise over the sea.\n6. Design: Colorful and detailed illustrations, often in the style of 2D game art or botanical illustrations.\n7. Urban: An ultrarealistic scene in a modern setting, possibly a cityscape viewed from indoors.\n\nYour task is to translate a given Chinese image description into a concise and precise English description. Ensure that the imagery is vivid and descriptive, and include stylistic elements to enrich the description.\nPlease note the following points:\n\n1. Capture the essence and mood of the Chinese description without including direct phrases or words from the examples provided.\n2. You should add appropriate words to make the images described in the prompt more aesthetically pleasing. If the Chinese description does not specify a style, you need to add some stylistic descriptions based on the essence of the Chinese text.\n3. The generated English description should not exceed 200 words.\n\n"""
75
+ qwen_prompt = QwenPrompt(
76
+ tokenizer_path=model_path, model=model, system_prompt=system_prompt
77
+ )
78
+ return qwen_prompt
79
+
80
+ def __call__(self, raw_prompt, positive=True, **kwargs):
81
+ if positive:
82
+ messages = [
83
+ {"role": "system", "content": self.system_prompt},
84
+ {"role": "user", "content": raw_prompt},
85
+ ]
86
+ text = self.tokenizer.apply_chat_template(
87
+ messages, tokenize=False, add_generation_prompt=True
88
+ )
89
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(
90
+ self.model.device
91
+ )
92
+
93
+ generated_ids = self.model.generate(
94
+ model_inputs.input_ids, max_new_tokens=512
95
+ )
96
+ generated_ids = [
97
+ output_ids[len(input_ids) :]
98
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
99
+ ]
100
+
101
+ prompt = self.tokenizer.batch_decode(
102
+ generated_ids, skip_special_tokens=True
103
+ )[0]
104
+ print(f"Your prompt is refined by Qwen: {prompt}")
105
+ return prompt
106
+ else:
107
+ return raw_prompt
108
+
109
+
110
+ class Translator(torch.nn.Module):
111
+ def __init__(self, tokenizer_path=None, model=None):
112
+ super().__init__()
113
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
114
+ self.model = model
115
+
116
+ @staticmethod
117
+ def from_model_manager(model_manager: ModelManager):
118
+ model, model_path = model_manager.fetch_model(
119
+ "translator", require_model_path=True
120
+ )
121
+ translator = Translator(tokenizer_path=model_path, model=model)
122
+ return translator
123
+
124
+ def __call__(self, prompt, **kwargs):
125
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
126
+ self.model.device
127
+ )
128
+ output_ids = self.model.generate(input_ids)
129
+ prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
130
+ print(f"Your prompt is translated: {prompt}")
131
+ return prompt
prompters/wan_prompter.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_prompter import BasePrompter
2
+ from models.wan_video_text_encoder import WanTextEncoder
3
+ from transformers import AutoTokenizer
4
+ import os, torch
5
+ import ftfy
6
+ import html
7
+ import string
8
+ import regex as re
9
+
10
+
11
+ def basic_clean(text):
12
+ text = ftfy.fix_text(text)
13
+ text = html.unescape(html.unescape(text))
14
+ return text.strip()
15
+
16
+
17
+ def whitespace_clean(text):
18
+ text = re.sub(r"\s+", " ", text)
19
+ text = text.strip()
20
+ return text
21
+
22
+
23
+ def canonicalize(text, keep_punctuation_exact_string=None):
24
+ text = text.replace("_", " ")
25
+ if keep_punctuation_exact_string:
26
+ text = keep_punctuation_exact_string.join(
27
+ part.translate(str.maketrans("", "", string.punctuation))
28
+ for part in text.split(keep_punctuation_exact_string)
29
+ )
30
+ else:
31
+ text = text.translate(str.maketrans("", "", string.punctuation))
32
+ text = text.lower()
33
+ text = re.sub(r"\s+", " ", text)
34
+ return text.strip()
35
+
36
+
37
+ class HuggingfaceTokenizer:
38
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
39
+ assert clean in (None, "whitespace", "lower", "canonicalize")
40
+ self.name = name
41
+ self.seq_len = seq_len
42
+ self.clean = clean
43
+
44
+ # init tokenizer
45
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
46
+ self.vocab_size = self.tokenizer.vocab_size
47
+
48
+ def __call__(self, sequence, **kwargs):
49
+ return_mask = kwargs.pop("return_mask", False)
50
+
51
+ # arguments
52
+ _kwargs = {"return_tensors": "pt"}
53
+ if self.seq_len is not None:
54
+ _kwargs.update(
55
+ {
56
+ "padding": "max_length",
57
+ "truncation": True,
58
+ "max_length": self.seq_len,
59
+ }
60
+ )
61
+ _kwargs.update(**kwargs)
62
+
63
+ # tokenization
64
+ if isinstance(sequence, str):
65
+ sequence = [sequence]
66
+ if self.clean:
67
+ sequence = [self._clean(u) for u in sequence]
68
+ ids = self.tokenizer(sequence, **_kwargs)
69
+
70
+ # output
71
+ if return_mask:
72
+ return ids.input_ids, ids.attention_mask
73
+ else:
74
+ return ids.input_ids
75
+
76
+ def _clean(self, text):
77
+ if self.clean == "whitespace":
78
+ text = whitespace_clean(basic_clean(text))
79
+ elif self.clean == "lower":
80
+ text = whitespace_clean(basic_clean(text)).lower()
81
+ elif self.clean == "canonicalize":
82
+ text = canonicalize(basic_clean(text))
83
+ return text
84
+
85
+
86
+ class WanPrompter(BasePrompter):
87
+ def __init__(self, tokenizer_path=None, text_len=512):
88
+ super().__init__()
89
+ self.text_len = text_len
90
+ self.text_encoder = None
91
+ self.fetch_tokenizer(tokenizer_path)
92
+
93
+ def fetch_tokenizer(self, tokenizer_path=None):
94
+ if tokenizer_path is not None:
95
+ self.tokenizer = HuggingfaceTokenizer(
96
+ name=tokenizer_path, seq_len=self.text_len, clean="whitespace"
97
+ )
98
+
99
+ def fetch_models(self, text_encoder: WanTextEncoder = None):
100
+ self.text_encoder = text_encoder
101
+
102
+ def encode_prompt(self, prompt, positive=True, device="cuda"):
103
+ prompt = self.process_prompt(prompt, positive=positive)
104
+
105
+ ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
106
+ ids = ids.to(device)
107
+ mask = mask.to(device)
108
+ seq_lens = mask.gt(0).sum(dim=1).long()
109
+ prompt_emb = self.text_encoder(ids, mask)
110
+ for i, v in enumerate(seq_lens):
111
+ prompt_emb[:, v:] = 0
112
+ return prompt_emb
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.7.0
2
+ torchvision==0.22.0
3
+ ftfy==6.3.1
4
+ huggingface_hub==0.31.1
5
+ imageio==2.37.0
6
+ insightface==0.7.3
7
+ numpy==2.2.6
8
+ opencv_python==4.11.0.86
9
+ Pillow==11.3.0
10
+ safetensors==0.5.3
11
+ tqdm==4.67.1
12
+ transformers==4.46.2
13
+ facexlib==0.3.0
14
+ einops==0.8.1
15
+ onnxruntime-gpu==1.22.0
16
+ imageio-ffmpeg==0.6.0
17
+ scikit-image==0.25.2
schedulers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .ddim import EnhancedDDIMScheduler
2
+ from .continuous_ode import ContinuousODEScheduler
3
+ from .flow_match import FlowMatchScheduler
schedulers/continuous_ode.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ContinuousODEScheduler:
5
+ def __init__(
6
+ self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0
7
+ ):
8
+ self.sigma_max = sigma_max
9
+ self.sigma_min = sigma_min
10
+ self.rho = rho
11
+ self.set_timesteps(num_inference_steps)
12
+
13
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, **kwargs):
14
+ ramp = torch.linspace(1 - denoising_strength, 1, num_inference_steps)
15
+ min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
16
+ max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
17
+ self.sigmas = torch.pow(
18
+ max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho
19
+ )
20
+ self.timesteps = torch.log(self.sigmas) * 0.25
21
+
22
+ def step(self, model_output, timestep, sample, to_final=False):
23
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
24
+ sigma = self.sigmas[timestep_id]
25
+ sample *= (sigma * sigma + 1).sqrt()
26
+ estimated_sample = (
27
+ -sigma / (sigma * sigma + 1).sqrt() * model_output
28
+ + 1 / (sigma * sigma + 1) * sample
29
+ )
30
+ if to_final or timestep_id + 1 >= len(self.timesteps):
31
+ prev_sample = estimated_sample
32
+ else:
33
+ sigma_ = self.sigmas[timestep_id + 1]
34
+ derivative = 1 / sigma * (sample - estimated_sample)
35
+ prev_sample = sample + derivative * (sigma_ - sigma)
36
+ prev_sample /= (sigma_ * sigma_ + 1).sqrt()
37
+ return prev_sample
38
+
39
+ def return_to_timestep(self, timestep, sample, sample_stablized):
40
+ # This scheduler doesn't support this function.
41
+ pass
42
+
43
+ def add_noise(self, original_samples, noise, timestep):
44
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
45
+ sigma = self.sigmas[timestep_id]
46
+ sample = (original_samples + noise * sigma) / (sigma * sigma + 1).sqrt()
47
+ return sample
48
+
49
+ def training_target(self, sample, noise, timestep):
50
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
51
+ sigma = self.sigmas[timestep_id]
52
+ target = (
53
+ -(sigma * sigma + 1).sqrt() / sigma + 1 / (sigma * sigma + 1).sqrt() / sigma
54
+ ) * sample + 1 / (sigma * sigma + 1).sqrt() * noise
55
+ return target
56
+
57
+ def training_weight(self, timestep):
58
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
59
+ sigma = self.sigmas[timestep_id]
60
+ weight = (1 + sigma * sigma).sqrt() / sigma
61
+ return weight
schedulers/ddim.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+
3
+
4
+ class EnhancedDDIMScheduler:
5
+ def __init__(
6
+ self,
7
+ num_train_timesteps=1000,
8
+ beta_start=0.00085,
9
+ beta_end=0.012,
10
+ beta_schedule="scaled_linear",
11
+ prediction_type="epsilon",
12
+ rescale_zero_terminal_snr=False,
13
+ ):
14
+ self.num_train_timesteps = num_train_timesteps
15
+ if beta_schedule == "scaled_linear":
16
+ betas = torch.square(
17
+ torch.linspace(
18
+ math.sqrt(beta_start),
19
+ math.sqrt(beta_end),
20
+ num_train_timesteps,
21
+ dtype=torch.float32,
22
+ )
23
+ )
24
+ elif beta_schedule == "linear":
25
+ betas = torch.linspace(
26
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32
27
+ )
28
+ else:
29
+ raise NotImplementedError(f"{beta_schedule} is not implemented")
30
+ self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
31
+ if rescale_zero_terminal_snr:
32
+ self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
33
+ self.alphas_cumprod = self.alphas_cumprod.tolist()
34
+ self.set_timesteps(10)
35
+ self.prediction_type = prediction_type
36
+
37
+ def rescale_zero_terminal_snr(self, alphas_cumprod):
38
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
39
+
40
+ # Store old values.
41
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
42
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
43
+
44
+ # Shift so the last timestep is zero.
45
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
46
+
47
+ # Scale so the first timestep is back to the old value.
48
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
49
+
50
+ # Convert alphas_bar_sqrt to betas
51
+ alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
52
+
53
+ return alphas_bar
54
+
55
+ def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
56
+ # The timesteps are aligned to 999...0, which is different from other implementations,
57
+ # but I think this implementation is more reasonable in theory.
58
+ max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
59
+ num_inference_steps = min(num_inference_steps, max_timestep + 1)
60
+ if num_inference_steps == 1:
61
+ self.timesteps = torch.Tensor([max_timestep])
62
+ else:
63
+ step_length = max_timestep / (num_inference_steps - 1)
64
+ self.timesteps = torch.Tensor(
65
+ [
66
+ round(max_timestep - i * step_length)
67
+ for i in range(num_inference_steps)
68
+ ]
69
+ )
70
+
71
+ def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
72
+ if self.prediction_type == "epsilon":
73
+ weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(
74
+ alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t
75
+ )
76
+ weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
77
+ prev_sample = sample * weight_x + model_output * weight_e
78
+ elif self.prediction_type == "v_prediction":
79
+ weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(
80
+ alpha_prod_t * (1 - alpha_prod_t_prev)
81
+ )
82
+ weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt(
83
+ (1 - alpha_prod_t) * (1 - alpha_prod_t_prev)
84
+ )
85
+ prev_sample = sample * weight_x + model_output * weight_e
86
+ else:
87
+ raise NotImplementedError(f"{self.prediction_type} is not implemented")
88
+ return prev_sample
89
+
90
+ def step(self, model_output, timestep, sample, to_final=False):
91
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
92
+ if isinstance(timestep, torch.Tensor):
93
+ timestep = timestep.cpu()
94
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
95
+ if to_final or timestep_id + 1 >= len(self.timesteps):
96
+ alpha_prod_t_prev = 1.0
97
+ else:
98
+ timestep_prev = int(self.timesteps[timestep_id + 1])
99
+ alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
100
+
101
+ return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
102
+
103
+ def return_to_timestep(self, timestep, sample, sample_stablized):
104
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
105
+ noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(
106
+ 1 - alpha_prod_t
107
+ )
108
+ return noise_pred
109
+
110
+ def add_noise(self, original_samples, noise, timestep):
111
+ sqrt_alpha_prod = math.sqrt(
112
+ self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
113
+ )
114
+ sqrt_one_minus_alpha_prod = math.sqrt(
115
+ 1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
116
+ )
117
+ noisy_samples = (
118
+ sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
119
+ )
120
+ return noisy_samples
121
+
122
+ def training_target(self, sample, noise, timestep):
123
+ if self.prediction_type == "epsilon":
124
+ return noise
125
+ else:
126
+ sqrt_alpha_prod = math.sqrt(
127
+ self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
128
+ )
129
+ sqrt_one_minus_alpha_prod = math.sqrt(
130
+ 1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
131
+ )
132
+ target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
133
+ return target
134
+
135
+ def training_weight(self, timestep):
136
+ return 1.0
schedulers/flow_match.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class FlowMatchScheduler:
5
+ def __init__(
6
+ self,
7
+ num_inference_steps=100,
8
+ num_train_timesteps=1000,
9
+ shift=3.0,
10
+ sigma_max=1.0,
11
+ sigma_min=0.003 / 1.002,
12
+ inverse_timesteps=False,
13
+ extra_one_step=False,
14
+ reverse_sigmas=False,
15
+ ):
16
+ self.num_train_timesteps = num_train_timesteps
17
+ self.shift = shift
18
+ self.sigma_max = sigma_max
19
+ self.sigma_min = sigma_min
20
+ self.inverse_timesteps = inverse_timesteps
21
+ self.extra_one_step = extra_one_step
22
+ self.reverse_sigmas = reverse_sigmas
23
+ self.set_timesteps(num_inference_steps)
24
+
25
+ def set_timesteps(
26
+ self,
27
+ num_inference_steps=100,
28
+ denoising_strength=1.0,
29
+ training=False,
30
+ shift=None,
31
+ ):
32
+ if shift is not None:
33
+ self.shift = shift
34
+ sigma_start = (
35
+ self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
36
+ )
37
+ if self.extra_one_step:
38
+ self.sigmas = torch.linspace(
39
+ sigma_start, self.sigma_min, num_inference_steps + 1
40
+ )[:-1]
41
+ else:
42
+ self.sigmas = torch.linspace(
43
+ sigma_start, self.sigma_min, num_inference_steps
44
+ )
45
+ if self.inverse_timesteps:
46
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
47
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
48
+ if self.reverse_sigmas:
49
+ self.sigmas = 1 - self.sigmas
50
+ self.timesteps = self.sigmas * self.num_train_timesteps
51
+ if training:
52
+ x = self.timesteps
53
+ y = torch.exp(
54
+ -2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2
55
+ )
56
+ y_shifted = y - y.min()
57
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
58
+ self.linear_timesteps_weights = bsmntw_weighing
59
+ self.training = True
60
+ else:
61
+ self.training = False
62
+
63
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
64
+ if isinstance(timestep, torch.Tensor):
65
+ timestep = timestep.cpu()
66
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
67
+ sigma = self.sigmas[timestep_id]
68
+ if to_final or timestep_id + 1 >= len(self.timesteps):
69
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
70
+ else:
71
+ sigma_ = self.sigmas[timestep_id + 1]
72
+ prev_sample = sample + model_output * (sigma_ - sigma)
73
+ return prev_sample
74
+
75
+ def return_to_timestep(self, timestep, sample, sample_stablized):
76
+ if isinstance(timestep, torch.Tensor):
77
+ timestep = timestep.cpu()
78
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
79
+ sigma = self.sigmas[timestep_id]
80
+ model_output = (sample - sample_stablized) / sigma
81
+ return model_output
82
+
83
+ def add_noise(self, original_samples, noise, timestep):
84
+ if isinstance(timestep, torch.Tensor):
85
+ timestep = timestep.cpu()
86
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
87
+ sigma = self.sigmas[timestep_id]
88
+ sample = (1 - sigma) * original_samples + sigma * noise
89
+ return sample
90
+
91
+ def training_target(self, sample, noise, timestep):
92
+ target = noise - sample
93
+ return target
94
+
95
+ def training_weight(self, timestep):
96
+ timestep_id = torch.argmin(
97
+ (self.timesteps - timestep.to(self.timesteps.device)).abs()
98
+ )
99
+ weights = self.linear_timesteps_weights[timestep_id]
100
+ return weights
test/input/first_frame.png ADDED

Git LFS Details

  • SHA256: 1f864b3330b1b47f11cc71235993766a34187a107b01815bf2708f4a459ffa67
  • Pointer size: 131 Bytes
  • Size of remote file: 403 kB
test/input/lecun.jpg ADDED
test/input/pose.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2038896160fb990162832cbff6eaebcf05b25e1a3b8c201e5b147a4ce3ce01d
3
+ size 173260
test/input/ruonan.jpg ADDED

Git LFS Details

  • SHA256: d0f82d2b7c91c08033ca2ce14d475675ccd302a966a6abad4f028568a5d078d2
  • Pointer size: 131 Bytes
  • Size of remote file: 204 kB
test/input/woman.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee8fb303f53a89c0ab36c0457c9452149c58a881055becb4da8abc41766bc6db
3
+ size 8399484