Spaces:
Paused
Paused
Migrated from GitHub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- LICENSE +201 -0
- ORIGINAL_README.md +238 -0
- app.py +163 -0
- assets/Stand-In.png +0 -0
- configs/model_config.py +1809 -0
- data/video.py +158 -0
- distributed/__init__.py +0 -0
- distributed/xdit_context_parallel.py +154 -0
- download_models.py +21 -0
- infer.py +85 -0
- infer_face_swap.py +119 -0
- infer_with_lora.py +94 -0
- infer_with_vace.py +106 -0
- lora/__init__.py +91 -0
- models/__init__.py +1 -0
- models/attention.py +130 -0
- models/downloader.py +122 -0
- models/model_manager.py +610 -0
- models/set_condition_branch.py +41 -0
- models/tiler.py +333 -0
- models/utils.py +219 -0
- models/wan_video_camera_controller.py +290 -0
- models/wan_video_dit.py +952 -0
- models/wan_video_image_encoder.py +957 -0
- models/wan_video_motion_controller.py +41 -0
- models/wan_video_text_encoder.py +289 -0
- models/wan_video_vace.py +140 -0
- models/wan_video_vae.py +1634 -0
- pipelines/base.py +173 -0
- pipelines/wan_video.py +1793 -0
- pipelines/wan_video_face_swap.py +1786 -0
- preprocessor/__init__.py +2 -0
- preprocessor/image_input_preprocessor.py +181 -0
- preprocessor/videomask_generator.py +242 -0
- prompters/__init__.py +3 -0
- prompters/base_prompter.py +68 -0
- prompters/omost.py +472 -0
- prompters/prompt_refiners.py +131 -0
- prompters/wan_prompter.py +112 -0
- requirements.txt +17 -0
- schedulers/__init__.py +3 -0
- schedulers/continuous_ode.py +61 -0
- schedulers/ddim.py +136 -0
- schedulers/flow_match.py +100 -0
- test/input/first_frame.png +3 -0
- test/input/lecun.jpg +0 -0
- test/input/pose.mp4 +3 -0
- test/input/ruonan.jpg +3 -0
- test/input/woman.mp4 +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
test/input/first_frame.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
test/input/pose.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
test/input/ruonan.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
test/input/woman.mp4 filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
ORIGINAL_README.md
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
<h1>
|
| 4 |
+
<img src="assets/Stand-In.png" width="85" alt="Logo" valign="middle">
|
| 5 |
+
Stand-In
|
| 6 |
+
</h1>
|
| 7 |
+
|
| 8 |
+
<h3>A Lightweight and Plug-and-Play Identity Control for Video Generation</h3>
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
[](https://arxiv.org/abs/2508.07901)
|
| 13 |
+
[](https://www.stand-in.tech)
|
| 14 |
+
[](https://huggingface.co/BowenXue/Stand-In)
|
| 15 |
+
|
| 16 |
+
</div>
|
| 17 |
+
|
| 18 |
+
<img width="5333" height="2983" alt="Image" src="https://github.com/user-attachments/assets/2fe1e505-bcf7-4eb6-8628-f23e70020966" />
|
| 19 |
+
|
| 20 |
+
> **Stand-In** is a lightweight, plug-and-play framework for identity-preserving video generation. By training only **1%** additional parameters compared to the base video generation model, we achieve state-of-the-art results in both Face Similarity and Naturalness, outperforming various full-parameter training methods. Moreover, **Stand-In** can be seamlessly integrated into other tasks such as subject-driven video generation, pose-controlled video generation, video stylization, and face swapping.
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## 🔥 News
|
| 25 |
+
* **[2025.08.18]** We have released a version compatible with VACE. Not only pose control, but you can also try other control methods such as depth maps, combined with Stand-In to maintain identity simultaneously.
|
| 26 |
+
|
| 27 |
+
* **[2025.08.16]** We have updated the experimental version of the face swapping feature. Feel free to try it out!
|
| 28 |
+
|
| 29 |
+
* **[2025.08.13]** Special thanks to @kijai for integrating Stand-In into the custom ComfyUI node **WanVideoWrapper**. However, the implementation differs from the official version, which may affect Stand-In’s performance.
|
| 30 |
+
In order to address part of the issue, we have urgently released the official Stand-In preprocessing ComfyUI node:
|
| 31 |
+
👉 https://github.com/WeChatCV/Stand-In_Preprocessor_ComfyUI
|
| 32 |
+
If you wish to experience Stand-In within ComfyUI, please use **our official preprocessing node** to replace the one implemented by kijai.
|
| 33 |
+
For the best results, we recommend waiting for the release of our full **official Stand-In ComfyUI**.
|
| 34 |
+
|
| 35 |
+
* **[2025.08.12]** Released Stand-In v1.0 (153M parameters), the Wan2.1-14B-T2V–adapted weights and inference code are now open-sourced.
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
## 🌟 Showcase
|
| 40 |
+
|
| 41 |
+
### Identity-Preserving Text-to-Video Generation
|
| 42 |
+
|
| 43 |
+
| Reference Image | Prompt | Generated Video |
|
| 44 |
+
| :---: | :---: | :---: |
|
| 45 |
+
|| "In a corridor where the walls ripple like water, a woman reaches out to touch the flowing surface, causing circles of ripples to spread. The camera moves from a medium shot to a close-up, capturing her curious expression as she sees her distorted reflection." | |
|
| 46 |
+
|| "A young man dressed in traditional attire draws the long sword from his waist and begins to wield it. The blade flashes with light as he moves—his eyes sharp, his actions swift and powerful, with his flowing robes dancing in the wind." ||
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
### Non-Human Subjects-Preserving Video Generation
|
| 50 |
+
|
| 51 |
+
| Reference Image | Prompt | Generated Video |
|
| 52 |
+
| :---: | :---: | :---: |
|
| 53 |
+
|<img width="415" height="415" alt="Image" src="https://github.com/user-attachments/assets/b929444d-d724-4cf9-b422-be82b380ff78" />|"A chibi-style boy speeding on a skateboard, holding a detective novel in one hand. The background features city streets, with trees, streetlights, and billboards along the roads."| |
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
### Identity-Preserving Stylized Video Generation
|
| 58 |
+
|
| 59 |
+
| Reference Image | LoRA | Generated Video |
|
| 60 |
+
| :---: | :---: | :---: |
|
| 61 |
+
||Ghibli LoRA||
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
### Video Face Swapping
|
| 65 |
+
|
| 66 |
+
| Reference Video | Identity | Generated Video |
|
| 67 |
+
| :---: | :---: | :---: |
|
| 68 |
+
||<img width="415" height="415" alt="Image" src="https://github.com/user-attachments/assets/d2cd8da0-7aa0-4ee4-a61d-b52718c33756" />||
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
### Pose-Guided Video Generation (With VACE)
|
| 74 |
+
|
| 75 |
+
| Reference Pose | First Frame | Generated Video |
|
| 76 |
+
| :---: | :---: | :---: |
|
| 77 |
+
||<img width="719" height="415" alt="Image" src="https://github.com/user-attachments/assets/1c2a69e1-e530-4164-848b-e7ea85a99763" />||
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
### For more results, please visit [https://stand-in-video.github.io/](https://www.Stand-In.tech)
|
| 81 |
+
|
| 82 |
+
## 📖 Key Features
|
| 83 |
+
- Efficient Training: Only 1% of the base model parameters need to be trained.
|
| 84 |
+
- High Fidelity: Outstanding identity consistency without sacrificing video generation quality.
|
| 85 |
+
- Plug-and-Play: Easily integrates into existing T2V (Text-to-Video) models.
|
| 86 |
+
- Highly Extensible: Compatible with community models such as LoRA, and supports various downstream video tasks.
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
## ✅ Todo List
|
| 91 |
+
- [x] Release IP2V inference script (compatible with community LoRA).
|
| 92 |
+
- [x] Open-source model weights compatible with Wan2.1-14B-T2V: `Stand-In_Wan2.1-T2V-14B_153M_v1.0`。
|
| 93 |
+
- [ ] Open-source model weights compatible with Wan2.2-T2V-A14B.
|
| 94 |
+
- [ ] Release training dataset, data preprocessing scripts, and training code.
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## 🚀 Quick Start
|
| 99 |
+
|
| 100 |
+
### 1. Environment Setup
|
| 101 |
+
```bash
|
| 102 |
+
# Clone the project repository
|
| 103 |
+
git clone https://github.com/WeChatCV/Stand-In.git
|
| 104 |
+
cd Stand-In
|
| 105 |
+
|
| 106 |
+
# Create and activate Conda environment
|
| 107 |
+
conda create -n Stand-In python=3.11 -y
|
| 108 |
+
conda activate Stand-In
|
| 109 |
+
|
| 110 |
+
# Install dependencies
|
| 111 |
+
pip install -r requirements.txt
|
| 112 |
+
|
| 113 |
+
# (Optional) Install Flash Attention for faster inference
|
| 114 |
+
# Note: Make sure your GPU and CUDA version are compatible with Flash Attention
|
| 115 |
+
pip install flash-attn --no-build-isolation
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### 2. Model Download
|
| 119 |
+
We provide an automatic download script that will fetch all required model weights into the `checkpoints` directory.
|
| 120 |
+
```bash
|
| 121 |
+
python download_models.py
|
| 122 |
+
```
|
| 123 |
+
This script will download the following models:
|
| 124 |
+
* `wan2.1-T2V-14B` (base text-to-video model)
|
| 125 |
+
* `antelopev2` (face recognition model)
|
| 126 |
+
* `Stand-In` (our Stand-In model)
|
| 127 |
+
|
| 128 |
+
> Note: If you already have the `wan2.1-T2V-14B model` locally, you can manually edit the `download_model.py` script to comment out the relevant download code and place the model in the `checkpoints/wan2.1-T2V-14B` directory.
|
| 129 |
+
|
| 130 |
+
---
|
| 131 |
+
|
| 132 |
+
## 🧪 Usage
|
| 133 |
+
|
| 134 |
+
### Standard Inference
|
| 135 |
+
|
| 136 |
+
Use the `infer.py` script for standard identity-preserving text-to-video generation.
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
python infer.py \
|
| 141 |
+
--prompt "A man sits comfortably at a desk, facing the camera as if talking to a friend or family member on the screen. His gaze is focused and gentle, with a natural smile. The background is his carefully decorated personal space, with photos and a world map on the wall, conveying a sense of intimate and modern communication." \
|
| 142 |
+
--ip_image "test/input/lecun.jpg" \
|
| 143 |
+
--output "test/output/lecun.mp4"
|
| 144 |
+
```
|
| 145 |
+
**Prompt Writing Tip:** If you do not wish to alter the subject's facial features, simply use *"a man"* or *"a woman"* without adding extra descriptions of their appearance. Prompts support both Chinese and English input. The prompt is intended for generating frontal, medium-to-close-up videos.
|
| 146 |
+
|
| 147 |
+
**Input Image Recommendation:** For best results, use a high-resolution frontal face image. There are no restrictions on resolution or file extension, as our built-in preprocessing pipeline will handle them automatically.
|
| 148 |
+
|
| 149 |
+
---
|
| 150 |
+
|
| 151 |
+
### Inference with Community LoRA
|
| 152 |
+
|
| 153 |
+
Use the `infer_with_lora.py` script to load one or more community LoRA models alongside Stand-In.
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
python infer_with_lora.py \
|
| 157 |
+
--prompt "A man sits comfortably at a desk, facing the camera as if talking to a friend or family member on the screen. His gaze is focused and gentle, with a natural smile. The background is his carefully decorated personal space, with photos and a world map on the wall, conveying a sense of intimate and modern communication." \
|
| 158 |
+
--ip_image "test/input/lecun.jpg" \
|
| 159 |
+
--output "test/output/lecun.mp4" \
|
| 160 |
+
--lora_path "path/to/your/lora.safetensors" \
|
| 161 |
+
--lora_scale 1.0
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
We recommend using this stylization LoRA: [https://civitai.com/models/1404755/studio-ghibli-wan21-t2v-14b](https://civitai.com/models/1404755/studio-ghibli-wan21-t2v-14b)
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
### Video Face Swapping
|
| 169 |
+
|
| 170 |
+
Use the `infer_face_swap.py` script to perform video face swapping with Stand-In.
|
| 171 |
+
|
| 172 |
+
```bash
|
| 173 |
+
python infer_face_swap.py \
|
| 174 |
+
--prompt "The video features a woman standing in front of a large screen displaying the words ""Tech Minute"" and the logo for CNET. She is wearing a purple top and appears to be presenting or speaking about technology-related topics. The background includes a cityscape with tall buildings, suggesting an urban setting. The woman seems to be engaged in a discussion or providing information on technology news or trends. The overall atmosphere is professional and informative, likely aimed at educating viewers about the latest developments in the tech industry." \
|
| 175 |
+
--ip_image "test/input/ruonan.jpg" \
|
| 176 |
+
--output "test/output/ruonan.mp4" \
|
| 177 |
+
--denoising_strength 0.85
|
| 178 |
+
```
|
| 179 |
+
**Note**: Since Wan2.1 itself does not have an inpainting function, our face swapping feature is still experimental.
|
| 180 |
+
|
| 181 |
+
The higher the denoising_strength, the more the background area is redrawn, and the more natural the face area becomes. Conversely, the lower the denoising_strength, the less the background area is redrawn, and the higher the degree of overfitting in the face area.
|
| 182 |
+
|
| 183 |
+
You can set --force_background_consistency to make the background completely consistent, but this may lead to potential and noticeable contour issues. Enabling this feature requires experimenting with different denoising_strength values to achieve the most natural effect. If slight changes to the background are not a concern, please do not enable this feature.
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
### Infer with VACE
|
| 187 |
+
Use the `infer_with_vace.py` script to perform identity-preserving video generation with Stand-In, compatible with VACE.
|
| 188 |
+
```bash
|
| 189 |
+
python infer_with_vace.py \
|
| 190 |
+
--prompt "A woman raises her hands." \
|
| 191 |
+
--vace_path "checkpoints/VACE/" \
|
| 192 |
+
--ip_image "test/input/first_frame.png" \
|
| 193 |
+
--reference_video "test/input/pose.mp4" \
|
| 194 |
+
--reference_image "test/input/first_frame.png" \
|
| 195 |
+
--output "test/output/woman.mp4" \
|
| 196 |
+
--vace_scale 0.8
|
| 197 |
+
```
|
| 198 |
+
You need to download the corresponding weights from the `VACE` repository or provide the path to the `VACE` weights in the `vace_path` parameter.
|
| 199 |
+
|
| 200 |
+
```bash
|
| 201 |
+
python download_models.py --vace
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
The input control video needs to be preprocessed using VACE's preprocessing tool. Both `reference_video` and `reference_image` are optional and can exist simultaneously. Additionally, VACE’s control has a preset bias towards faces, which affects identity preservation. Please lower the `vace_scale` to a balance point where both motion and identity are preserved. When only `ip_image` and `reference_video` are provided, the weight can be reduced to 0.5.
|
| 205 |
+
|
| 206 |
+
Using both Stand-In and VACE together is more challenging than using Stand-In alone. We are still maintaining this feature, so if you encounter unexpected outputs or have other questions, feel free to raise them in the issue.
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
## 🤝 Acknowledgements
|
| 210 |
+
|
| 211 |
+
This project is built upon the following excellent open-source projects:
|
| 212 |
+
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) (training/inference framework)
|
| 213 |
+
* [Wan2.1](https://github.com/Wan-Video/Wan2.1) (base video generation model)
|
| 214 |
+
|
| 215 |
+
We sincerely thank the authors and contributors of these projects.
|
| 216 |
+
|
| 217 |
+
The original raw material of our dataset was collected with the help of our team member [Binxin Yang](https://binxinyang.github.io/), and we appreciate his contribution!
|
| 218 |
+
|
| 219 |
+
---
|
| 220 |
+
|
| 221 |
+
## ✏ Citation
|
| 222 |
+
|
| 223 |
+
If you find our work helpful for your research, please consider citing our paper:
|
| 224 |
+
|
| 225 |
+
```bibtex
|
| 226 |
+
@article{xue2025standin,
|
| 227 |
+
title={Stand-In: A Lightweight and Plug-and-Play Identity Control for Video Generation},
|
| 228 |
+
author={Bowen Xue and Qixin Yan and Wenjing Wang and Hao Liu and Chen Li},
|
| 229 |
+
journal={arXiv preprint arXiv:2508.07901},
|
| 230 |
+
year={2025},
|
| 231 |
+
}
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
---
|
| 235 |
+
|
| 236 |
+
## 📬 Contact Us
|
| 237 |
+
|
| 238 |
+
If you have any questions or suggestions, feel free to reach out via [GitHub Issues](https://github.com/WeChatCV/Stand-In/issues) . We look forward to your feedback!
|
app.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import time
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from data.video import save_video
|
| 9 |
+
from wan_loader import load_wan_pipe
|
| 10 |
+
from models.set_condition_branch import set_stand_in
|
| 11 |
+
from preprocessor import FaceProcessor
|
| 12 |
+
|
| 13 |
+
print("Loading model, please wait...")
|
| 14 |
+
try:
|
| 15 |
+
ANTELOPEV2_PATH = "checkpoints/antelopev2"
|
| 16 |
+
BASE_MODEL_PATH = "checkpoints/base_model/"
|
| 17 |
+
LORA_MODEL_PATH = "checkpoints/Stand-In/Stand-In_wan2.1_T2V_14B_ver1.0.ckpt"
|
| 18 |
+
|
| 19 |
+
if not os.path.exists(ANTELOPEV2_PATH):
|
| 20 |
+
raise FileNotFoundError(
|
| 21 |
+
f"AntelopeV2 checkpoint not found at: {ANTELOPEV2_PATH}"
|
| 22 |
+
)
|
| 23 |
+
if not os.path.exists(BASE_MODEL_PATH):
|
| 24 |
+
raise FileNotFoundError(f"Base model not found at: {BASE_MODEL_PATH}")
|
| 25 |
+
if not os.path.exists(LORA_MODEL_PATH):
|
| 26 |
+
raise FileNotFoundError(f"LoRA model not found at: {LORA_MODEL_PATH}")
|
| 27 |
+
|
| 28 |
+
face_processor = FaceProcessor(antelopv2_path=ANTELOPEV2_PATH)
|
| 29 |
+
pipe = load_wan_pipe(base_path=BASE_MODEL_PATH, torch_dtype=torch.bfloat16)
|
| 30 |
+
set_stand_in(pipe, model_path=LORA_MODEL_PATH)
|
| 31 |
+
print("Model loaded successfully!")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"Model loading failed: {e}")
|
| 34 |
+
with gr.Blocks() as demo:
|
| 35 |
+
gr.Markdown("# Error: Model Loading Failed")
|
| 36 |
+
gr.Markdown(f"""
|
| 37 |
+
Please check the following:
|
| 38 |
+
1. Make sure the checkpoint files are placed in the correct directory.
|
| 39 |
+
2. Ensure all dependencies are properly installed.
|
| 40 |
+
3. Check the console output for detailed error information.
|
| 41 |
+
|
| 42 |
+
**Error details**: {e}
|
| 43 |
+
""")
|
| 44 |
+
demo.launch()
|
| 45 |
+
exit()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def generate_video(
|
| 49 |
+
pil_image: Image.Image,
|
| 50 |
+
prompt: str,
|
| 51 |
+
seed: int,
|
| 52 |
+
negative_prompt: str,
|
| 53 |
+
num_steps: int,
|
| 54 |
+
fps: int,
|
| 55 |
+
quality: int,
|
| 56 |
+
):
|
| 57 |
+
if pil_image is None:
|
| 58 |
+
raise gr.Error("Please upload a face image first!")
|
| 59 |
+
|
| 60 |
+
print("Processing face...")
|
| 61 |
+
ip_image = face_processor.process(pil_image)
|
| 62 |
+
print("Face processing completed.")
|
| 63 |
+
|
| 64 |
+
print("Generating video...")
|
| 65 |
+
start_time = time.time()
|
| 66 |
+
video = pipe(
|
| 67 |
+
prompt=prompt,
|
| 68 |
+
negative_prompt=negative_prompt,
|
| 69 |
+
seed=int(seed),
|
| 70 |
+
ip_image=ip_image,
|
| 71 |
+
num_inference_steps=int(num_steps),
|
| 72 |
+
tiled=False,
|
| 73 |
+
)
|
| 74 |
+
end_time = time.time()
|
| 75 |
+
print(f"Video generated in {end_time - start_time:.2f} seconds.")
|
| 76 |
+
|
| 77 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
|
| 78 |
+
video_path = temp_file.name
|
| 79 |
+
save_video(video, video_path, fps=int(fps), quality=quality)
|
| 80 |
+
print(f"Video saved to: {video_path}")
|
| 81 |
+
return video_path
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
|
| 85 |
+
gr.Markdown(
|
| 86 |
+
"""
|
| 87 |
+
# Stand-In IP2V
|
| 88 |
+
"""
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
with gr.Row():
|
| 92 |
+
with gr.Column(scale=1):
|
| 93 |
+
gr.Markdown("### 1. Upload a Face Image")
|
| 94 |
+
input_image = gr.Image(
|
| 95 |
+
label="Upload Image",
|
| 96 |
+
type="pil",
|
| 97 |
+
image_mode="RGB",
|
| 98 |
+
height=300,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
gr.Markdown("### 2. Enter Core Parameters")
|
| 102 |
+
input_prompt = gr.Textbox(
|
| 103 |
+
label="Prompt",
|
| 104 |
+
lines=4,
|
| 105 |
+
value="一位男性舒适地坐在书桌前,正对着镜头,仿佛在与屏幕前的亲友对话。他的眼神专注而温柔,嘴角带着自然的笑意。背景是他精心布置的个人空间,墙上贴着照片和一张世界地图,传达出一种亲密而现代的沟通感。",
|
| 106 |
+
placeholder="Please enter a detailed description of the scene, character actions, expressions, etc...",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
input_seed = gr.Slider(
|
| 110 |
+
label="Seed",
|
| 111 |
+
minimum=0,
|
| 112 |
+
maximum=100000,
|
| 113 |
+
step=1,
|
| 114 |
+
value=0,
|
| 115 |
+
info="The same seed and parameters will generate the same result.",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 119 |
+
input_negative_prompt = gr.Textbox(
|
| 120 |
+
label="Negative Prompt",
|
| 121 |
+
lines=3,
|
| 122 |
+
value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
| 123 |
+
)
|
| 124 |
+
input_steps = gr.Slider(
|
| 125 |
+
label="Inference Steps",
|
| 126 |
+
minimum=10,
|
| 127 |
+
maximum=50,
|
| 128 |
+
step=1,
|
| 129 |
+
value=20,
|
| 130 |
+
info="More steps may improve details but will take longer to generate.",
|
| 131 |
+
)
|
| 132 |
+
output_fps = gr.Slider(
|
| 133 |
+
label="Video FPS", minimum=10, maximum=30, step=1, value=25
|
| 134 |
+
)
|
| 135 |
+
output_quality = gr.Slider(
|
| 136 |
+
label="Video Quality", minimum=1, maximum=10, step=1, value=9
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
generate_btn = gr.Button("Generate Video", variant="primary")
|
| 140 |
+
|
| 141 |
+
with gr.Column(scale=1):
|
| 142 |
+
gr.Markdown("### 3. View Generated Result")
|
| 143 |
+
output_video = gr.Video(
|
| 144 |
+
label="Generated Video",
|
| 145 |
+
height=480,
|
| 146 |
+
)
|
| 147 |
+
generate_btn.click(
|
| 148 |
+
fn=generate_video,
|
| 149 |
+
inputs=[
|
| 150 |
+
input_image,
|
| 151 |
+
input_prompt,
|
| 152 |
+
input_seed,
|
| 153 |
+
input_negative_prompt,
|
| 154 |
+
input_steps,
|
| 155 |
+
output_fps,
|
| 156 |
+
output_quality,
|
| 157 |
+
],
|
| 158 |
+
outputs=output_video,
|
| 159 |
+
api_name="generate_video",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
demo.launch(share=True, server_name="0.0.0.0", server_port=8080)
|
assets/Stand-In.png
ADDED
|
configs/model_config.py
ADDED
|
@@ -0,0 +1,1809 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import Literal, TypeAlias
|
| 2 |
+
|
| 3 |
+
from models.wan_video_dit import WanModel
|
| 4 |
+
from models.wan_video_text_encoder import WanTextEncoder
|
| 5 |
+
from models.wan_video_image_encoder import WanImageEncoder
|
| 6 |
+
from models.wan_video_vae import WanVideoVAE, WanVideoVAE38
|
| 7 |
+
from models.wan_video_motion_controller import WanMotionControllerModel
|
| 8 |
+
from models.wan_video_vace import VaceWanModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
model_loader_configs = [
|
| 12 |
+
(
|
| 13 |
+
None,
|
| 14 |
+
"9269f8db9040a9d860eaca435be61814",
|
| 15 |
+
["wan_video_dit"],
|
| 16 |
+
[WanModel],
|
| 17 |
+
"civitai",
|
| 18 |
+
),
|
| 19 |
+
(
|
| 20 |
+
None,
|
| 21 |
+
"aafcfd9672c3a2456dc46e1cb6e52c70",
|
| 22 |
+
["wan_video_dit"],
|
| 23 |
+
[WanModel],
|
| 24 |
+
"civitai",
|
| 25 |
+
),
|
| 26 |
+
(
|
| 27 |
+
None,
|
| 28 |
+
"6bfcfb3b342cb286ce886889d519a77e",
|
| 29 |
+
["wan_video_dit"],
|
| 30 |
+
[WanModel],
|
| 31 |
+
"civitai",
|
| 32 |
+
),
|
| 33 |
+
(
|
| 34 |
+
None,
|
| 35 |
+
"6d6ccde6845b95ad9114ab993d917893",
|
| 36 |
+
["wan_video_dit"],
|
| 37 |
+
[WanModel],
|
| 38 |
+
"civitai",
|
| 39 |
+
),
|
| 40 |
+
(
|
| 41 |
+
None,
|
| 42 |
+
"6bfcfb3b342cb286ce886889d519a77e",
|
| 43 |
+
["wan_video_dit"],
|
| 44 |
+
[WanModel],
|
| 45 |
+
"civitai",
|
| 46 |
+
),
|
| 47 |
+
(
|
| 48 |
+
None,
|
| 49 |
+
"349723183fc063b2bfc10bb2835cf677",
|
| 50 |
+
["wan_video_dit"],
|
| 51 |
+
[WanModel],
|
| 52 |
+
"civitai",
|
| 53 |
+
),
|
| 54 |
+
(
|
| 55 |
+
None,
|
| 56 |
+
"efa44cddf936c70abd0ea28b6cbe946c",
|
| 57 |
+
["wan_video_dit"],
|
| 58 |
+
[WanModel],
|
| 59 |
+
"civitai",
|
| 60 |
+
),
|
| 61 |
+
(
|
| 62 |
+
None,
|
| 63 |
+
"3ef3b1f8e1dab83d5b71fd7b617f859f",
|
| 64 |
+
["wan_video_dit"],
|
| 65 |
+
[WanModel],
|
| 66 |
+
"civitai",
|
| 67 |
+
),
|
| 68 |
+
(
|
| 69 |
+
None,
|
| 70 |
+
"70ddad9d3a133785da5ea371aae09504",
|
| 71 |
+
["wan_video_dit"],
|
| 72 |
+
[WanModel],
|
| 73 |
+
"civitai",
|
| 74 |
+
),
|
| 75 |
+
(
|
| 76 |
+
None,
|
| 77 |
+
"26bde73488a92e64cc20b0a7485b9e5b",
|
| 78 |
+
["wan_video_dit"],
|
| 79 |
+
[WanModel],
|
| 80 |
+
"civitai",
|
| 81 |
+
),
|
| 82 |
+
(
|
| 83 |
+
None,
|
| 84 |
+
"ac6a5aa74f4a0aab6f64eb9a72f19901",
|
| 85 |
+
["wan_video_dit"],
|
| 86 |
+
[WanModel],
|
| 87 |
+
"civitai",
|
| 88 |
+
),
|
| 89 |
+
(
|
| 90 |
+
None,
|
| 91 |
+
"b61c605c2adbd23124d152ed28e049ae",
|
| 92 |
+
["wan_video_dit"],
|
| 93 |
+
[WanModel],
|
| 94 |
+
"civitai",
|
| 95 |
+
),
|
| 96 |
+
(
|
| 97 |
+
None,
|
| 98 |
+
"1f5ab7703c6fc803fdded85ff040c316",
|
| 99 |
+
["wan_video_dit"],
|
| 100 |
+
[WanModel],
|
| 101 |
+
"civitai",
|
| 102 |
+
),
|
| 103 |
+
(
|
| 104 |
+
None,
|
| 105 |
+
"5b013604280dd715f8457c6ed6d6a626",
|
| 106 |
+
["wan_video_dit"],
|
| 107 |
+
[WanModel],
|
| 108 |
+
"civitai",
|
| 109 |
+
),
|
| 110 |
+
(
|
| 111 |
+
None,
|
| 112 |
+
"a61453409b67cd3246cf0c3bebad47ba",
|
| 113 |
+
["wan_video_dit", "wan_video_vace"],
|
| 114 |
+
[WanModel, VaceWanModel],
|
| 115 |
+
"civitai",
|
| 116 |
+
),
|
| 117 |
+
(
|
| 118 |
+
None,
|
| 119 |
+
"7a513e1f257a861512b1afd387a8ecd9",
|
| 120 |
+
["wan_video_dit", "wan_video_vace"],
|
| 121 |
+
[WanModel, VaceWanModel],
|
| 122 |
+
"civitai",
|
| 123 |
+
),
|
| 124 |
+
(
|
| 125 |
+
None,
|
| 126 |
+
"cb104773c6c2cb6df4f9529ad5c60d0b",
|
| 127 |
+
["wan_video_dit"],
|
| 128 |
+
[WanModel],
|
| 129 |
+
"diffusers",
|
| 130 |
+
),
|
| 131 |
+
(
|
| 132 |
+
None,
|
| 133 |
+
"9c8818c2cbea55eca56c7b447df170da",
|
| 134 |
+
["wan_video_text_encoder"],
|
| 135 |
+
[WanTextEncoder],
|
| 136 |
+
"civitai",
|
| 137 |
+
),
|
| 138 |
+
(
|
| 139 |
+
None,
|
| 140 |
+
"5941c53e207d62f20f9025686193c40b",
|
| 141 |
+
["wan_video_image_encoder"],
|
| 142 |
+
[WanImageEncoder],
|
| 143 |
+
"civitai",
|
| 144 |
+
),
|
| 145 |
+
(
|
| 146 |
+
None,
|
| 147 |
+
"1378ea763357eea97acdef78e65d6d96",
|
| 148 |
+
["wan_video_vae"],
|
| 149 |
+
[WanVideoVAE],
|
| 150 |
+
"civitai",
|
| 151 |
+
),
|
| 152 |
+
(
|
| 153 |
+
None,
|
| 154 |
+
"ccc42284ea13e1ad04693284c7a09be6",
|
| 155 |
+
["wan_video_vae"],
|
| 156 |
+
[WanVideoVAE],
|
| 157 |
+
"civitai",
|
| 158 |
+
),
|
| 159 |
+
(
|
| 160 |
+
None,
|
| 161 |
+
"e1de6c02cdac79f8b739f4d3698cd216",
|
| 162 |
+
["wan_video_vae"],
|
| 163 |
+
[WanVideoVAE38],
|
| 164 |
+
"civitai",
|
| 165 |
+
),
|
| 166 |
+
(
|
| 167 |
+
None,
|
| 168 |
+
"dbd5ec76bbf977983f972c151d545389",
|
| 169 |
+
["wan_video_motion_controller"],
|
| 170 |
+
[WanMotionControllerModel],
|
| 171 |
+
"civitai",
|
| 172 |
+
),
|
| 173 |
+
]
|
| 174 |
+
huggingface_model_loader_configs = [
|
| 175 |
+
# These configs are provided for detecting model type automatically.
|
| 176 |
+
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
| 177 |
+
(
|
| 178 |
+
"ChatGLMModel",
|
| 179 |
+
"diffsynth.models.kolors_text_encoder",
|
| 180 |
+
"kolors_text_encoder",
|
| 181 |
+
None,
|
| 182 |
+
),
|
| 183 |
+
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
| 184 |
+
(
|
| 185 |
+
"BloomForCausalLM",
|
| 186 |
+
"transformers.models.bloom.modeling_bloom",
|
| 187 |
+
"beautiful_prompt",
|
| 188 |
+
None,
|
| 189 |
+
),
|
| 190 |
+
(
|
| 191 |
+
"Qwen2ForCausalLM",
|
| 192 |
+
"transformers.models.qwen2.modeling_qwen2",
|
| 193 |
+
"qwen_prompt",
|
| 194 |
+
None,
|
| 195 |
+
),
|
| 196 |
+
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
| 197 |
+
(
|
| 198 |
+
"T5EncoderModel",
|
| 199 |
+
"diffsynth.models.flux_text_encoder",
|
| 200 |
+
"flux_text_encoder_2",
|
| 201 |
+
"FluxTextEncoder2",
|
| 202 |
+
),
|
| 203 |
+
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
| 204 |
+
(
|
| 205 |
+
"SiglipModel",
|
| 206 |
+
"transformers.models.siglip.modeling_siglip",
|
| 207 |
+
"siglip_vision_model",
|
| 208 |
+
"SiglipVisionModel",
|
| 209 |
+
),
|
| 210 |
+
(
|
| 211 |
+
"LlamaForCausalLM",
|
| 212 |
+
"diffsynth.models.hunyuan_video_text_encoder",
|
| 213 |
+
"hunyuan_video_text_encoder_2",
|
| 214 |
+
"HunyuanVideoLLMEncoder",
|
| 215 |
+
),
|
| 216 |
+
(
|
| 217 |
+
"LlavaForConditionalGeneration",
|
| 218 |
+
"diffsynth.models.hunyuan_video_text_encoder",
|
| 219 |
+
"hunyuan_video_text_encoder_2",
|
| 220 |
+
"HunyuanVideoMLLMEncoder",
|
| 221 |
+
),
|
| 222 |
+
(
|
| 223 |
+
"Step1Model",
|
| 224 |
+
"diffsynth.models.stepvideo_text_encoder",
|
| 225 |
+
"stepvideo_text_encoder_2",
|
| 226 |
+
"STEP1TextEncoder",
|
| 227 |
+
),
|
| 228 |
+
(
|
| 229 |
+
"Qwen2_5_VLForConditionalGeneration",
|
| 230 |
+
"diffsynth.models.qwenvl",
|
| 231 |
+
"qwenvl",
|
| 232 |
+
"Qwen25VL_7b_Embedder",
|
| 233 |
+
),
|
| 234 |
+
]
|
| 235 |
+
patch_model_loader_configs = [
|
| 236 |
+
# These configs are provided for detecting model type automatically.
|
| 237 |
+
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
preset_models_on_huggingface = {
|
| 241 |
+
"HunyuanDiT": [
|
| 242 |
+
(
|
| 243 |
+
"Tencent-Hunyuan/HunyuanDiT",
|
| 244 |
+
"t2i/clip_text_encoder/pytorch_model.bin",
|
| 245 |
+
"models/HunyuanDiT/t2i/clip_text_encoder",
|
| 246 |
+
),
|
| 247 |
+
(
|
| 248 |
+
"Tencent-Hunyuan/HunyuanDiT",
|
| 249 |
+
"t2i/mt5/pytorch_model.bin",
|
| 250 |
+
"models/HunyuanDiT/t2i/mt5",
|
| 251 |
+
),
|
| 252 |
+
(
|
| 253 |
+
"Tencent-Hunyuan/HunyuanDiT",
|
| 254 |
+
"t2i/model/pytorch_model_ema.pt",
|
| 255 |
+
"models/HunyuanDiT/t2i/model",
|
| 256 |
+
),
|
| 257 |
+
(
|
| 258 |
+
"Tencent-Hunyuan/HunyuanDiT",
|
| 259 |
+
"t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin",
|
| 260 |
+
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix",
|
| 261 |
+
),
|
| 262 |
+
],
|
| 263 |
+
"stable-video-diffusion-img2vid-xt": [
|
| 264 |
+
(
|
| 265 |
+
"stabilityai/stable-video-diffusion-img2vid-xt",
|
| 266 |
+
"svd_xt.safetensors",
|
| 267 |
+
"models/stable_video_diffusion",
|
| 268 |
+
),
|
| 269 |
+
],
|
| 270 |
+
"ExVideo-SVD-128f-v1": [
|
| 271 |
+
(
|
| 272 |
+
"ECNU-CILab/ExVideo-SVD-128f-v1",
|
| 273 |
+
"model.fp16.safetensors",
|
| 274 |
+
"models/stable_video_diffusion",
|
| 275 |
+
),
|
| 276 |
+
],
|
| 277 |
+
# Stable Diffusion
|
| 278 |
+
"StableDiffusion_v15": [
|
| 279 |
+
(
|
| 280 |
+
"benjamin-paine/stable-diffusion-v1-5",
|
| 281 |
+
"v1-5-pruned-emaonly.safetensors",
|
| 282 |
+
"models/stable_diffusion",
|
| 283 |
+
),
|
| 284 |
+
],
|
| 285 |
+
"DreamShaper_8": [
|
| 286 |
+
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
| 287 |
+
],
|
| 288 |
+
# Textual Inversion
|
| 289 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
| 290 |
+
(
|
| 291 |
+
"gemasai/verybadimagenegative_v1.3",
|
| 292 |
+
"verybadimagenegative_v1.3.pt",
|
| 293 |
+
"models/textual_inversion",
|
| 294 |
+
),
|
| 295 |
+
],
|
| 296 |
+
# Stable Diffusion XL
|
| 297 |
+
"StableDiffusionXL_v1": [
|
| 298 |
+
(
|
| 299 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 300 |
+
"sd_xl_base_1.0.safetensors",
|
| 301 |
+
"models/stable_diffusion_xl",
|
| 302 |
+
),
|
| 303 |
+
],
|
| 304 |
+
"BluePencilXL_v200": [
|
| 305 |
+
(
|
| 306 |
+
"frankjoshua/bluePencilXL_v200",
|
| 307 |
+
"bluePencilXL_v200.safetensors",
|
| 308 |
+
"models/stable_diffusion_xl",
|
| 309 |
+
),
|
| 310 |
+
],
|
| 311 |
+
"StableDiffusionXL_Turbo": [
|
| 312 |
+
(
|
| 313 |
+
"stabilityai/sdxl-turbo",
|
| 314 |
+
"sd_xl_turbo_1.0_fp16.safetensors",
|
| 315 |
+
"models/stable_diffusion_xl_turbo",
|
| 316 |
+
),
|
| 317 |
+
],
|
| 318 |
+
# Stable Diffusion 3
|
| 319 |
+
"StableDiffusion3": [
|
| 320 |
+
(
|
| 321 |
+
"stabilityai/stable-diffusion-3-medium",
|
| 322 |
+
"sd3_medium_incl_clips_t5xxlfp16.safetensors",
|
| 323 |
+
"models/stable_diffusion_3",
|
| 324 |
+
),
|
| 325 |
+
],
|
| 326 |
+
"StableDiffusion3_without_T5": [
|
| 327 |
+
(
|
| 328 |
+
"stabilityai/stable-diffusion-3-medium",
|
| 329 |
+
"sd3_medium_incl_clips.safetensors",
|
| 330 |
+
"models/stable_diffusion_3",
|
| 331 |
+
),
|
| 332 |
+
],
|
| 333 |
+
# ControlNet
|
| 334 |
+
"ControlNet_v11f1p_sd15_depth": [
|
| 335 |
+
(
|
| 336 |
+
"lllyasviel/ControlNet-v1-1",
|
| 337 |
+
"control_v11f1p_sd15_depth.pth",
|
| 338 |
+
"models/ControlNet",
|
| 339 |
+
),
|
| 340 |
+
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
| 341 |
+
],
|
| 342 |
+
"ControlNet_v11p_sd15_softedge": [
|
| 343 |
+
(
|
| 344 |
+
"lllyasviel/ControlNet-v1-1",
|
| 345 |
+
"control_v11p_sd15_softedge.pth",
|
| 346 |
+
"models/ControlNet",
|
| 347 |
+
),
|
| 348 |
+
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
| 349 |
+
],
|
| 350 |
+
"ControlNet_v11f1e_sd15_tile": [
|
| 351 |
+
(
|
| 352 |
+
"lllyasviel/ControlNet-v1-1",
|
| 353 |
+
"control_v11f1e_sd15_tile.pth",
|
| 354 |
+
"models/ControlNet",
|
| 355 |
+
)
|
| 356 |
+
],
|
| 357 |
+
"ControlNet_v11p_sd15_lineart": [
|
| 358 |
+
(
|
| 359 |
+
"lllyasviel/ControlNet-v1-1",
|
| 360 |
+
"control_v11p_sd15_lineart.pth",
|
| 361 |
+
"models/ControlNet",
|
| 362 |
+
),
|
| 363 |
+
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
|
| 364 |
+
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators"),
|
| 365 |
+
],
|
| 366 |
+
"ControlNet_union_sdxl_promax": [
|
| 367 |
+
(
|
| 368 |
+
"xinsir/controlnet-union-sdxl-1.0",
|
| 369 |
+
"diffusion_pytorch_model_promax.safetensors",
|
| 370 |
+
"models/ControlNet/controlnet_union",
|
| 371 |
+
),
|
| 372 |
+
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
| 373 |
+
],
|
| 374 |
+
# AnimateDiff
|
| 375 |
+
"AnimateDiff_v2": [
|
| 376 |
+
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
| 377 |
+
],
|
| 378 |
+
"AnimateDiff_xl_beta": [
|
| 379 |
+
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
| 380 |
+
],
|
| 381 |
+
# Qwen Prompt
|
| 382 |
+
"QwenPrompt": [
|
| 383 |
+
(
|
| 384 |
+
"Qwen/Qwen2-1.5B-Instruct",
|
| 385 |
+
"config.json",
|
| 386 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 387 |
+
),
|
| 388 |
+
(
|
| 389 |
+
"Qwen/Qwen2-1.5B-Instruct",
|
| 390 |
+
"generation_config.json",
|
| 391 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 392 |
+
),
|
| 393 |
+
(
|
| 394 |
+
"Qwen/Qwen2-1.5B-Instruct",
|
| 395 |
+
"model.safetensors",
|
| 396 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 397 |
+
),
|
| 398 |
+
(
|
| 399 |
+
"Qwen/Qwen2-1.5B-Instruct",
|
| 400 |
+
"special_tokens_map.json",
|
| 401 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 402 |
+
),
|
| 403 |
+
(
|
| 404 |
+
"Qwen/Qwen2-1.5B-Instruct",
|
| 405 |
+
"tokenizer.json",
|
| 406 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 407 |
+
),
|
| 408 |
+
(
|
| 409 |
+
"Qwen/Qwen2-1.5B-Instruct",
|
| 410 |
+
"tokenizer_config.json",
|
| 411 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 412 |
+
),
|
| 413 |
+
(
|
| 414 |
+
"Qwen/Qwen2-1.5B-Instruct",
|
| 415 |
+
"merges.txt",
|
| 416 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 417 |
+
),
|
| 418 |
+
(
|
| 419 |
+
"Qwen/Qwen2-1.5B-Instruct",
|
| 420 |
+
"vocab.json",
|
| 421 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 422 |
+
),
|
| 423 |
+
],
|
| 424 |
+
# Beautiful Prompt
|
| 425 |
+
"BeautifulPrompt": [
|
| 426 |
+
(
|
| 427 |
+
"alibaba-pai/pai-bloom-1b1-text2prompt-sd",
|
| 428 |
+
"config.json",
|
| 429 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 430 |
+
),
|
| 431 |
+
(
|
| 432 |
+
"alibaba-pai/pai-bloom-1b1-text2prompt-sd",
|
| 433 |
+
"generation_config.json",
|
| 434 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 435 |
+
),
|
| 436 |
+
(
|
| 437 |
+
"alibaba-pai/pai-bloom-1b1-text2prompt-sd",
|
| 438 |
+
"model.safetensors",
|
| 439 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 440 |
+
),
|
| 441 |
+
(
|
| 442 |
+
"alibaba-pai/pai-bloom-1b1-text2prompt-sd",
|
| 443 |
+
"special_tokens_map.json",
|
| 444 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 445 |
+
),
|
| 446 |
+
(
|
| 447 |
+
"alibaba-pai/pai-bloom-1b1-text2prompt-sd",
|
| 448 |
+
"tokenizer.json",
|
| 449 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 450 |
+
),
|
| 451 |
+
(
|
| 452 |
+
"alibaba-pai/pai-bloom-1b1-text2prompt-sd",
|
| 453 |
+
"tokenizer_config.json",
|
| 454 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 455 |
+
),
|
| 456 |
+
],
|
| 457 |
+
# Omost prompt
|
| 458 |
+
"OmostPrompt": [
|
| 459 |
+
(
|
| 460 |
+
"lllyasviel/omost-llama-3-8b-4bits",
|
| 461 |
+
"model-00001-of-00002.safetensors",
|
| 462 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 463 |
+
),
|
| 464 |
+
(
|
| 465 |
+
"lllyasviel/omost-llama-3-8b-4bits",
|
| 466 |
+
"model-00002-of-00002.safetensors",
|
| 467 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 468 |
+
),
|
| 469 |
+
(
|
| 470 |
+
"lllyasviel/omost-llama-3-8b-4bits",
|
| 471 |
+
"tokenizer.json",
|
| 472 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 473 |
+
),
|
| 474 |
+
(
|
| 475 |
+
"lllyasviel/omost-llama-3-8b-4bits",
|
| 476 |
+
"tokenizer_config.json",
|
| 477 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 478 |
+
),
|
| 479 |
+
(
|
| 480 |
+
"lllyasviel/omost-llama-3-8b-4bits",
|
| 481 |
+
"config.json",
|
| 482 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 483 |
+
),
|
| 484 |
+
(
|
| 485 |
+
"lllyasviel/omost-llama-3-8b-4bits",
|
| 486 |
+
"generation_config.json",
|
| 487 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 488 |
+
),
|
| 489 |
+
(
|
| 490 |
+
"lllyasviel/omost-llama-3-8b-4bits",
|
| 491 |
+
"model.safetensors.index.json",
|
| 492 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 493 |
+
),
|
| 494 |
+
(
|
| 495 |
+
"lllyasviel/omost-llama-3-8b-4bits",
|
| 496 |
+
"special_tokens_map.json",
|
| 497 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 498 |
+
),
|
| 499 |
+
],
|
| 500 |
+
# Translator
|
| 501 |
+
"opus-mt-zh-en": [
|
| 502 |
+
(
|
| 503 |
+
"Helsinki-NLP/opus-mt-zh-en",
|
| 504 |
+
"config.json",
|
| 505 |
+
"models/translator/opus-mt-zh-en",
|
| 506 |
+
),
|
| 507 |
+
(
|
| 508 |
+
"Helsinki-NLP/opus-mt-zh-en",
|
| 509 |
+
"generation_config.json",
|
| 510 |
+
"models/translator/opus-mt-zh-en",
|
| 511 |
+
),
|
| 512 |
+
(
|
| 513 |
+
"Helsinki-NLP/opus-mt-zh-en",
|
| 514 |
+
"metadata.json",
|
| 515 |
+
"models/translator/opus-mt-zh-en",
|
| 516 |
+
),
|
| 517 |
+
(
|
| 518 |
+
"Helsinki-NLP/opus-mt-zh-en",
|
| 519 |
+
"pytorch_model.bin",
|
| 520 |
+
"models/translator/opus-mt-zh-en",
|
| 521 |
+
),
|
| 522 |
+
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
| 523 |
+
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
| 524 |
+
(
|
| 525 |
+
"Helsinki-NLP/opus-mt-zh-en",
|
| 526 |
+
"tokenizer_config.json",
|
| 527 |
+
"models/translator/opus-mt-zh-en",
|
| 528 |
+
),
|
| 529 |
+
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
| 530 |
+
],
|
| 531 |
+
# IP-Adapter
|
| 532 |
+
"IP-Adapter-SD": [
|
| 533 |
+
(
|
| 534 |
+
"h94/IP-Adapter",
|
| 535 |
+
"models/image_encoder/model.safetensors",
|
| 536 |
+
"models/IpAdapter/stable_diffusion/image_encoder",
|
| 537 |
+
),
|
| 538 |
+
(
|
| 539 |
+
"h94/IP-Adapter",
|
| 540 |
+
"models/ip-adapter_sd15.bin",
|
| 541 |
+
"models/IpAdapter/stable_diffusion",
|
| 542 |
+
),
|
| 543 |
+
],
|
| 544 |
+
"IP-Adapter-SDXL": [
|
| 545 |
+
(
|
| 546 |
+
"h94/IP-Adapter",
|
| 547 |
+
"sdxl_models/image_encoder/model.safetensors",
|
| 548 |
+
"models/IpAdapter/stable_diffusion_xl/image_encoder",
|
| 549 |
+
),
|
| 550 |
+
(
|
| 551 |
+
"h94/IP-Adapter",
|
| 552 |
+
"sdxl_models/ip-adapter_sdxl.bin",
|
| 553 |
+
"models/IpAdapter/stable_diffusion_xl",
|
| 554 |
+
),
|
| 555 |
+
],
|
| 556 |
+
"SDXL-vae-fp16-fix": [
|
| 557 |
+
(
|
| 558 |
+
"madebyollin/sdxl-vae-fp16-fix",
|
| 559 |
+
"diffusion_pytorch_model.safetensors",
|
| 560 |
+
"models/sdxl-vae-fp16-fix",
|
| 561 |
+
)
|
| 562 |
+
],
|
| 563 |
+
# Kolors
|
| 564 |
+
"Kolors": [
|
| 565 |
+
(
|
| 566 |
+
"Kwai-Kolors/Kolors",
|
| 567 |
+
"text_encoder/config.json",
|
| 568 |
+
"models/kolors/Kolors/text_encoder",
|
| 569 |
+
),
|
| 570 |
+
(
|
| 571 |
+
"Kwai-Kolors/Kolors",
|
| 572 |
+
"text_encoder/pytorch_model.bin.index.json",
|
| 573 |
+
"models/kolors/Kolors/text_encoder",
|
| 574 |
+
),
|
| 575 |
+
(
|
| 576 |
+
"Kwai-Kolors/Kolors",
|
| 577 |
+
"text_encoder/pytorch_model-00001-of-00007.bin",
|
| 578 |
+
"models/kolors/Kolors/text_encoder",
|
| 579 |
+
),
|
| 580 |
+
(
|
| 581 |
+
"Kwai-Kolors/Kolors",
|
| 582 |
+
"text_encoder/pytorch_model-00002-of-00007.bin",
|
| 583 |
+
"models/kolors/Kolors/text_encoder",
|
| 584 |
+
),
|
| 585 |
+
(
|
| 586 |
+
"Kwai-Kolors/Kolors",
|
| 587 |
+
"text_encoder/pytorch_model-00003-of-00007.bin",
|
| 588 |
+
"models/kolors/Kolors/text_encoder",
|
| 589 |
+
),
|
| 590 |
+
(
|
| 591 |
+
"Kwai-Kolors/Kolors",
|
| 592 |
+
"text_encoder/pytorch_model-00004-of-00007.bin",
|
| 593 |
+
"models/kolors/Kolors/text_encoder",
|
| 594 |
+
),
|
| 595 |
+
(
|
| 596 |
+
"Kwai-Kolors/Kolors",
|
| 597 |
+
"text_encoder/pytorch_model-00005-of-00007.bin",
|
| 598 |
+
"models/kolors/Kolors/text_encoder",
|
| 599 |
+
),
|
| 600 |
+
(
|
| 601 |
+
"Kwai-Kolors/Kolors",
|
| 602 |
+
"text_encoder/pytorch_model-00006-of-00007.bin",
|
| 603 |
+
"models/kolors/Kolors/text_encoder",
|
| 604 |
+
),
|
| 605 |
+
(
|
| 606 |
+
"Kwai-Kolors/Kolors",
|
| 607 |
+
"text_encoder/pytorch_model-00007-of-00007.bin",
|
| 608 |
+
"models/kolors/Kolors/text_encoder",
|
| 609 |
+
),
|
| 610 |
+
(
|
| 611 |
+
"Kwai-Kolors/Kolors",
|
| 612 |
+
"unet/diffusion_pytorch_model.safetensors",
|
| 613 |
+
"models/kolors/Kolors/unet",
|
| 614 |
+
),
|
| 615 |
+
(
|
| 616 |
+
"Kwai-Kolors/Kolors",
|
| 617 |
+
"vae/diffusion_pytorch_model.safetensors",
|
| 618 |
+
"models/kolors/Kolors/vae",
|
| 619 |
+
),
|
| 620 |
+
],
|
| 621 |
+
# FLUX
|
| 622 |
+
"FLUX.1-dev": [
|
| 623 |
+
(
|
| 624 |
+
"black-forest-labs/FLUX.1-dev",
|
| 625 |
+
"text_encoder/model.safetensors",
|
| 626 |
+
"models/FLUX/FLUX.1-dev/text_encoder",
|
| 627 |
+
),
|
| 628 |
+
(
|
| 629 |
+
"black-forest-labs/FLUX.1-dev",
|
| 630 |
+
"text_encoder_2/config.json",
|
| 631 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 632 |
+
),
|
| 633 |
+
(
|
| 634 |
+
"black-forest-labs/FLUX.1-dev",
|
| 635 |
+
"text_encoder_2/model-00001-of-00002.safetensors",
|
| 636 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 637 |
+
),
|
| 638 |
+
(
|
| 639 |
+
"black-forest-labs/FLUX.1-dev",
|
| 640 |
+
"text_encoder_2/model-00002-of-00002.safetensors",
|
| 641 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 642 |
+
),
|
| 643 |
+
(
|
| 644 |
+
"black-forest-labs/FLUX.1-dev",
|
| 645 |
+
"text_encoder_2/model.safetensors.index.json",
|
| 646 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 647 |
+
),
|
| 648 |
+
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 649 |
+
(
|
| 650 |
+
"black-forest-labs/FLUX.1-dev",
|
| 651 |
+
"flux1-dev.safetensors",
|
| 652 |
+
"models/FLUX/FLUX.1-dev",
|
| 653 |
+
),
|
| 654 |
+
],
|
| 655 |
+
"InstantX/FLUX.1-dev-IP-Adapter": {
|
| 656 |
+
"file_list": [
|
| 657 |
+
(
|
| 658 |
+
"InstantX/FLUX.1-dev-IP-Adapter",
|
| 659 |
+
"ip-adapter.bin",
|
| 660 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter",
|
| 661 |
+
),
|
| 662 |
+
(
|
| 663 |
+
"google/siglip-so400m-patch14-384",
|
| 664 |
+
"model.safetensors",
|
| 665 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
| 666 |
+
),
|
| 667 |
+
(
|
| 668 |
+
"google/siglip-so400m-patch14-384",
|
| 669 |
+
"config.json",
|
| 670 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
| 671 |
+
),
|
| 672 |
+
],
|
| 673 |
+
"load_path": [
|
| 674 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
| 675 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
| 676 |
+
],
|
| 677 |
+
},
|
| 678 |
+
# RIFE
|
| 679 |
+
"RIFE": [
|
| 680 |
+
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
| 681 |
+
],
|
| 682 |
+
# CogVideo
|
| 683 |
+
"CogVideoX-5B": [
|
| 684 |
+
(
|
| 685 |
+
"THUDM/CogVideoX-5b",
|
| 686 |
+
"text_encoder/config.json",
|
| 687 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
| 688 |
+
),
|
| 689 |
+
(
|
| 690 |
+
"THUDM/CogVideoX-5b",
|
| 691 |
+
"text_encoder/model.safetensors.index.json",
|
| 692 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
| 693 |
+
),
|
| 694 |
+
(
|
| 695 |
+
"THUDM/CogVideoX-5b",
|
| 696 |
+
"text_encoder/model-00001-of-00002.safetensors",
|
| 697 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
| 698 |
+
),
|
| 699 |
+
(
|
| 700 |
+
"THUDM/CogVideoX-5b",
|
| 701 |
+
"text_encoder/model-00002-of-00002.safetensors",
|
| 702 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
| 703 |
+
),
|
| 704 |
+
(
|
| 705 |
+
"THUDM/CogVideoX-5b",
|
| 706 |
+
"transformer/config.json",
|
| 707 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
| 708 |
+
),
|
| 709 |
+
(
|
| 710 |
+
"THUDM/CogVideoX-5b",
|
| 711 |
+
"transformer/diffusion_pytorch_model.safetensors.index.json",
|
| 712 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
| 713 |
+
),
|
| 714 |
+
(
|
| 715 |
+
"THUDM/CogVideoX-5b",
|
| 716 |
+
"transformer/diffusion_pytorch_model-00001-of-00002.safetensors",
|
| 717 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
| 718 |
+
),
|
| 719 |
+
(
|
| 720 |
+
"THUDM/CogVideoX-5b",
|
| 721 |
+
"transformer/diffusion_pytorch_model-00002-of-00002.safetensors",
|
| 722 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
| 723 |
+
),
|
| 724 |
+
(
|
| 725 |
+
"THUDM/CogVideoX-5b",
|
| 726 |
+
"vae/diffusion_pytorch_model.safetensors",
|
| 727 |
+
"models/CogVideo/CogVideoX-5b/vae",
|
| 728 |
+
),
|
| 729 |
+
],
|
| 730 |
+
# Stable Diffusion 3.5
|
| 731 |
+
"StableDiffusion3.5-large": [
|
| 732 |
+
(
|
| 733 |
+
"stabilityai/stable-diffusion-3.5-large",
|
| 734 |
+
"sd3.5_large.safetensors",
|
| 735 |
+
"models/stable_diffusion_3",
|
| 736 |
+
),
|
| 737 |
+
(
|
| 738 |
+
"stabilityai/stable-diffusion-3.5-large",
|
| 739 |
+
"text_encoders/clip_l.safetensors",
|
| 740 |
+
"models/stable_diffusion_3/text_encoders",
|
| 741 |
+
),
|
| 742 |
+
(
|
| 743 |
+
"stabilityai/stable-diffusion-3.5-large",
|
| 744 |
+
"text_encoders/clip_g.safetensors",
|
| 745 |
+
"models/stable_diffusion_3/text_encoders",
|
| 746 |
+
),
|
| 747 |
+
(
|
| 748 |
+
"stabilityai/stable-diffusion-3.5-large",
|
| 749 |
+
"text_encoders/t5xxl_fp16.safetensors",
|
| 750 |
+
"models/stable_diffusion_3/text_encoders",
|
| 751 |
+
),
|
| 752 |
+
],
|
| 753 |
+
}
|
| 754 |
+
preset_models_on_modelscope = {
|
| 755 |
+
# Hunyuan DiT
|
| 756 |
+
"HunyuanDiT": [
|
| 757 |
+
(
|
| 758 |
+
"modelscope/HunyuanDiT",
|
| 759 |
+
"t2i/clip_text_encoder/pytorch_model.bin",
|
| 760 |
+
"models/HunyuanDiT/t2i/clip_text_encoder",
|
| 761 |
+
),
|
| 762 |
+
(
|
| 763 |
+
"modelscope/HunyuanDiT",
|
| 764 |
+
"t2i/mt5/pytorch_model.bin",
|
| 765 |
+
"models/HunyuanDiT/t2i/mt5",
|
| 766 |
+
),
|
| 767 |
+
(
|
| 768 |
+
"modelscope/HunyuanDiT",
|
| 769 |
+
"t2i/model/pytorch_model_ema.pt",
|
| 770 |
+
"models/HunyuanDiT/t2i/model",
|
| 771 |
+
),
|
| 772 |
+
(
|
| 773 |
+
"modelscope/HunyuanDiT",
|
| 774 |
+
"t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin",
|
| 775 |
+
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix",
|
| 776 |
+
),
|
| 777 |
+
],
|
| 778 |
+
# Stable Video Diffusion
|
| 779 |
+
"stable-video-diffusion-img2vid-xt": [
|
| 780 |
+
(
|
| 781 |
+
"AI-ModelScope/stable-video-diffusion-img2vid-xt",
|
| 782 |
+
"svd_xt.safetensors",
|
| 783 |
+
"models/stable_video_diffusion",
|
| 784 |
+
),
|
| 785 |
+
],
|
| 786 |
+
# ExVideo
|
| 787 |
+
"ExVideo-SVD-128f-v1": [
|
| 788 |
+
(
|
| 789 |
+
"ECNU-CILab/ExVideo-SVD-128f-v1",
|
| 790 |
+
"model.fp16.safetensors",
|
| 791 |
+
"models/stable_video_diffusion",
|
| 792 |
+
),
|
| 793 |
+
],
|
| 794 |
+
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
| 795 |
+
(
|
| 796 |
+
"ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1",
|
| 797 |
+
"ExVideo-CogVideoX-LoRA-129f-v1.safetensors",
|
| 798 |
+
"models/lora",
|
| 799 |
+
),
|
| 800 |
+
],
|
| 801 |
+
# Stable Diffusion
|
| 802 |
+
"StableDiffusion_v15": [
|
| 803 |
+
(
|
| 804 |
+
"AI-ModelScope/stable-diffusion-v1-5",
|
| 805 |
+
"v1-5-pruned-emaonly.safetensors",
|
| 806 |
+
"models/stable_diffusion",
|
| 807 |
+
),
|
| 808 |
+
],
|
| 809 |
+
"DreamShaper_8": [
|
| 810 |
+
(
|
| 811 |
+
"sd_lora/dreamshaper_8",
|
| 812 |
+
"dreamshaper_8.safetensors",
|
| 813 |
+
"models/stable_diffusion",
|
| 814 |
+
),
|
| 815 |
+
],
|
| 816 |
+
"AingDiffusion_v12": [
|
| 817 |
+
(
|
| 818 |
+
"sd_lora/aingdiffusion_v12",
|
| 819 |
+
"aingdiffusion_v12.safetensors",
|
| 820 |
+
"models/stable_diffusion",
|
| 821 |
+
),
|
| 822 |
+
],
|
| 823 |
+
"Flat2DAnimerge_v45Sharp": [
|
| 824 |
+
(
|
| 825 |
+
"sd_lora/Flat-2D-Animerge",
|
| 826 |
+
"flat2DAnimerge_v45Sharp.safetensors",
|
| 827 |
+
"models/stable_diffusion",
|
| 828 |
+
),
|
| 829 |
+
],
|
| 830 |
+
# Textual Inversion
|
| 831 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
| 832 |
+
(
|
| 833 |
+
"sd_lora/verybadimagenegative_v1.3",
|
| 834 |
+
"verybadimagenegative_v1.3.pt",
|
| 835 |
+
"models/textual_inversion",
|
| 836 |
+
),
|
| 837 |
+
],
|
| 838 |
+
# Stable Diffusion XL
|
| 839 |
+
"StableDiffusionXL_v1": [
|
| 840 |
+
(
|
| 841 |
+
"AI-ModelScope/stable-diffusion-xl-base-1.0",
|
| 842 |
+
"sd_xl_base_1.0.safetensors",
|
| 843 |
+
"models/stable_diffusion_xl",
|
| 844 |
+
),
|
| 845 |
+
],
|
| 846 |
+
"BluePencilXL_v200": [
|
| 847 |
+
(
|
| 848 |
+
"sd_lora/bluePencilXL_v200",
|
| 849 |
+
"bluePencilXL_v200.safetensors",
|
| 850 |
+
"models/stable_diffusion_xl",
|
| 851 |
+
),
|
| 852 |
+
],
|
| 853 |
+
"StableDiffusionXL_Turbo": [
|
| 854 |
+
(
|
| 855 |
+
"AI-ModelScope/sdxl-turbo",
|
| 856 |
+
"sd_xl_turbo_1.0_fp16.safetensors",
|
| 857 |
+
"models/stable_diffusion_xl_turbo",
|
| 858 |
+
),
|
| 859 |
+
],
|
| 860 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
| 861 |
+
(
|
| 862 |
+
"sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0",
|
| 863 |
+
"zyd232_ChineseInkStyle_SDXL_v1_0.safetensors",
|
| 864 |
+
"models/lora",
|
| 865 |
+
),
|
| 866 |
+
],
|
| 867 |
+
# Stable Diffusion 3
|
| 868 |
+
"StableDiffusion3": [
|
| 869 |
+
(
|
| 870 |
+
"AI-ModelScope/stable-diffusion-3-medium",
|
| 871 |
+
"sd3_medium_incl_clips_t5xxlfp16.safetensors",
|
| 872 |
+
"models/stable_diffusion_3",
|
| 873 |
+
),
|
| 874 |
+
],
|
| 875 |
+
"StableDiffusion3_without_T5": [
|
| 876 |
+
(
|
| 877 |
+
"AI-ModelScope/stable-diffusion-3-medium",
|
| 878 |
+
"sd3_medium_incl_clips.safetensors",
|
| 879 |
+
"models/stable_diffusion_3",
|
| 880 |
+
),
|
| 881 |
+
],
|
| 882 |
+
# ControlNet
|
| 883 |
+
"ControlNet_v11f1p_sd15_depth": [
|
| 884 |
+
(
|
| 885 |
+
"AI-ModelScope/ControlNet-v1-1",
|
| 886 |
+
"control_v11f1p_sd15_depth.pth",
|
| 887 |
+
"models/ControlNet",
|
| 888 |
+
),
|
| 889 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
| 890 |
+
],
|
| 891 |
+
"ControlNet_v11p_sd15_softedge": [
|
| 892 |
+
(
|
| 893 |
+
"AI-ModelScope/ControlNet-v1-1",
|
| 894 |
+
"control_v11p_sd15_softedge.pth",
|
| 895 |
+
"models/ControlNet",
|
| 896 |
+
),
|
| 897 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
| 898 |
+
],
|
| 899 |
+
"ControlNet_v11f1e_sd15_tile": [
|
| 900 |
+
(
|
| 901 |
+
"AI-ModelScope/ControlNet-v1-1",
|
| 902 |
+
"control_v11f1e_sd15_tile.pth",
|
| 903 |
+
"models/ControlNet",
|
| 904 |
+
)
|
| 905 |
+
],
|
| 906 |
+
"ControlNet_v11p_sd15_lineart": [
|
| 907 |
+
(
|
| 908 |
+
"AI-ModelScope/ControlNet-v1-1",
|
| 909 |
+
"control_v11p_sd15_lineart.pth",
|
| 910 |
+
"models/ControlNet",
|
| 911 |
+
),
|
| 912 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
| 913 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
| 914 |
+
],
|
| 915 |
+
"ControlNet_union_sdxl_promax": [
|
| 916 |
+
(
|
| 917 |
+
"AI-ModelScope/controlnet-union-sdxl-1.0",
|
| 918 |
+
"diffusion_pytorch_model_promax.safetensors",
|
| 919 |
+
"models/ControlNet/controlnet_union",
|
| 920 |
+
),
|
| 921 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
| 922 |
+
],
|
| 923 |
+
"Annotators:Depth": [
|
| 924 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
| 925 |
+
],
|
| 926 |
+
"Annotators:Softedge": [
|
| 927 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
| 928 |
+
],
|
| 929 |
+
"Annotators:Lineart": [
|
| 930 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
| 931 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
| 932 |
+
],
|
| 933 |
+
"Annotators:Normal": [
|
| 934 |
+
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
|
| 935 |
+
],
|
| 936 |
+
"Annotators:Openpose": [
|
| 937 |
+
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
|
| 938 |
+
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
|
| 939 |
+
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
|
| 940 |
+
],
|
| 941 |
+
# AnimateDiff
|
| 942 |
+
"AnimateDiff_v2": [
|
| 943 |
+
(
|
| 944 |
+
"Shanghai_AI_Laboratory/animatediff",
|
| 945 |
+
"mm_sd_v15_v2.ckpt",
|
| 946 |
+
"models/AnimateDiff",
|
| 947 |
+
),
|
| 948 |
+
],
|
| 949 |
+
"AnimateDiff_xl_beta": [
|
| 950 |
+
(
|
| 951 |
+
"Shanghai_AI_Laboratory/animatediff",
|
| 952 |
+
"mm_sdxl_v10_beta.ckpt",
|
| 953 |
+
"models/AnimateDiff",
|
| 954 |
+
),
|
| 955 |
+
],
|
| 956 |
+
# RIFE
|
| 957 |
+
"RIFE": [
|
| 958 |
+
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
| 959 |
+
],
|
| 960 |
+
# Qwen Prompt
|
| 961 |
+
"QwenPrompt": {
|
| 962 |
+
"file_list": [
|
| 963 |
+
(
|
| 964 |
+
"qwen/Qwen2-1.5B-Instruct",
|
| 965 |
+
"config.json",
|
| 966 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 967 |
+
),
|
| 968 |
+
(
|
| 969 |
+
"qwen/Qwen2-1.5B-Instruct",
|
| 970 |
+
"generation_config.json",
|
| 971 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 972 |
+
),
|
| 973 |
+
(
|
| 974 |
+
"qwen/Qwen2-1.5B-Instruct",
|
| 975 |
+
"model.safetensors",
|
| 976 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 977 |
+
),
|
| 978 |
+
(
|
| 979 |
+
"qwen/Qwen2-1.5B-Instruct",
|
| 980 |
+
"special_tokens_map.json",
|
| 981 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 982 |
+
),
|
| 983 |
+
(
|
| 984 |
+
"qwen/Qwen2-1.5B-Instruct",
|
| 985 |
+
"tokenizer.json",
|
| 986 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 987 |
+
),
|
| 988 |
+
(
|
| 989 |
+
"qwen/Qwen2-1.5B-Instruct",
|
| 990 |
+
"tokenizer_config.json",
|
| 991 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 992 |
+
),
|
| 993 |
+
(
|
| 994 |
+
"qwen/Qwen2-1.5B-Instruct",
|
| 995 |
+
"merges.txt",
|
| 996 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 997 |
+
),
|
| 998 |
+
(
|
| 999 |
+
"qwen/Qwen2-1.5B-Instruct",
|
| 1000 |
+
"vocab.json",
|
| 1001 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 1002 |
+
),
|
| 1003 |
+
],
|
| 1004 |
+
"load_path": [
|
| 1005 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 1006 |
+
],
|
| 1007 |
+
},
|
| 1008 |
+
# Beautiful Prompt
|
| 1009 |
+
"BeautifulPrompt": {
|
| 1010 |
+
"file_list": [
|
| 1011 |
+
(
|
| 1012 |
+
"AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
|
| 1013 |
+
"config.json",
|
| 1014 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 1015 |
+
),
|
| 1016 |
+
(
|
| 1017 |
+
"AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
|
| 1018 |
+
"generation_config.json",
|
| 1019 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 1020 |
+
),
|
| 1021 |
+
(
|
| 1022 |
+
"AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
|
| 1023 |
+
"model.safetensors",
|
| 1024 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 1025 |
+
),
|
| 1026 |
+
(
|
| 1027 |
+
"AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
|
| 1028 |
+
"special_tokens_map.json",
|
| 1029 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 1030 |
+
),
|
| 1031 |
+
(
|
| 1032 |
+
"AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
|
| 1033 |
+
"tokenizer.json",
|
| 1034 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 1035 |
+
),
|
| 1036 |
+
(
|
| 1037 |
+
"AI-ModelScope/pai-bloom-1b1-text2prompt-sd",
|
| 1038 |
+
"tokenizer_config.json",
|
| 1039 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 1040 |
+
),
|
| 1041 |
+
],
|
| 1042 |
+
"load_path": [
|
| 1043 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 1044 |
+
],
|
| 1045 |
+
},
|
| 1046 |
+
# Omost prompt
|
| 1047 |
+
"OmostPrompt": {
|
| 1048 |
+
"file_list": [
|
| 1049 |
+
(
|
| 1050 |
+
"Omost/omost-llama-3-8b-4bits",
|
| 1051 |
+
"model-00001-of-00002.safetensors",
|
| 1052 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 1053 |
+
),
|
| 1054 |
+
(
|
| 1055 |
+
"Omost/omost-llama-3-8b-4bits",
|
| 1056 |
+
"model-00002-of-00002.safetensors",
|
| 1057 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 1058 |
+
),
|
| 1059 |
+
(
|
| 1060 |
+
"Omost/omost-llama-3-8b-4bits",
|
| 1061 |
+
"tokenizer.json",
|
| 1062 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 1063 |
+
),
|
| 1064 |
+
(
|
| 1065 |
+
"Omost/omost-llama-3-8b-4bits",
|
| 1066 |
+
"tokenizer_config.json",
|
| 1067 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 1068 |
+
),
|
| 1069 |
+
(
|
| 1070 |
+
"Omost/omost-llama-3-8b-4bits",
|
| 1071 |
+
"config.json",
|
| 1072 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 1073 |
+
),
|
| 1074 |
+
(
|
| 1075 |
+
"Omost/omost-llama-3-8b-4bits",
|
| 1076 |
+
"generation_config.json",
|
| 1077 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 1078 |
+
),
|
| 1079 |
+
(
|
| 1080 |
+
"Omost/omost-llama-3-8b-4bits",
|
| 1081 |
+
"model.safetensors.index.json",
|
| 1082 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 1083 |
+
),
|
| 1084 |
+
(
|
| 1085 |
+
"Omost/omost-llama-3-8b-4bits",
|
| 1086 |
+
"special_tokens_map.json",
|
| 1087 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 1088 |
+
),
|
| 1089 |
+
],
|
| 1090 |
+
"load_path": [
|
| 1091 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 1092 |
+
],
|
| 1093 |
+
},
|
| 1094 |
+
# Translator
|
| 1095 |
+
"opus-mt-zh-en": {
|
| 1096 |
+
"file_list": [
|
| 1097 |
+
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
| 1098 |
+
(
|
| 1099 |
+
"moxying/opus-mt-zh-en",
|
| 1100 |
+
"generation_config.json",
|
| 1101 |
+
"models/translator/opus-mt-zh-en",
|
| 1102 |
+
),
|
| 1103 |
+
(
|
| 1104 |
+
"moxying/opus-mt-zh-en",
|
| 1105 |
+
"metadata.json",
|
| 1106 |
+
"models/translator/opus-mt-zh-en",
|
| 1107 |
+
),
|
| 1108 |
+
(
|
| 1109 |
+
"moxying/opus-mt-zh-en",
|
| 1110 |
+
"pytorch_model.bin",
|
| 1111 |
+
"models/translator/opus-mt-zh-en",
|
| 1112 |
+
),
|
| 1113 |
+
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
| 1114 |
+
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
| 1115 |
+
(
|
| 1116 |
+
"moxying/opus-mt-zh-en",
|
| 1117 |
+
"tokenizer_config.json",
|
| 1118 |
+
"models/translator/opus-mt-zh-en",
|
| 1119 |
+
),
|
| 1120 |
+
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
| 1121 |
+
],
|
| 1122 |
+
"load_path": [
|
| 1123 |
+
"models/translator/opus-mt-zh-en",
|
| 1124 |
+
],
|
| 1125 |
+
},
|
| 1126 |
+
# IP-Adapter
|
| 1127 |
+
"IP-Adapter-SD": [
|
| 1128 |
+
(
|
| 1129 |
+
"AI-ModelScope/IP-Adapter",
|
| 1130 |
+
"models/image_encoder/model.safetensors",
|
| 1131 |
+
"models/IpAdapter/stable_diffusion/image_encoder",
|
| 1132 |
+
),
|
| 1133 |
+
(
|
| 1134 |
+
"AI-ModelScope/IP-Adapter",
|
| 1135 |
+
"models/ip-adapter_sd15.bin",
|
| 1136 |
+
"models/IpAdapter/stable_diffusion",
|
| 1137 |
+
),
|
| 1138 |
+
],
|
| 1139 |
+
"IP-Adapter-SDXL": [
|
| 1140 |
+
(
|
| 1141 |
+
"AI-ModelScope/IP-Adapter",
|
| 1142 |
+
"sdxl_models/image_encoder/model.safetensors",
|
| 1143 |
+
"models/IpAdapter/stable_diffusion_xl/image_encoder",
|
| 1144 |
+
),
|
| 1145 |
+
(
|
| 1146 |
+
"AI-ModelScope/IP-Adapter",
|
| 1147 |
+
"sdxl_models/ip-adapter_sdxl.bin",
|
| 1148 |
+
"models/IpAdapter/stable_diffusion_xl",
|
| 1149 |
+
),
|
| 1150 |
+
],
|
| 1151 |
+
# Kolors
|
| 1152 |
+
"Kolors": {
|
| 1153 |
+
"file_list": [
|
| 1154 |
+
(
|
| 1155 |
+
"Kwai-Kolors/Kolors",
|
| 1156 |
+
"text_encoder/config.json",
|
| 1157 |
+
"models/kolors/Kolors/text_encoder",
|
| 1158 |
+
),
|
| 1159 |
+
(
|
| 1160 |
+
"Kwai-Kolors/Kolors",
|
| 1161 |
+
"text_encoder/pytorch_model.bin.index.json",
|
| 1162 |
+
"models/kolors/Kolors/text_encoder",
|
| 1163 |
+
),
|
| 1164 |
+
(
|
| 1165 |
+
"Kwai-Kolors/Kolors",
|
| 1166 |
+
"text_encoder/pytorch_model-00001-of-00007.bin",
|
| 1167 |
+
"models/kolors/Kolors/text_encoder",
|
| 1168 |
+
),
|
| 1169 |
+
(
|
| 1170 |
+
"Kwai-Kolors/Kolors",
|
| 1171 |
+
"text_encoder/pytorch_model-00002-of-00007.bin",
|
| 1172 |
+
"models/kolors/Kolors/text_encoder",
|
| 1173 |
+
),
|
| 1174 |
+
(
|
| 1175 |
+
"Kwai-Kolors/Kolors",
|
| 1176 |
+
"text_encoder/pytorch_model-00003-of-00007.bin",
|
| 1177 |
+
"models/kolors/Kolors/text_encoder",
|
| 1178 |
+
),
|
| 1179 |
+
(
|
| 1180 |
+
"Kwai-Kolors/Kolors",
|
| 1181 |
+
"text_encoder/pytorch_model-00004-of-00007.bin",
|
| 1182 |
+
"models/kolors/Kolors/text_encoder",
|
| 1183 |
+
),
|
| 1184 |
+
(
|
| 1185 |
+
"Kwai-Kolors/Kolors",
|
| 1186 |
+
"text_encoder/pytorch_model-00005-of-00007.bin",
|
| 1187 |
+
"models/kolors/Kolors/text_encoder",
|
| 1188 |
+
),
|
| 1189 |
+
(
|
| 1190 |
+
"Kwai-Kolors/Kolors",
|
| 1191 |
+
"text_encoder/pytorch_model-00006-of-00007.bin",
|
| 1192 |
+
"models/kolors/Kolors/text_encoder",
|
| 1193 |
+
),
|
| 1194 |
+
(
|
| 1195 |
+
"Kwai-Kolors/Kolors",
|
| 1196 |
+
"text_encoder/pytorch_model-00007-of-00007.bin",
|
| 1197 |
+
"models/kolors/Kolors/text_encoder",
|
| 1198 |
+
),
|
| 1199 |
+
(
|
| 1200 |
+
"Kwai-Kolors/Kolors",
|
| 1201 |
+
"unet/diffusion_pytorch_model.safetensors",
|
| 1202 |
+
"models/kolors/Kolors/unet",
|
| 1203 |
+
),
|
| 1204 |
+
(
|
| 1205 |
+
"Kwai-Kolors/Kolors",
|
| 1206 |
+
"vae/diffusion_pytorch_model.safetensors",
|
| 1207 |
+
"models/kolors/Kolors/vae",
|
| 1208 |
+
),
|
| 1209 |
+
],
|
| 1210 |
+
"load_path": [
|
| 1211 |
+
"models/kolors/Kolors/text_encoder",
|
| 1212 |
+
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
| 1213 |
+
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
| 1214 |
+
],
|
| 1215 |
+
},
|
| 1216 |
+
"SDXL-vae-fp16-fix": [
|
| 1217 |
+
(
|
| 1218 |
+
"AI-ModelScope/sdxl-vae-fp16-fix",
|
| 1219 |
+
"diffusion_pytorch_model.safetensors",
|
| 1220 |
+
"models/sdxl-vae-fp16-fix",
|
| 1221 |
+
)
|
| 1222 |
+
],
|
| 1223 |
+
# FLUX
|
| 1224 |
+
"FLUX.1-dev": {
|
| 1225 |
+
"file_list": [
|
| 1226 |
+
(
|
| 1227 |
+
"AI-ModelScope/FLUX.1-dev",
|
| 1228 |
+
"text_encoder/model.safetensors",
|
| 1229 |
+
"models/FLUX/FLUX.1-dev/text_encoder",
|
| 1230 |
+
),
|
| 1231 |
+
(
|
| 1232 |
+
"AI-ModelScope/FLUX.1-dev",
|
| 1233 |
+
"text_encoder_2/config.json",
|
| 1234 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 1235 |
+
),
|
| 1236 |
+
(
|
| 1237 |
+
"AI-ModelScope/FLUX.1-dev",
|
| 1238 |
+
"text_encoder_2/model-00001-of-00002.safetensors",
|
| 1239 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 1240 |
+
),
|
| 1241 |
+
(
|
| 1242 |
+
"AI-ModelScope/FLUX.1-dev",
|
| 1243 |
+
"text_encoder_2/model-00002-of-00002.safetensors",
|
| 1244 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 1245 |
+
),
|
| 1246 |
+
(
|
| 1247 |
+
"AI-ModelScope/FLUX.1-dev",
|
| 1248 |
+
"text_encoder_2/model.safetensors.index.json",
|
| 1249 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 1250 |
+
),
|
| 1251 |
+
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 1252 |
+
(
|
| 1253 |
+
"AI-ModelScope/FLUX.1-dev",
|
| 1254 |
+
"flux1-dev.safetensors",
|
| 1255 |
+
"models/FLUX/FLUX.1-dev",
|
| 1256 |
+
),
|
| 1257 |
+
],
|
| 1258 |
+
"load_path": [
|
| 1259 |
+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
| 1260 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 1261 |
+
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
| 1262 |
+
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
|
| 1263 |
+
],
|
| 1264 |
+
},
|
| 1265 |
+
"FLUX.1-schnell": {
|
| 1266 |
+
"file_list": [
|
| 1267 |
+
(
|
| 1268 |
+
"AI-ModelScope/FLUX.1-dev",
|
| 1269 |
+
"text_encoder/model.safetensors",
|
| 1270 |
+
"models/FLUX/FLUX.1-dev/text_encoder",
|
| 1271 |
+
),
|
| 1272 |
+
(
|
| 1273 |
+
"AI-ModelScope/FLUX.1-dev",
|
| 1274 |
+
"text_encoder_2/config.json",
|
| 1275 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 1276 |
+
),
|
| 1277 |
+
(
|
| 1278 |
+
"AI-ModelScope/FLUX.1-dev",
|
| 1279 |
+
"text_encoder_2/model-00001-of-00002.safetensors",
|
| 1280 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 1281 |
+
),
|
| 1282 |
+
(
|
| 1283 |
+
"AI-ModelScope/FLUX.1-dev",
|
| 1284 |
+
"text_encoder_2/model-00002-of-00002.safetensors",
|
| 1285 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 1286 |
+
),
|
| 1287 |
+
(
|
| 1288 |
+
"AI-ModelScope/FLUX.1-dev",
|
| 1289 |
+
"text_encoder_2/model.safetensors.index.json",
|
| 1290 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 1291 |
+
),
|
| 1292 |
+
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 1293 |
+
(
|
| 1294 |
+
"AI-ModelScope/FLUX.1-schnell",
|
| 1295 |
+
"flux1-schnell.safetensors",
|
| 1296 |
+
"models/FLUX/FLUX.1-schnell",
|
| 1297 |
+
),
|
| 1298 |
+
],
|
| 1299 |
+
"load_path": [
|
| 1300 |
+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
| 1301 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 1302 |
+
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
| 1303 |
+
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors",
|
| 1304 |
+
],
|
| 1305 |
+
},
|
| 1306 |
+
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
|
| 1307 |
+
(
|
| 1308 |
+
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
| 1309 |
+
"diffusion_pytorch_model.safetensors",
|
| 1310 |
+
"models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
| 1311 |
+
),
|
| 1312 |
+
],
|
| 1313 |
+
"jasperai/Flux.1-dev-Controlnet-Depth": [
|
| 1314 |
+
(
|
| 1315 |
+
"jasperai/Flux.1-dev-Controlnet-Depth",
|
| 1316 |
+
"diffusion_pytorch_model.safetensors",
|
| 1317 |
+
"models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth",
|
| 1318 |
+
),
|
| 1319 |
+
],
|
| 1320 |
+
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
|
| 1321 |
+
(
|
| 1322 |
+
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
| 1323 |
+
"diffusion_pytorch_model.safetensors",
|
| 1324 |
+
"models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
| 1325 |
+
),
|
| 1326 |
+
],
|
| 1327 |
+
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
|
| 1328 |
+
(
|
| 1329 |
+
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
| 1330 |
+
"diffusion_pytorch_model.safetensors",
|
| 1331 |
+
"models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler",
|
| 1332 |
+
),
|
| 1333 |
+
],
|
| 1334 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
|
| 1335 |
+
(
|
| 1336 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
| 1337 |
+
"diffusion_pytorch_model.safetensors",
|
| 1338 |
+
"models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
| 1339 |
+
),
|
| 1340 |
+
],
|
| 1341 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
|
| 1342 |
+
(
|
| 1343 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
| 1344 |
+
"diffusion_pytorch_model.safetensors",
|
| 1345 |
+
"models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
| 1346 |
+
),
|
| 1347 |
+
],
|
| 1348 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
|
| 1349 |
+
(
|
| 1350 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
| 1351 |
+
"diffusion_pytorch_model.safetensors",
|
| 1352 |
+
"models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
| 1353 |
+
),
|
| 1354 |
+
],
|
| 1355 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
| 1356 |
+
(
|
| 1357 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
| 1358 |
+
"diffusion_pytorch_model.safetensors",
|
| 1359 |
+
"models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
| 1360 |
+
),
|
| 1361 |
+
],
|
| 1362 |
+
"InstantX/FLUX.1-dev-IP-Adapter": {
|
| 1363 |
+
"file_list": [
|
| 1364 |
+
(
|
| 1365 |
+
"InstantX/FLUX.1-dev-IP-Adapter",
|
| 1366 |
+
"ip-adapter.bin",
|
| 1367 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter",
|
| 1368 |
+
),
|
| 1369 |
+
(
|
| 1370 |
+
"AI-ModelScope/siglip-so400m-patch14-384",
|
| 1371 |
+
"model.safetensors",
|
| 1372 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
| 1373 |
+
),
|
| 1374 |
+
(
|
| 1375 |
+
"AI-ModelScope/siglip-so400m-patch14-384",
|
| 1376 |
+
"config.json",
|
| 1377 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
| 1378 |
+
),
|
| 1379 |
+
],
|
| 1380 |
+
"load_path": [
|
| 1381 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
| 1382 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
| 1383 |
+
],
|
| 1384 |
+
},
|
| 1385 |
+
"InfiniteYou": {
|
| 1386 |
+
"file_list": [
|
| 1387 |
+
(
|
| 1388 |
+
"ByteDance/InfiniteYou",
|
| 1389 |
+
"infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
| 1390 |
+
"models/InfiniteYou/InfuseNetModel",
|
| 1391 |
+
),
|
| 1392 |
+
(
|
| 1393 |
+
"ByteDance/InfiniteYou",
|
| 1394 |
+
"infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors",
|
| 1395 |
+
"models/InfiniteYou/InfuseNetModel",
|
| 1396 |
+
),
|
| 1397 |
+
(
|
| 1398 |
+
"ByteDance/InfiniteYou",
|
| 1399 |
+
"infu_flux_v1.0/aes_stage2/image_proj_model.bin",
|
| 1400 |
+
"models/InfiniteYou",
|
| 1401 |
+
),
|
| 1402 |
+
(
|
| 1403 |
+
"ByteDance/InfiniteYou",
|
| 1404 |
+
"supports/insightface/models/antelopev2/1k3d68.onnx",
|
| 1405 |
+
"models/InfiniteYou/insightface/models/antelopev2",
|
| 1406 |
+
),
|
| 1407 |
+
(
|
| 1408 |
+
"ByteDance/InfiniteYou",
|
| 1409 |
+
"supports/insightface/models/antelopev2/2d106det.onnx",
|
| 1410 |
+
"models/InfiniteYou/insightface/models/antelopev2",
|
| 1411 |
+
),
|
| 1412 |
+
(
|
| 1413 |
+
"ByteDance/InfiniteYou",
|
| 1414 |
+
"supports/insightface/models/antelopev2/genderage.onnx",
|
| 1415 |
+
"models/InfiniteYou/insightface/models/antelopev2",
|
| 1416 |
+
),
|
| 1417 |
+
(
|
| 1418 |
+
"ByteDance/InfiniteYou",
|
| 1419 |
+
"supports/insightface/models/antelopev2/glintr100.onnx",
|
| 1420 |
+
"models/InfiniteYou/insightface/models/antelopev2",
|
| 1421 |
+
),
|
| 1422 |
+
(
|
| 1423 |
+
"ByteDance/InfiniteYou",
|
| 1424 |
+
"supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx",
|
| 1425 |
+
"models/InfiniteYou/insightface/models/antelopev2",
|
| 1426 |
+
),
|
| 1427 |
+
],
|
| 1428 |
+
"load_path": [
|
| 1429 |
+
[
|
| 1430 |
+
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
| 1431 |
+
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors",
|
| 1432 |
+
],
|
| 1433 |
+
"models/InfiniteYou/image_proj_model.bin",
|
| 1434 |
+
],
|
| 1435 |
+
},
|
| 1436 |
+
# ESRGAN
|
| 1437 |
+
"ESRGAN_x4": [
|
| 1438 |
+
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
| 1439 |
+
],
|
| 1440 |
+
# RIFE
|
| 1441 |
+
"RIFE": [
|
| 1442 |
+
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
| 1443 |
+
],
|
| 1444 |
+
# Omnigen
|
| 1445 |
+
"OmniGen-v1": {
|
| 1446 |
+
"file_list": [
|
| 1447 |
+
(
|
| 1448 |
+
"BAAI/OmniGen-v1",
|
| 1449 |
+
"vae/diffusion_pytorch_model.safetensors",
|
| 1450 |
+
"models/OmniGen/OmniGen-v1/vae",
|
| 1451 |
+
),
|
| 1452 |
+
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
|
| 1453 |
+
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
|
| 1454 |
+
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
|
| 1455 |
+
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
|
| 1456 |
+
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
|
| 1457 |
+
],
|
| 1458 |
+
"load_path": [
|
| 1459 |
+
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
|
| 1460 |
+
"models/OmniGen/OmniGen-v1/model.safetensors",
|
| 1461 |
+
],
|
| 1462 |
+
},
|
| 1463 |
+
# CogVideo
|
| 1464 |
+
"CogVideoX-5B": {
|
| 1465 |
+
"file_list": [
|
| 1466 |
+
(
|
| 1467 |
+
"ZhipuAI/CogVideoX-5b",
|
| 1468 |
+
"text_encoder/config.json",
|
| 1469 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
| 1470 |
+
),
|
| 1471 |
+
(
|
| 1472 |
+
"ZhipuAI/CogVideoX-5b",
|
| 1473 |
+
"text_encoder/model.safetensors.index.json",
|
| 1474 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
| 1475 |
+
),
|
| 1476 |
+
(
|
| 1477 |
+
"ZhipuAI/CogVideoX-5b",
|
| 1478 |
+
"text_encoder/model-00001-of-00002.safetensors",
|
| 1479 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
| 1480 |
+
),
|
| 1481 |
+
(
|
| 1482 |
+
"ZhipuAI/CogVideoX-5b",
|
| 1483 |
+
"text_encoder/model-00002-of-00002.safetensors",
|
| 1484 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
| 1485 |
+
),
|
| 1486 |
+
(
|
| 1487 |
+
"ZhipuAI/CogVideoX-5b",
|
| 1488 |
+
"transformer/config.json",
|
| 1489 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
| 1490 |
+
),
|
| 1491 |
+
(
|
| 1492 |
+
"ZhipuAI/CogVideoX-5b",
|
| 1493 |
+
"transformer/diffusion_pytorch_model.safetensors.index.json",
|
| 1494 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
| 1495 |
+
),
|
| 1496 |
+
(
|
| 1497 |
+
"ZhipuAI/CogVideoX-5b",
|
| 1498 |
+
"transformer/diffusion_pytorch_model-00001-of-00002.safetensors",
|
| 1499 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
| 1500 |
+
),
|
| 1501 |
+
(
|
| 1502 |
+
"ZhipuAI/CogVideoX-5b",
|
| 1503 |
+
"transformer/diffusion_pytorch_model-00002-of-00002.safetensors",
|
| 1504 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
| 1505 |
+
),
|
| 1506 |
+
(
|
| 1507 |
+
"ZhipuAI/CogVideoX-5b",
|
| 1508 |
+
"vae/diffusion_pytorch_model.safetensors",
|
| 1509 |
+
"models/CogVideo/CogVideoX-5b/vae",
|
| 1510 |
+
),
|
| 1511 |
+
],
|
| 1512 |
+
"load_path": [
|
| 1513 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
| 1514 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
| 1515 |
+
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
| 1516 |
+
],
|
| 1517 |
+
},
|
| 1518 |
+
# Stable Diffusion 3.5
|
| 1519 |
+
"StableDiffusion3.5-large": [
|
| 1520 |
+
(
|
| 1521 |
+
"AI-ModelScope/stable-diffusion-3.5-large",
|
| 1522 |
+
"sd3.5_large.safetensors",
|
| 1523 |
+
"models/stable_diffusion_3",
|
| 1524 |
+
),
|
| 1525 |
+
(
|
| 1526 |
+
"AI-ModelScope/stable-diffusion-3.5-large",
|
| 1527 |
+
"text_encoders/clip_l.safetensors",
|
| 1528 |
+
"models/stable_diffusion_3/text_encoders",
|
| 1529 |
+
),
|
| 1530 |
+
(
|
| 1531 |
+
"AI-ModelScope/stable-diffusion-3.5-large",
|
| 1532 |
+
"text_encoders/clip_g.safetensors",
|
| 1533 |
+
"models/stable_diffusion_3/text_encoders",
|
| 1534 |
+
),
|
| 1535 |
+
(
|
| 1536 |
+
"AI-ModelScope/stable-diffusion-3.5-large",
|
| 1537 |
+
"text_encoders/t5xxl_fp16.safetensors",
|
| 1538 |
+
"models/stable_diffusion_3/text_encoders",
|
| 1539 |
+
),
|
| 1540 |
+
],
|
| 1541 |
+
"StableDiffusion3.5-medium": [
|
| 1542 |
+
(
|
| 1543 |
+
"AI-ModelScope/stable-diffusion-3.5-medium",
|
| 1544 |
+
"sd3.5_medium.safetensors",
|
| 1545 |
+
"models/stable_diffusion_3",
|
| 1546 |
+
),
|
| 1547 |
+
(
|
| 1548 |
+
"AI-ModelScope/stable-diffusion-3.5-large",
|
| 1549 |
+
"text_encoders/clip_l.safetensors",
|
| 1550 |
+
"models/stable_diffusion_3/text_encoders",
|
| 1551 |
+
),
|
| 1552 |
+
(
|
| 1553 |
+
"AI-ModelScope/stable-diffusion-3.5-large",
|
| 1554 |
+
"text_encoders/clip_g.safetensors",
|
| 1555 |
+
"models/stable_diffusion_3/text_encoders",
|
| 1556 |
+
),
|
| 1557 |
+
(
|
| 1558 |
+
"AI-ModelScope/stable-diffusion-3.5-large",
|
| 1559 |
+
"text_encoders/t5xxl_fp16.safetensors",
|
| 1560 |
+
"models/stable_diffusion_3/text_encoders",
|
| 1561 |
+
),
|
| 1562 |
+
],
|
| 1563 |
+
"StableDiffusion3.5-large-turbo": [
|
| 1564 |
+
(
|
| 1565 |
+
"AI-ModelScope/stable-diffusion-3.5-large-turbo",
|
| 1566 |
+
"sd3.5_large_turbo.safetensors",
|
| 1567 |
+
"models/stable_diffusion_3",
|
| 1568 |
+
),
|
| 1569 |
+
(
|
| 1570 |
+
"AI-ModelScope/stable-diffusion-3.5-large",
|
| 1571 |
+
"text_encoders/clip_l.safetensors",
|
| 1572 |
+
"models/stable_diffusion_3/text_encoders",
|
| 1573 |
+
),
|
| 1574 |
+
(
|
| 1575 |
+
"AI-ModelScope/stable-diffusion-3.5-large",
|
| 1576 |
+
"text_encoders/clip_g.safetensors",
|
| 1577 |
+
"models/stable_diffusion_3/text_encoders",
|
| 1578 |
+
),
|
| 1579 |
+
(
|
| 1580 |
+
"AI-ModelScope/stable-diffusion-3.5-large",
|
| 1581 |
+
"text_encoders/t5xxl_fp16.safetensors",
|
| 1582 |
+
"models/stable_diffusion_3/text_encoders",
|
| 1583 |
+
),
|
| 1584 |
+
],
|
| 1585 |
+
"HunyuanVideo": {
|
| 1586 |
+
"file_list": [
|
| 1587 |
+
(
|
| 1588 |
+
"AI-ModelScope/clip-vit-large-patch14",
|
| 1589 |
+
"model.safetensors",
|
| 1590 |
+
"models/HunyuanVideo/text_encoder",
|
| 1591 |
+
),
|
| 1592 |
+
(
|
| 1593 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1594 |
+
"model-00001-of-00004.safetensors",
|
| 1595 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1596 |
+
),
|
| 1597 |
+
(
|
| 1598 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1599 |
+
"model-00002-of-00004.safetensors",
|
| 1600 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1601 |
+
),
|
| 1602 |
+
(
|
| 1603 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1604 |
+
"model-00003-of-00004.safetensors",
|
| 1605 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1606 |
+
),
|
| 1607 |
+
(
|
| 1608 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1609 |
+
"model-00004-of-00004.safetensors",
|
| 1610 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1611 |
+
),
|
| 1612 |
+
(
|
| 1613 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1614 |
+
"config.json",
|
| 1615 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1616 |
+
),
|
| 1617 |
+
(
|
| 1618 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1619 |
+
"model.safetensors.index.json",
|
| 1620 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1621 |
+
),
|
| 1622 |
+
(
|
| 1623 |
+
"AI-ModelScope/HunyuanVideo",
|
| 1624 |
+
"hunyuan-video-t2v-720p/vae/pytorch_model.pt",
|
| 1625 |
+
"models/HunyuanVideo/vae",
|
| 1626 |
+
),
|
| 1627 |
+
(
|
| 1628 |
+
"AI-ModelScope/HunyuanVideo",
|
| 1629 |
+
"hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
|
| 1630 |
+
"models/HunyuanVideo/transformers",
|
| 1631 |
+
),
|
| 1632 |
+
],
|
| 1633 |
+
"load_path": [
|
| 1634 |
+
"models/HunyuanVideo/text_encoder/model.safetensors",
|
| 1635 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1636 |
+
"models/HunyuanVideo/vae/pytorch_model.pt",
|
| 1637 |
+
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt",
|
| 1638 |
+
],
|
| 1639 |
+
},
|
| 1640 |
+
"HunyuanVideoI2V": {
|
| 1641 |
+
"file_list": [
|
| 1642 |
+
(
|
| 1643 |
+
"AI-ModelScope/clip-vit-large-patch14",
|
| 1644 |
+
"model.safetensors",
|
| 1645 |
+
"models/HunyuanVideoI2V/text_encoder",
|
| 1646 |
+
),
|
| 1647 |
+
(
|
| 1648 |
+
"AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
|
| 1649 |
+
"model-00001-of-00004.safetensors",
|
| 1650 |
+
"models/HunyuanVideoI2V/text_encoder_2",
|
| 1651 |
+
),
|
| 1652 |
+
(
|
| 1653 |
+
"AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
|
| 1654 |
+
"model-00002-of-00004.safetensors",
|
| 1655 |
+
"models/HunyuanVideoI2V/text_encoder_2",
|
| 1656 |
+
),
|
| 1657 |
+
(
|
| 1658 |
+
"AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
|
| 1659 |
+
"model-00003-of-00004.safetensors",
|
| 1660 |
+
"models/HunyuanVideoI2V/text_encoder_2",
|
| 1661 |
+
),
|
| 1662 |
+
(
|
| 1663 |
+
"AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
|
| 1664 |
+
"model-00004-of-00004.safetensors",
|
| 1665 |
+
"models/HunyuanVideoI2V/text_encoder_2",
|
| 1666 |
+
),
|
| 1667 |
+
(
|
| 1668 |
+
"AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
|
| 1669 |
+
"config.json",
|
| 1670 |
+
"models/HunyuanVideoI2V/text_encoder_2",
|
| 1671 |
+
),
|
| 1672 |
+
(
|
| 1673 |
+
"AI-ModelScope/llava-llama-3-8b-v1_1-transformers",
|
| 1674 |
+
"model.safetensors.index.json",
|
| 1675 |
+
"models/HunyuanVideoI2V/text_encoder_2",
|
| 1676 |
+
),
|
| 1677 |
+
(
|
| 1678 |
+
"AI-ModelScope/HunyuanVideo-I2V",
|
| 1679 |
+
"hunyuan-video-i2v-720p/vae/pytorch_model.pt",
|
| 1680 |
+
"models/HunyuanVideoI2V/vae",
|
| 1681 |
+
),
|
| 1682 |
+
(
|
| 1683 |
+
"AI-ModelScope/HunyuanVideo-I2V",
|
| 1684 |
+
"hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt",
|
| 1685 |
+
"models/HunyuanVideoI2V/transformers",
|
| 1686 |
+
),
|
| 1687 |
+
],
|
| 1688 |
+
"load_path": [
|
| 1689 |
+
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
| 1690 |
+
"models/HunyuanVideoI2V/text_encoder_2",
|
| 1691 |
+
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
|
| 1692 |
+
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt",
|
| 1693 |
+
],
|
| 1694 |
+
},
|
| 1695 |
+
"HunyuanVideo-fp8": {
|
| 1696 |
+
"file_list": [
|
| 1697 |
+
(
|
| 1698 |
+
"AI-ModelScope/clip-vit-large-patch14",
|
| 1699 |
+
"model.safetensors",
|
| 1700 |
+
"models/HunyuanVideo/text_encoder",
|
| 1701 |
+
),
|
| 1702 |
+
(
|
| 1703 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1704 |
+
"model-00001-of-00004.safetensors",
|
| 1705 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1706 |
+
),
|
| 1707 |
+
(
|
| 1708 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1709 |
+
"model-00002-of-00004.safetensors",
|
| 1710 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1711 |
+
),
|
| 1712 |
+
(
|
| 1713 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1714 |
+
"model-00003-of-00004.safetensors",
|
| 1715 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1716 |
+
),
|
| 1717 |
+
(
|
| 1718 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1719 |
+
"model-00004-of-00004.safetensors",
|
| 1720 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1721 |
+
),
|
| 1722 |
+
(
|
| 1723 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1724 |
+
"config.json",
|
| 1725 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1726 |
+
),
|
| 1727 |
+
(
|
| 1728 |
+
"DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder",
|
| 1729 |
+
"model.safetensors.index.json",
|
| 1730 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1731 |
+
),
|
| 1732 |
+
(
|
| 1733 |
+
"AI-ModelScope/HunyuanVideo",
|
| 1734 |
+
"hunyuan-video-t2v-720p/vae/pytorch_model.pt",
|
| 1735 |
+
"models/HunyuanVideo/vae",
|
| 1736 |
+
),
|
| 1737 |
+
(
|
| 1738 |
+
"DiffSynth-Studio/HunyuanVideo-safetensors",
|
| 1739 |
+
"model.fp8.safetensors",
|
| 1740 |
+
"models/HunyuanVideo/transformers",
|
| 1741 |
+
),
|
| 1742 |
+
],
|
| 1743 |
+
"load_path": [
|
| 1744 |
+
"models/HunyuanVideo/text_encoder/model.safetensors",
|
| 1745 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 1746 |
+
"models/HunyuanVideo/vae/pytorch_model.pt",
|
| 1747 |
+
"models/HunyuanVideo/transformers/model.fp8.safetensors",
|
| 1748 |
+
],
|
| 1749 |
+
},
|
| 1750 |
+
}
|
| 1751 |
+
Preset_model_id: TypeAlias = Literal[
|
| 1752 |
+
"HunyuanDiT",
|
| 1753 |
+
"stable-video-diffusion-img2vid-xt",
|
| 1754 |
+
"ExVideo-SVD-128f-v1",
|
| 1755 |
+
"ExVideo-CogVideoX-LoRA-129f-v1",
|
| 1756 |
+
"StableDiffusion_v15",
|
| 1757 |
+
"DreamShaper_8",
|
| 1758 |
+
"AingDiffusion_v12",
|
| 1759 |
+
"Flat2DAnimerge_v45Sharp",
|
| 1760 |
+
"TextualInversion_VeryBadImageNegative_v1.3",
|
| 1761 |
+
"StableDiffusionXL_v1",
|
| 1762 |
+
"BluePencilXL_v200",
|
| 1763 |
+
"StableDiffusionXL_Turbo",
|
| 1764 |
+
"ControlNet_v11f1p_sd15_depth",
|
| 1765 |
+
"ControlNet_v11p_sd15_softedge",
|
| 1766 |
+
"ControlNet_v11f1e_sd15_tile",
|
| 1767 |
+
"ControlNet_v11p_sd15_lineart",
|
| 1768 |
+
"AnimateDiff_v2",
|
| 1769 |
+
"AnimateDiff_xl_beta",
|
| 1770 |
+
"RIFE",
|
| 1771 |
+
"BeautifulPrompt",
|
| 1772 |
+
"opus-mt-zh-en",
|
| 1773 |
+
"IP-Adapter-SD",
|
| 1774 |
+
"IP-Adapter-SDXL",
|
| 1775 |
+
"StableDiffusion3",
|
| 1776 |
+
"StableDiffusion3_without_T5",
|
| 1777 |
+
"Kolors",
|
| 1778 |
+
"SDXL-vae-fp16-fix",
|
| 1779 |
+
"ControlNet_union_sdxl_promax",
|
| 1780 |
+
"FLUX.1-dev",
|
| 1781 |
+
"FLUX.1-schnell",
|
| 1782 |
+
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
| 1783 |
+
"jasperai/Flux.1-dev-Controlnet-Depth",
|
| 1784 |
+
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
| 1785 |
+
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
| 1786 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
| 1787 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
| 1788 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
| 1789 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
| 1790 |
+
"InstantX/FLUX.1-dev-IP-Adapter",
|
| 1791 |
+
"InfiniteYou",
|
| 1792 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
| 1793 |
+
"QwenPrompt",
|
| 1794 |
+
"OmostPrompt",
|
| 1795 |
+
"ESRGAN_x4",
|
| 1796 |
+
"RIFE",
|
| 1797 |
+
"OmniGen-v1",
|
| 1798 |
+
"CogVideoX-5B",
|
| 1799 |
+
"Annotators:Depth",
|
| 1800 |
+
"Annotators:Softedge",
|
| 1801 |
+
"Annotators:Lineart",
|
| 1802 |
+
"Annotators:Normal",
|
| 1803 |
+
"Annotators:Openpose",
|
| 1804 |
+
"StableDiffusion3.5-large",
|
| 1805 |
+
"StableDiffusion3.5-medium",
|
| 1806 |
+
"HunyuanVideo",
|
| 1807 |
+
"HunyuanVideo-fp8",
|
| 1808 |
+
"HunyuanVideoI2V",
|
| 1809 |
+
]
|
data/video.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import imageio, os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LowMemoryVideo:
|
| 8 |
+
def __init__(self, file_name):
|
| 9 |
+
self.reader = imageio.get_reader(file_name)
|
| 10 |
+
|
| 11 |
+
def __len__(self):
|
| 12 |
+
return self.reader.count_frames()
|
| 13 |
+
|
| 14 |
+
def __getitem__(self, item):
|
| 15 |
+
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
| 16 |
+
|
| 17 |
+
def __del__(self):
|
| 18 |
+
self.reader.close()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def split_file_name(file_name):
|
| 22 |
+
result = []
|
| 23 |
+
number = -1
|
| 24 |
+
for i in file_name:
|
| 25 |
+
if ord(i) >= ord("0") and ord(i) <= ord("9"):
|
| 26 |
+
if number == -1:
|
| 27 |
+
number = 0
|
| 28 |
+
number = number * 10 + ord(i) - ord("0")
|
| 29 |
+
else:
|
| 30 |
+
if number != -1:
|
| 31 |
+
result.append(number)
|
| 32 |
+
number = -1
|
| 33 |
+
result.append(i)
|
| 34 |
+
if number != -1:
|
| 35 |
+
result.append(number)
|
| 36 |
+
result = tuple(result)
|
| 37 |
+
return result
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def search_for_images(folder):
|
| 41 |
+
file_list = [
|
| 42 |
+
i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")
|
| 43 |
+
]
|
| 44 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
| 45 |
+
file_list = [i[1] for i in sorted(file_list)]
|
| 46 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
| 47 |
+
return file_list
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class LowMemoryImageFolder:
|
| 51 |
+
def __init__(self, folder, file_list=None):
|
| 52 |
+
if file_list is None:
|
| 53 |
+
self.file_list = search_for_images(folder)
|
| 54 |
+
else:
|
| 55 |
+
self.file_list = [
|
| 56 |
+
os.path.join(folder, file_name) for file_name in file_list
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
def __len__(self):
|
| 60 |
+
return len(self.file_list)
|
| 61 |
+
|
| 62 |
+
def __getitem__(self, item):
|
| 63 |
+
return Image.open(self.file_list[item]).convert("RGB")
|
| 64 |
+
|
| 65 |
+
def __del__(self):
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def crop_and_resize(image, height, width):
|
| 70 |
+
image = np.array(image)
|
| 71 |
+
image_height, image_width, _ = image.shape
|
| 72 |
+
if image_height / image_width < height / width:
|
| 73 |
+
croped_width = int(image_height / height * width)
|
| 74 |
+
left = (image_width - croped_width) // 2
|
| 75 |
+
image = image[:, left : left + croped_width]
|
| 76 |
+
image = Image.fromarray(image).resize((width, height))
|
| 77 |
+
else:
|
| 78 |
+
croped_height = int(image_width / width * height)
|
| 79 |
+
left = (image_height - croped_height) // 2
|
| 80 |
+
image = image[left : left + croped_height, :]
|
| 81 |
+
image = Image.fromarray(image).resize((width, height))
|
| 82 |
+
return image
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class VideoData:
|
| 86 |
+
def __init__(
|
| 87 |
+
self, video_file=None, image_folder=None, height=None, width=None, **kwargs
|
| 88 |
+
):
|
| 89 |
+
if video_file is not None:
|
| 90 |
+
self.data_type = "video"
|
| 91 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
| 92 |
+
elif image_folder is not None:
|
| 93 |
+
self.data_type = "images"
|
| 94 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError("Cannot open video or image folder")
|
| 97 |
+
self.length = None
|
| 98 |
+
self.set_shape(height, width)
|
| 99 |
+
|
| 100 |
+
def raw_data(self):
|
| 101 |
+
frames = []
|
| 102 |
+
for i in range(self.__len__()):
|
| 103 |
+
frames.append(self.__getitem__(i))
|
| 104 |
+
return frames
|
| 105 |
+
|
| 106 |
+
def set_length(self, length):
|
| 107 |
+
self.length = length
|
| 108 |
+
|
| 109 |
+
def set_shape(self, height, width):
|
| 110 |
+
self.height = height
|
| 111 |
+
self.width = width
|
| 112 |
+
|
| 113 |
+
def __len__(self):
|
| 114 |
+
if self.length is None:
|
| 115 |
+
return len(self.data)
|
| 116 |
+
else:
|
| 117 |
+
return self.length
|
| 118 |
+
|
| 119 |
+
def shape(self):
|
| 120 |
+
if self.height is not None and self.width is not None:
|
| 121 |
+
return self.height, self.width
|
| 122 |
+
else:
|
| 123 |
+
height, width, _ = self.__getitem__(0).shape
|
| 124 |
+
return height, width
|
| 125 |
+
|
| 126 |
+
def __getitem__(self, item):
|
| 127 |
+
frame = self.data.__getitem__(item)
|
| 128 |
+
width, height = frame.size
|
| 129 |
+
if self.height is not None and self.width is not None:
|
| 130 |
+
if self.height != height or self.width != width:
|
| 131 |
+
frame = crop_and_resize(frame, self.height, self.width)
|
| 132 |
+
return frame
|
| 133 |
+
|
| 134 |
+
def __del__(self):
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
def save_images(self, folder):
|
| 138 |
+
os.makedirs(folder, exist_ok=True)
|
| 139 |
+
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
| 140 |
+
frame = self.__getitem__(i)
|
| 141 |
+
frame.save(os.path.join(folder, f"{i}.png"))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
| 145 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 146 |
+
writer = imageio.get_writer(
|
| 147 |
+
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
|
| 148 |
+
)
|
| 149 |
+
for frame in tqdm(frames, desc="Saving video"):
|
| 150 |
+
frame = np.array(frame)
|
| 151 |
+
writer.append_data(frame)
|
| 152 |
+
writer.close()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def save_frames(frames, save_path):
|
| 156 |
+
os.makedirs(save_path, exist_ok=True)
|
| 157 |
+
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
| 158 |
+
frame.save(os.path.join(save_path, f"{i}.png"))
|
distributed/__init__.py
ADDED
|
File without changes
|
distributed/xdit_context_parallel.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from xfuser.core.distributed import (
|
| 5 |
+
get_sequence_parallel_rank,
|
| 6 |
+
get_sequence_parallel_world_size,
|
| 7 |
+
get_sp_group,
|
| 8 |
+
)
|
| 9 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 13 |
+
sinusoid = torch.outer(
|
| 14 |
+
position.type(torch.float64),
|
| 15 |
+
torch.pow(
|
| 16 |
+
10000,
|
| 17 |
+
-torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(
|
| 18 |
+
dim // 2
|
| 19 |
+
),
|
| 20 |
+
),
|
| 21 |
+
)
|
| 22 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 23 |
+
return x.to(position.dtype)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def pad_freqs(original_tensor, target_len):
|
| 27 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 28 |
+
pad_size = target_len - seq_len
|
| 29 |
+
padding_tensor = torch.ones(
|
| 30 |
+
pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device
|
| 31 |
+
)
|
| 32 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 33 |
+
return padded_tensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def rope_apply(x, freqs, num_heads):
|
| 37 |
+
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
| 38 |
+
s_per_rank = x.shape[1]
|
| 39 |
+
|
| 40 |
+
x_out = torch.view_as_complex(
|
| 41 |
+
x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
sp_size = get_sequence_parallel_world_size()
|
| 45 |
+
sp_rank = get_sequence_parallel_rank()
|
| 46 |
+
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
| 47 |
+
freqs_rank = freqs[(sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, :]
|
| 48 |
+
|
| 49 |
+
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
| 50 |
+
return x_out.to(x.dtype)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def usp_dit_forward(
|
| 54 |
+
self,
|
| 55 |
+
x: torch.Tensor,
|
| 56 |
+
timestep: torch.Tensor,
|
| 57 |
+
context: torch.Tensor,
|
| 58 |
+
clip_feature: Optional[torch.Tensor] = None,
|
| 59 |
+
y: Optional[torch.Tensor] = None,
|
| 60 |
+
use_gradient_checkpointing: bool = False,
|
| 61 |
+
use_gradient_checkpointing_offload: bool = False,
|
| 62 |
+
**kwargs,
|
| 63 |
+
):
|
| 64 |
+
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
| 65 |
+
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
| 66 |
+
context = self.text_embedding(context)
|
| 67 |
+
|
| 68 |
+
if self.has_image_input:
|
| 69 |
+
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
| 70 |
+
clip_embdding = self.img_emb(clip_feature)
|
| 71 |
+
context = torch.cat([clip_embdding, context], dim=1)
|
| 72 |
+
|
| 73 |
+
x, (f, h, w) = self.patchify(x)
|
| 74 |
+
|
| 75 |
+
freqs = (
|
| 76 |
+
torch.cat(
|
| 77 |
+
[
|
| 78 |
+
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 79 |
+
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 80 |
+
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
| 81 |
+
],
|
| 82 |
+
dim=-1,
|
| 83 |
+
)
|
| 84 |
+
.reshape(f * h * w, 1, -1)
|
| 85 |
+
.to(x.device)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def create_custom_forward(module):
|
| 89 |
+
def custom_forward(*inputs):
|
| 90 |
+
return module(*inputs)
|
| 91 |
+
|
| 92 |
+
return custom_forward
|
| 93 |
+
|
| 94 |
+
# Context Parallel
|
| 95 |
+
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[
|
| 96 |
+
get_sequence_parallel_rank()
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
for block in self.blocks:
|
| 100 |
+
if self.training and use_gradient_checkpointing:
|
| 101 |
+
if use_gradient_checkpointing_offload:
|
| 102 |
+
with torch.autograd.graph.save_on_cpu():
|
| 103 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 104 |
+
create_custom_forward(block),
|
| 105 |
+
x,
|
| 106 |
+
context,
|
| 107 |
+
t_mod,
|
| 108 |
+
freqs,
|
| 109 |
+
use_reentrant=False,
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 113 |
+
create_custom_forward(block),
|
| 114 |
+
x,
|
| 115 |
+
context,
|
| 116 |
+
t_mod,
|
| 117 |
+
freqs,
|
| 118 |
+
use_reentrant=False,
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
x = block(x, context, t_mod, freqs)
|
| 122 |
+
|
| 123 |
+
x = self.head(x, t)
|
| 124 |
+
|
| 125 |
+
# Context Parallel
|
| 126 |
+
x = get_sp_group().all_gather(x, dim=1)
|
| 127 |
+
|
| 128 |
+
# unpatchify
|
| 129 |
+
x = self.unpatchify(x, (f, h, w))
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def usp_attn_forward(self, x, freqs):
|
| 134 |
+
q = self.norm_q(self.q(x))
|
| 135 |
+
k = self.norm_k(self.k(x))
|
| 136 |
+
v = self.v(x)
|
| 137 |
+
|
| 138 |
+
q = rope_apply(q, freqs, self.num_heads)
|
| 139 |
+
k = rope_apply(k, freqs, self.num_heads)
|
| 140 |
+
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
| 141 |
+
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
| 142 |
+
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
| 143 |
+
|
| 144 |
+
x = xFuserLongContextAttention()(
|
| 145 |
+
None,
|
| 146 |
+
query=q,
|
| 147 |
+
key=k,
|
| 148 |
+
value=v,
|
| 149 |
+
)
|
| 150 |
+
x = x.flatten(2)
|
| 151 |
+
|
| 152 |
+
del q, k, v
|
| 153 |
+
torch.cuda.empty_cache()
|
| 154 |
+
return self.o(x)
|
download_models.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from huggingface_hub import snapshot_download
|
| 3 |
+
|
| 4 |
+
def main(use_vace: bool):
|
| 5 |
+
if use_vace:
|
| 6 |
+
snapshot_download("Wan-AI/Wan2.1-VACE-14B", local_dir="checkpoints/VACE/")
|
| 7 |
+
else:
|
| 8 |
+
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="checkpoints/base_model/")
|
| 9 |
+
|
| 10 |
+
snapshot_download(
|
| 11 |
+
"DIAMONIK7777/antelopev2",
|
| 12 |
+
local_dir="checkpoints/antelopev2/models/antelopev2"
|
| 13 |
+
)
|
| 14 |
+
snapshot_download("BowenXue/Stand-In", local_dir="checkpoints/Stand-In/")
|
| 15 |
+
|
| 16 |
+
if __name__ == "__main__":
|
| 17 |
+
parser = argparse.ArgumentParser(description="Download models with or without VACE.")
|
| 18 |
+
parser.add_argument("--vace", action="store_true", help="Use VACE model instead of T2V.")
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
main(args.vace)
|
infer.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from data.video import save_video
|
| 3 |
+
from wan_loader import load_wan_pipe
|
| 4 |
+
from models.set_condition_branch import set_stand_in
|
| 5 |
+
from preprocessor import FaceProcessor
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
parser = argparse.ArgumentParser()
|
| 9 |
+
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"--ip_image",
|
| 12 |
+
type=str,
|
| 13 |
+
default="test/input/lecun.jpg",
|
| 14 |
+
help="Input face image path or URL",
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--prompt",
|
| 18 |
+
type=str,
|
| 19 |
+
default="一位男性舒适地坐在书桌前,正对着镜头,仿佛在与屏幕前的亲友对话。他的眼神专注而温柔,嘴角带着自然的笑意。背景是他精心布置的个人空间,墙上贴着照片和一张世界地图,传达出一种亲密而现代的沟通感。",
|
| 20 |
+
help="Text prompt for video generation",
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--output", type=str, default="test/output/lecun.mp4", help="Output video file path"
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--seed", type=int, default=0, help="Random seed for reproducibility"
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--num_inference_steps", type=int, default=20, help="Number of inference steps"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--negative_prompt",
|
| 34 |
+
type=str,
|
| 35 |
+
default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
| 36 |
+
help="Negative prompt to avoid unwanted features",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument("--tiled", action="store_true", help="Enable tiled mode")
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--fps", type=int, default=25, help="Frames per second for output video"
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--quality", type=int, default=9, help="Output video quality (1-9)"
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--base_path",
|
| 47 |
+
type=str,
|
| 48 |
+
default="checkpoints/base_model/",
|
| 49 |
+
help="Path to base model checkpoint",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--stand_in_path",
|
| 53 |
+
type=str,
|
| 54 |
+
default="checkpoints/Stand-In/Stand-In_wan2.1_T2V_14B_ver1.0.ckpt",
|
| 55 |
+
help="Path to LoRA weights checkpoint",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--antelopv2_path",
|
| 59 |
+
type=str,
|
| 60 |
+
default="checkpoints/antelopev2",
|
| 61 |
+
help="Path to AntelopeV2 model checkpoint",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
face_processor = FaceProcessor(antelopv2_path=args.antelopv2_path)
|
| 68 |
+
ip_image = face_processor.process(args.ip_image)
|
| 69 |
+
|
| 70 |
+
pipe = load_wan_pipe(base_path=args.base_path, torch_dtype=torch.bfloat16)
|
| 71 |
+
|
| 72 |
+
set_stand_in(
|
| 73 |
+
pipe,
|
| 74 |
+
model_path=args.stand_in_path,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
video = pipe(
|
| 78 |
+
prompt=args.prompt,
|
| 79 |
+
negative_prompt=args.negative_prompt,
|
| 80 |
+
seed=args.seed,
|
| 81 |
+
ip_image=ip_image,
|
| 82 |
+
num_inference_steps=args.num_inference_steps,
|
| 83 |
+
tiled=args.tiled,
|
| 84 |
+
)
|
| 85 |
+
save_video(video, args.output, fps=args.fps, quality=args.quality)
|
infer_face_swap.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from data.video import save_video
|
| 3 |
+
from wan_loader import load_wan_pipe
|
| 4 |
+
from models.set_condition_branch import set_stand_in
|
| 5 |
+
from preprocessor import FaceProcessor, VideoMaskGenerator
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
parser = argparse.ArgumentParser()
|
| 9 |
+
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"--ip_image",
|
| 12 |
+
type=str,
|
| 13 |
+
default="test/input/ruonan.jpg",
|
| 14 |
+
help="Input face image path or URL",
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--input_video",
|
| 18 |
+
type=str,
|
| 19 |
+
default="test/input/woman.mp4",
|
| 20 |
+
help="Input video path",
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--denoising_strength",
|
| 24 |
+
type=float,
|
| 25 |
+
default=0.85,
|
| 26 |
+
help="The lower denoising strength represents a higher similarity to the original video.",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--prompt",
|
| 30 |
+
type=str,
|
| 31 |
+
default="The video features a woman standing in front of a large screen displaying the words "
|
| 32 |
+
"Tech Minute"
|
| 33 |
+
" and the logo for CNET. She is wearing a purple top and appears to be presenting or speaking about technology-related topics. The background includes a cityscape with tall buildings, suggesting an urban setting. The woman seems to be engaged in a discussion or providing information on technology news or trends. The overall atmosphere is professional and informative, likely aimed at educating viewers about the latest developments in the tech industry.",
|
| 34 |
+
help="Text prompt for video generation",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--output",
|
| 38 |
+
type=str,
|
| 39 |
+
default="test/output/ruonan.mp4",
|
| 40 |
+
help="Output video file path",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--seed", type=int, default=0, help="Random seed for reproducibility"
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--num_inference_steps", type=int, default=20, help="Number of inference steps"
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--force_background_consistency",
|
| 50 |
+
type=bool,
|
| 51 |
+
default=False,
|
| 52 |
+
help="Set to True to force background consistency across generated frames.",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--negative_prompt",
|
| 57 |
+
type=str,
|
| 58 |
+
default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
| 59 |
+
help="Negative prompt to avoid unwanted features",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument("--tiled", action="store_true", help="Enable tiled mode")
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--fps", type=int, default=25, help="Frames per second for output video"
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--quality", type=int, default=9, help="Output video quality (1-9)"
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--base_path",
|
| 70 |
+
type=str,
|
| 71 |
+
default="checkpoints/base_model/",
|
| 72 |
+
help="Path to base model checkpoint",
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--stand_in_path",
|
| 76 |
+
type=str,
|
| 77 |
+
default="checkpoints/Stand-In/Stand-In_wan2.1_T2V_14B_ver1.0.ckpt",
|
| 78 |
+
help="Path to LoRA weights checkpoint",
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--antelopv2_path",
|
| 82 |
+
type=str,
|
| 83 |
+
default="checkpoints/antelopev2",
|
| 84 |
+
help="Path to AntelopeV2 model checkpoint",
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
args = parser.parse_args()
|
| 88 |
+
|
| 89 |
+
face_processor = FaceProcessor(antelopv2_path=args.antelopv2_path)
|
| 90 |
+
videomask_generator = VideoMaskGenerator(antelopv2_path=args.antelopv2_path)
|
| 91 |
+
|
| 92 |
+
ip_image, ip_image_rgba = face_processor.process(args.ip_image, extra_input=True)
|
| 93 |
+
input_video, face_mask, width, height, num_frames = videomask_generator.process(args.input_video, ip_image_rgba, random_horizontal_flip_chance=0.05, dilation_kernel_size=10)
|
| 94 |
+
|
| 95 |
+
pipe = load_wan_pipe(
|
| 96 |
+
base_path=args.base_path, face_swap=True, torch_dtype=torch.bfloat16
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
set_stand_in(
|
| 100 |
+
pipe,
|
| 101 |
+
model_path=args.stand_in_path,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
video = pipe(
|
| 105 |
+
prompt=args.prompt,
|
| 106 |
+
negative_prompt=args.negative_prompt,
|
| 107 |
+
seed=args.seed,
|
| 108 |
+
width=width,
|
| 109 |
+
height=height,
|
| 110 |
+
num_frames=num_frames,
|
| 111 |
+
denoising_strength=args.denoising_strength,
|
| 112 |
+
ip_image=ip_image,
|
| 113 |
+
face_mask=face_mask,
|
| 114 |
+
input_video=input_video,
|
| 115 |
+
num_inference_steps=args.num_inference_steps,
|
| 116 |
+
tiled=args.tiled,
|
| 117 |
+
force_background_consistency=args.force_background_consistency
|
| 118 |
+
)
|
| 119 |
+
save_video(video, args.output, fps=args.fps, quality=args.quality)
|
infer_with_lora.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from data.video import save_video
|
| 3 |
+
from wan_loader import load_wan_pipe
|
| 4 |
+
from models.set_condition_branch import set_stand_in
|
| 5 |
+
from preprocessor import FaceProcessor
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
parser = argparse.ArgumentParser()
|
| 9 |
+
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"--ip_image",
|
| 12 |
+
type=str,
|
| 13 |
+
default="test/input/lecun.jpg",
|
| 14 |
+
help="Input face image path or URL",
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--lora_path", type=str, required=True, help="Text prompt for video generation"
|
| 18 |
+
)
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--prompt",
|
| 21 |
+
type=str,
|
| 22 |
+
default="Close-up of a young man with dark hair tied back, wearing a white kimono adorned with a red floral pattern. He sits against a backdrop of sliding doors with blue accents. His expression shifts from neutral to a slight smile, then to a surprised look. The camera remains static, focusing on his face and upper body as he appears to be reacting to something off-screen. The lighting is soft and natural, suggesting daytime.",
|
| 23 |
+
help="Text prompt for video generation",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--output", type=str, default="test/output/lecun.mp4", help="Output video file path"
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--seed", type=int, default=0, help="Random seed for reproducibility"
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--num_inference_steps", type=int, default=20, help="Number of inference steps"
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument("--lora_scale", type=float, default=1.0, help="Lora Scale")
|
| 35 |
+
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--negative_prompt",
|
| 38 |
+
type=str,
|
| 39 |
+
default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
| 40 |
+
help="Negative prompt to avoid unwanted features",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument("--tiled", action="store_true", help="Enable tiled mode")
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--fps", type=int, default=25, help="Frames per second for output video"
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--quality", type=int, default=9, help="Output video quality (1-9)"
|
| 48 |
+
)
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--base_path",
|
| 51 |
+
type=str,
|
| 52 |
+
default="checkpoints/base_model/",
|
| 53 |
+
help="Path to base model checkpoint",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--stand_in_path",
|
| 57 |
+
type=str,
|
| 58 |
+
default="checkpoints/Stand-In/Stand-In_wan2.1_T2V_14B_ver1.0.ckpt",
|
| 59 |
+
help="Path to LoRA weights checkpoint",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--antelopv2_path",
|
| 63 |
+
type=str,
|
| 64 |
+
default="checkpoints/antelopev2",
|
| 65 |
+
help="Path to AntelopeV2 model checkpoint",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
|
| 70 |
+
face_processor = FaceProcessor(antelopv2_path=args.antelopv2_path)
|
| 71 |
+
ip_image = face_processor.process(args.ip_image)
|
| 72 |
+
|
| 73 |
+
pipe = load_wan_pipe(base_path=args.base_path, torch_dtype=torch.bfloat16)
|
| 74 |
+
|
| 75 |
+
pipe.load_lora(
|
| 76 |
+
pipe.dit,
|
| 77 |
+
args.lora_path,
|
| 78 |
+
alpha=1,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
set_stand_in(
|
| 82 |
+
pipe,
|
| 83 |
+
model_path=args.stand_in_path,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
video = pipe(
|
| 87 |
+
prompt=args.prompt,
|
| 88 |
+
negative_prompt=args.negative_prompt,
|
| 89 |
+
seed=args.seed,
|
| 90 |
+
ip_image=ip_image,
|
| 91 |
+
num_inference_steps=args.num_inference_steps,
|
| 92 |
+
tiled=args.tiled,
|
| 93 |
+
)
|
| 94 |
+
save_video(video, args.output, fps=args.fps, quality=args.quality)
|
infer_with_vace.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from data.video import save_video
|
| 3 |
+
from wan_loader import load_wan_pipe
|
| 4 |
+
from models.set_condition_branch import set_stand_in
|
| 5 |
+
from preprocessor import FaceProcessor
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
parser = argparse.ArgumentParser()
|
| 9 |
+
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"--ip_image",
|
| 12 |
+
type=str,
|
| 13 |
+
default="test/input/first_frame.png",
|
| 14 |
+
help="Input face image path or URL",
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--reference_video",
|
| 18 |
+
type=str,
|
| 19 |
+
default="test/input/pose.mp4",
|
| 20 |
+
help="reference_video path",
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--reference_image",
|
| 24 |
+
default="test/input/first_frame.png",
|
| 25 |
+
type=str,
|
| 26 |
+
help="reference_video path",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--vace_scale",
|
| 30 |
+
type=float,
|
| 31 |
+
default=0.8,
|
| 32 |
+
help="Scaling factor for VACE.",
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--prompt",
|
| 36 |
+
type=str,
|
| 37 |
+
default="一个女人举起双手",
|
| 38 |
+
help="Text prompt for video generation",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--output", type=str, default="test/output/woman.mp4", help="Output video file path"
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--seed", type=int, default=0, help="Random seed for reproducibility"
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--num_inference_steps", type=int, default=20, help="Number of inference steps"
|
| 48 |
+
)
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--vace_path",
|
| 51 |
+
type=str,
|
| 52 |
+
default="checkpoints/VACE/",
|
| 53 |
+
help="Path to base model checkpoint",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--negative_prompt",
|
| 58 |
+
type=str,
|
| 59 |
+
default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
| 60 |
+
help="Negative prompt to avoid unwanted features",
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument("--tiled", action="store_true", help="Enable tiled mode")
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--fps", type=int, default=25, help="Frames per second for output video"
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--quality", type=int, default=9, help="Output video quality (1-9)"
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--stand_in_path",
|
| 71 |
+
type=str,
|
| 72 |
+
default="checkpoints/Stand-In/Stand-In_wan2.1_T2V_14B_ver1.0.ckpt",
|
| 73 |
+
help="Path to LoRA weights checkpoint",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--antelopv2_path",
|
| 77 |
+
type=str,
|
| 78 |
+
default="checkpoints/antelopev2",
|
| 79 |
+
help="Path to AntelopeV2 model checkpoint",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
args = parser.parse_args()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
face_processor = FaceProcessor(antelopv2_path=args.antelopv2_path)
|
| 86 |
+
ip_image = face_processor.process(args.ip_image)
|
| 87 |
+
|
| 88 |
+
pipe = load_wan_pipe(base_path=args.vace_path, use_vace=True, torch_dtype=torch.bfloat16)
|
| 89 |
+
|
| 90 |
+
set_stand_in(
|
| 91 |
+
pipe,
|
| 92 |
+
model_path=args.stand_in_path,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
video = pipe(
|
| 96 |
+
prompt=args.prompt,
|
| 97 |
+
vace_video=args.reference_video,
|
| 98 |
+
vace_reference_image=args.reference_image,
|
| 99 |
+
negative_prompt=args.negative_prompt,
|
| 100 |
+
vace_scale=args.vace_scale,
|
| 101 |
+
seed=args.seed,
|
| 102 |
+
ip_image=ip_image,
|
| 103 |
+
num_inference_steps=args.num_inference_steps,
|
| 104 |
+
tiled=args.tiled,
|
| 105 |
+
)
|
| 106 |
+
save_video(video, args.output, fps=args.fps, quality=args.quality)
|
lora/__init__.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class GeneralLoRALoader:
|
| 5 |
+
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
| 6 |
+
self.device = device
|
| 7 |
+
self.torch_dtype = torch_dtype
|
| 8 |
+
|
| 9 |
+
def get_name_dict(self, lora_state_dict):
|
| 10 |
+
lora_name_dict = {}
|
| 11 |
+
|
| 12 |
+
has_lora_A = any(k.endswith(".lora_A.weight") for k in lora_state_dict)
|
| 13 |
+
has_lora_down = any(k.endswith(".lora_down.weight") for k in lora_state_dict)
|
| 14 |
+
|
| 15 |
+
if has_lora_A:
|
| 16 |
+
lora_a_keys = [k for k in lora_state_dict if k.endswith(".lora_A.weight")]
|
| 17 |
+
for lora_a_key in lora_a_keys:
|
| 18 |
+
base_name = lora_a_key.replace(".lora_A.weight", "")
|
| 19 |
+
lora_b_key = base_name + ".lora_B.weight"
|
| 20 |
+
|
| 21 |
+
if lora_b_key in lora_state_dict:
|
| 22 |
+
target_name = base_name.replace("diffusion_model.", "", 1)
|
| 23 |
+
lora_name_dict[target_name] = (lora_b_key, lora_a_key)
|
| 24 |
+
|
| 25 |
+
elif has_lora_down:
|
| 26 |
+
lora_down_keys = [
|
| 27 |
+
k for k in lora_state_dict if k.endswith(".lora_down.weight")
|
| 28 |
+
]
|
| 29 |
+
for lora_down_key in lora_down_keys:
|
| 30 |
+
base_name = lora_down_key.replace(".lora_down.weight", "")
|
| 31 |
+
lora_up_key = base_name + ".lora_up.weight"
|
| 32 |
+
|
| 33 |
+
if lora_up_key in lora_state_dict:
|
| 34 |
+
target_name = base_name.replace("lora_unet_", "").replace("_", ".")
|
| 35 |
+
target_name = target_name.replace(".attn.", "_attn.")
|
| 36 |
+
lora_name_dict[target_name] = (lora_up_key, lora_down_key)
|
| 37 |
+
|
| 38 |
+
else:
|
| 39 |
+
print(
|
| 40 |
+
"Warning: No recognizable LoRA key names found in state_dict (neither 'lora_A' nor 'lora_down')."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return lora_name_dict
|
| 44 |
+
|
| 45 |
+
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
| 46 |
+
lora_name_dict = self.get_name_dict(state_dict_lora)
|
| 47 |
+
updated_num = 0
|
| 48 |
+
|
| 49 |
+
lora_target_names = set(lora_name_dict.keys())
|
| 50 |
+
model_layer_names = {
|
| 51 |
+
name for name, module in model.named_modules() if hasattr(module, "weight")
|
| 52 |
+
}
|
| 53 |
+
matched_names = lora_target_names.intersection(model_layer_names)
|
| 54 |
+
unmatched_lora_names = lora_target_names - model_layer_names
|
| 55 |
+
|
| 56 |
+
print(f"Successfully matched {len(matched_names)} layers.")
|
| 57 |
+
if unmatched_lora_names:
|
| 58 |
+
print(
|
| 59 |
+
f"Warning: {len(unmatched_lora_names)} LoRA layers not matched and will be ignored."
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
for name, module in model.named_modules():
|
| 63 |
+
if name in matched_names:
|
| 64 |
+
lora_b_key, lora_a_key = lora_name_dict[name]
|
| 65 |
+
weight_up = state_dict_lora[lora_b_key].to(
|
| 66 |
+
device=self.device, dtype=self.torch_dtype
|
| 67 |
+
)
|
| 68 |
+
weight_down = state_dict_lora[lora_a_key].to(
|
| 69 |
+
device=self.device, dtype=self.torch_dtype
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if len(weight_up.shape) == 4:
|
| 73 |
+
weight_up = weight_up.squeeze(3).squeeze(2)
|
| 74 |
+
weight_down = weight_down.squeeze(3).squeeze(2)
|
| 75 |
+
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(
|
| 76 |
+
2
|
| 77 |
+
).unsqueeze(3)
|
| 78 |
+
else:
|
| 79 |
+
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
| 80 |
+
|
| 81 |
+
if module.weight.shape != weight_lora.shape:
|
| 82 |
+
print(f"Error: Shape mismatch for layer '{name}'! Skipping update.")
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
module.weight.data = (
|
| 86 |
+
module.weight.data.to(weight_lora.device, dtype=weight_lora.dtype)
|
| 87 |
+
+ weight_lora
|
| 88 |
+
)
|
| 89 |
+
updated_num += 1
|
| 90 |
+
|
| 91 |
+
print(f"LoRA loading complete, updated {updated_num} tensors in total.\n")
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .model_manager import *
|
models/attention.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def low_version_attention(query, key, value, attn_bias=None):
|
| 6 |
+
scale = 1 / query.shape[-1] ** 0.5
|
| 7 |
+
query = query * scale
|
| 8 |
+
attn = torch.matmul(query, key.transpose(-2, -1))
|
| 9 |
+
if attn_bias is not None:
|
| 10 |
+
attn = attn + attn_bias
|
| 11 |
+
attn = attn.softmax(-1)
|
| 12 |
+
return attn @ value
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Attention(torch.nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
q_dim,
|
| 19 |
+
num_heads,
|
| 20 |
+
head_dim,
|
| 21 |
+
kv_dim=None,
|
| 22 |
+
bias_q=False,
|
| 23 |
+
bias_kv=False,
|
| 24 |
+
bias_out=False,
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
dim_inner = head_dim * num_heads
|
| 28 |
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
| 29 |
+
self.num_heads = num_heads
|
| 30 |
+
self.head_dim = head_dim
|
| 31 |
+
|
| 32 |
+
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
| 33 |
+
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
| 34 |
+
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
| 35 |
+
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
| 36 |
+
|
| 37 |
+
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
| 38 |
+
batch_size = q.shape[0]
|
| 39 |
+
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 40 |
+
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 41 |
+
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
| 42 |
+
q, ip_k, ip_v
|
| 43 |
+
)
|
| 44 |
+
hidden_states = hidden_states + scale * ip_hidden_states
|
| 45 |
+
return hidden_states
|
| 46 |
+
|
| 47 |
+
def torch_forward(
|
| 48 |
+
self,
|
| 49 |
+
hidden_states,
|
| 50 |
+
encoder_hidden_states=None,
|
| 51 |
+
attn_mask=None,
|
| 52 |
+
ipadapter_kwargs=None,
|
| 53 |
+
qkv_preprocessor=None,
|
| 54 |
+
):
|
| 55 |
+
if encoder_hidden_states is None:
|
| 56 |
+
encoder_hidden_states = hidden_states
|
| 57 |
+
|
| 58 |
+
batch_size = encoder_hidden_states.shape[0]
|
| 59 |
+
|
| 60 |
+
q = self.to_q(hidden_states)
|
| 61 |
+
k = self.to_k(encoder_hidden_states)
|
| 62 |
+
v = self.to_v(encoder_hidden_states)
|
| 63 |
+
|
| 64 |
+
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 65 |
+
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 66 |
+
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 67 |
+
|
| 68 |
+
if qkv_preprocessor is not None:
|
| 69 |
+
q, k, v = qkv_preprocessor(q, k, v)
|
| 70 |
+
|
| 71 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
| 72 |
+
q, k, v, attn_mask=attn_mask
|
| 73 |
+
)
|
| 74 |
+
if ipadapter_kwargs is not None:
|
| 75 |
+
hidden_states = self.interact_with_ipadapter(
|
| 76 |
+
hidden_states, q, **ipadapter_kwargs
|
| 77 |
+
)
|
| 78 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
| 79 |
+
batch_size, -1, self.num_heads * self.head_dim
|
| 80 |
+
)
|
| 81 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 82 |
+
|
| 83 |
+
hidden_states = self.to_out(hidden_states)
|
| 84 |
+
|
| 85 |
+
return hidden_states
|
| 86 |
+
|
| 87 |
+
def xformers_forward(
|
| 88 |
+
self, hidden_states, encoder_hidden_states=None, attn_mask=None
|
| 89 |
+
):
|
| 90 |
+
if encoder_hidden_states is None:
|
| 91 |
+
encoder_hidden_states = hidden_states
|
| 92 |
+
|
| 93 |
+
q = self.to_q(hidden_states)
|
| 94 |
+
k = self.to_k(encoder_hidden_states)
|
| 95 |
+
v = self.to_v(encoder_hidden_states)
|
| 96 |
+
|
| 97 |
+
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
| 98 |
+
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
| 99 |
+
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
| 100 |
+
|
| 101 |
+
if attn_mask is not None:
|
| 102 |
+
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
| 103 |
+
else:
|
| 104 |
+
import xformers.ops as xops
|
| 105 |
+
|
| 106 |
+
hidden_states = xops.memory_efficient_attention(q, k, v)
|
| 107 |
+
hidden_states = rearrange(
|
| 108 |
+
hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 112 |
+
hidden_states = self.to_out(hidden_states)
|
| 113 |
+
|
| 114 |
+
return hidden_states
|
| 115 |
+
|
| 116 |
+
def forward(
|
| 117 |
+
self,
|
| 118 |
+
hidden_states,
|
| 119 |
+
encoder_hidden_states=None,
|
| 120 |
+
attn_mask=None,
|
| 121 |
+
ipadapter_kwargs=None,
|
| 122 |
+
qkv_preprocessor=None,
|
| 123 |
+
):
|
| 124 |
+
return self.torch_forward(
|
| 125 |
+
hidden_states,
|
| 126 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 127 |
+
attn_mask=attn_mask,
|
| 128 |
+
ipadapter_kwargs=ipadapter_kwargs,
|
| 129 |
+
qkv_preprocessor=qkv_preprocessor,
|
| 130 |
+
)
|
models/downloader.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 2 |
+
import os, shutil
|
| 3 |
+
from typing_extensions import Literal, TypeAlias
|
| 4 |
+
from typing import List
|
| 5 |
+
from configs.model_config import (
|
| 6 |
+
preset_models_on_huggingface,
|
| 7 |
+
preset_models_on_modelscope,
|
| 8 |
+
Preset_model_id,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
| 13 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 14 |
+
file_name = os.path.basename(origin_file_path)
|
| 15 |
+
if file_name in os.listdir(local_dir):
|
| 16 |
+
print(f" {file_name} has been already in {local_dir}.")
|
| 17 |
+
else:
|
| 18 |
+
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
| 19 |
+
snapshot_download(
|
| 20 |
+
model_id, allow_file_pattern=origin_file_path, local_dir=local_dir
|
| 21 |
+
)
|
| 22 |
+
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
| 23 |
+
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
| 24 |
+
if downloaded_file_path != target_file_path:
|
| 25 |
+
shutil.move(downloaded_file_path, target_file_path)
|
| 26 |
+
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
| 30 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 31 |
+
file_name = os.path.basename(origin_file_path)
|
| 32 |
+
if file_name in os.listdir(local_dir):
|
| 33 |
+
print(f" {file_name} has been already in {local_dir}.")
|
| 34 |
+
else:
|
| 35 |
+
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
| 36 |
+
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
| 37 |
+
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
| 38 |
+
target_file_path = os.path.join(local_dir, file_name)
|
| 39 |
+
if downloaded_file_path != target_file_path:
|
| 40 |
+
shutil.move(downloaded_file_path, target_file_path)
|
| 41 |
+
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
Preset_model_website: TypeAlias = Literal[
|
| 45 |
+
"HuggingFace",
|
| 46 |
+
"ModelScope",
|
| 47 |
+
]
|
| 48 |
+
website_to_preset_models = {
|
| 49 |
+
"HuggingFace": preset_models_on_huggingface,
|
| 50 |
+
"ModelScope": preset_models_on_modelscope,
|
| 51 |
+
}
|
| 52 |
+
website_to_download_fn = {
|
| 53 |
+
"HuggingFace": download_from_huggingface,
|
| 54 |
+
"ModelScope": download_from_modelscope,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def download_customized_models(
|
| 59 |
+
model_id,
|
| 60 |
+
origin_file_path,
|
| 61 |
+
local_dir,
|
| 62 |
+
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
| 63 |
+
):
|
| 64 |
+
downloaded_files = []
|
| 65 |
+
for website in downloading_priority:
|
| 66 |
+
# Check if the file is downloaded.
|
| 67 |
+
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
| 68 |
+
if file_to_download in downloaded_files:
|
| 69 |
+
continue
|
| 70 |
+
# Download
|
| 71 |
+
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
| 72 |
+
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
| 73 |
+
downloaded_files.append(file_to_download)
|
| 74 |
+
return downloaded_files
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def download_models(
|
| 78 |
+
model_id_list: List[Preset_model_id] = [],
|
| 79 |
+
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
| 80 |
+
):
|
| 81 |
+
print(f"Downloading models: {model_id_list}")
|
| 82 |
+
downloaded_files = []
|
| 83 |
+
load_files = []
|
| 84 |
+
|
| 85 |
+
for model_id in model_id_list:
|
| 86 |
+
for website in downloading_priority:
|
| 87 |
+
if model_id in website_to_preset_models[website]:
|
| 88 |
+
# Parse model metadata
|
| 89 |
+
model_metadata = website_to_preset_models[website][model_id]
|
| 90 |
+
if isinstance(model_metadata, list):
|
| 91 |
+
file_data = model_metadata
|
| 92 |
+
else:
|
| 93 |
+
file_data = model_metadata.get("file_list", [])
|
| 94 |
+
|
| 95 |
+
# Try downloading the model from this website.
|
| 96 |
+
model_files = []
|
| 97 |
+
for model_id, origin_file_path, local_dir in file_data:
|
| 98 |
+
# Check if the file is downloaded.
|
| 99 |
+
file_to_download = os.path.join(
|
| 100 |
+
local_dir, os.path.basename(origin_file_path)
|
| 101 |
+
)
|
| 102 |
+
if file_to_download in downloaded_files:
|
| 103 |
+
continue
|
| 104 |
+
# Download
|
| 105 |
+
website_to_download_fn[website](
|
| 106 |
+
model_id, origin_file_path, local_dir
|
| 107 |
+
)
|
| 108 |
+
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
| 109 |
+
downloaded_files.append(file_to_download)
|
| 110 |
+
model_files.append(file_to_download)
|
| 111 |
+
|
| 112 |
+
# If the model is successfully downloaded, break.
|
| 113 |
+
if len(model_files) > 0:
|
| 114 |
+
if (
|
| 115 |
+
isinstance(model_metadata, dict)
|
| 116 |
+
and "load_path" in model_metadata
|
| 117 |
+
):
|
| 118 |
+
model_files = model_metadata["load_path"]
|
| 119 |
+
load_files.extend(model_files)
|
| 120 |
+
break
|
| 121 |
+
|
| 122 |
+
return load_files
|
models/model_manager.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, torch, json, importlib
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from .downloader import (
|
| 5 |
+
download_models,
|
| 6 |
+
download_customized_models,
|
| 7 |
+
Preset_model_id,
|
| 8 |
+
Preset_model_website,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
from configs.model_config import (
|
| 12 |
+
model_loader_configs,
|
| 13 |
+
huggingface_model_loader_configs,
|
| 14 |
+
patch_model_loader_configs,
|
| 15 |
+
)
|
| 16 |
+
from .utils import (
|
| 17 |
+
load_state_dict,
|
| 18 |
+
init_weights_on_device,
|
| 19 |
+
hash_state_dict_keys,
|
| 20 |
+
split_state_dict_with_prefix,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_model_from_single_file(
|
| 25 |
+
state_dict, model_names, model_classes, model_resource, torch_dtype, device
|
| 26 |
+
):
|
| 27 |
+
loaded_model_names, loaded_models = [], []
|
| 28 |
+
for model_name, model_class in zip(model_names, model_classes):
|
| 29 |
+
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
| 30 |
+
state_dict_converter = model_class.state_dict_converter()
|
| 31 |
+
if model_resource == "civitai":
|
| 32 |
+
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
| 33 |
+
elif model_resource == "diffusers":
|
| 34 |
+
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
| 35 |
+
if isinstance(state_dict_results, tuple):
|
| 36 |
+
model_state_dict, extra_kwargs = state_dict_results
|
| 37 |
+
print(
|
| 38 |
+
f" This model is initialized with extra kwargs: {extra_kwargs}"
|
| 39 |
+
)
|
| 40 |
+
else:
|
| 41 |
+
model_state_dict, extra_kwargs = state_dict_results, {}
|
| 42 |
+
torch_dtype = (
|
| 43 |
+
torch.float32
|
| 44 |
+
if extra_kwargs.get("upcast_to_float32", False)
|
| 45 |
+
else torch_dtype
|
| 46 |
+
)
|
| 47 |
+
with init_weights_on_device():
|
| 48 |
+
model = model_class(**extra_kwargs)
|
| 49 |
+
if hasattr(model, "eval"):
|
| 50 |
+
model = model.eval()
|
| 51 |
+
model.load_state_dict(model_state_dict, assign=True)
|
| 52 |
+
model = model.to(dtype=torch_dtype, device=device)
|
| 53 |
+
loaded_model_names.append(model_name)
|
| 54 |
+
loaded_models.append(model)
|
| 55 |
+
return loaded_model_names, loaded_models
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_model_from_huggingface_folder(
|
| 59 |
+
file_path, model_names, model_classes, torch_dtype, device
|
| 60 |
+
):
|
| 61 |
+
loaded_model_names, loaded_models = [], []
|
| 62 |
+
for model_name, model_class in zip(model_names, model_classes):
|
| 63 |
+
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
| 64 |
+
model = model_class.from_pretrained(
|
| 65 |
+
file_path, torch_dtype=torch_dtype
|
| 66 |
+
).eval()
|
| 67 |
+
else:
|
| 68 |
+
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
|
| 69 |
+
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
| 70 |
+
model = model.half()
|
| 71 |
+
try:
|
| 72 |
+
model = model.to(device=device)
|
| 73 |
+
except:
|
| 74 |
+
pass
|
| 75 |
+
loaded_model_names.append(model_name)
|
| 76 |
+
loaded_models.append(model)
|
| 77 |
+
return loaded_model_names, loaded_models
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_single_patch_model_from_single_file(
|
| 81 |
+
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device
|
| 82 |
+
):
|
| 83 |
+
print(
|
| 84 |
+
f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}"
|
| 85 |
+
)
|
| 86 |
+
base_state_dict = base_model.state_dict()
|
| 87 |
+
base_model.to("cpu")
|
| 88 |
+
del base_model
|
| 89 |
+
model = model_class(**extra_kwargs)
|
| 90 |
+
model.load_state_dict(base_state_dict, strict=False)
|
| 91 |
+
model.load_state_dict(state_dict, strict=False)
|
| 92 |
+
model.to(dtype=torch_dtype, device=device)
|
| 93 |
+
return model
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def load_patch_model_from_single_file(
|
| 97 |
+
state_dict,
|
| 98 |
+
model_names,
|
| 99 |
+
model_classes,
|
| 100 |
+
extra_kwargs,
|
| 101 |
+
model_manager,
|
| 102 |
+
torch_dtype,
|
| 103 |
+
device,
|
| 104 |
+
):
|
| 105 |
+
loaded_model_names, loaded_models = [], []
|
| 106 |
+
for model_name, model_class in zip(model_names, model_classes):
|
| 107 |
+
while True:
|
| 108 |
+
for model_id in range(len(model_manager.model)):
|
| 109 |
+
base_model_name = model_manager.model_name[model_id]
|
| 110 |
+
if base_model_name == model_name:
|
| 111 |
+
base_model_path = model_manager.model_path[model_id]
|
| 112 |
+
base_model = model_manager.model[model_id]
|
| 113 |
+
print(
|
| 114 |
+
f" Adding patch model to {base_model_name} ({base_model_path})"
|
| 115 |
+
)
|
| 116 |
+
patched_model = load_single_patch_model_from_single_file(
|
| 117 |
+
state_dict,
|
| 118 |
+
model_name,
|
| 119 |
+
model_class,
|
| 120 |
+
base_model,
|
| 121 |
+
extra_kwargs,
|
| 122 |
+
torch_dtype,
|
| 123 |
+
device,
|
| 124 |
+
)
|
| 125 |
+
loaded_model_names.append(base_model_name)
|
| 126 |
+
loaded_models.append(patched_model)
|
| 127 |
+
model_manager.model.pop(model_id)
|
| 128 |
+
model_manager.model_path.pop(model_id)
|
| 129 |
+
model_manager.model_name.pop(model_id)
|
| 130 |
+
break
|
| 131 |
+
else:
|
| 132 |
+
break
|
| 133 |
+
return loaded_model_names, loaded_models
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class ModelDetectorTemplate:
|
| 137 |
+
def __init__(self):
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
def match(self, file_path="", state_dict={}):
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
def load(
|
| 144 |
+
self,
|
| 145 |
+
file_path="",
|
| 146 |
+
state_dict={},
|
| 147 |
+
device="cuda",
|
| 148 |
+
torch_dtype=torch.float16,
|
| 149 |
+
**kwargs,
|
| 150 |
+
):
|
| 151 |
+
return [], []
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class ModelDetectorFromSingleFile:
|
| 155 |
+
def __init__(self, model_loader_configs=[]):
|
| 156 |
+
self.keys_hash_with_shape_dict = {}
|
| 157 |
+
self.keys_hash_dict = {}
|
| 158 |
+
for metadata in model_loader_configs:
|
| 159 |
+
self.add_model_metadata(*metadata)
|
| 160 |
+
|
| 161 |
+
def add_model_metadata(
|
| 162 |
+
self,
|
| 163 |
+
keys_hash,
|
| 164 |
+
keys_hash_with_shape,
|
| 165 |
+
model_names,
|
| 166 |
+
model_classes,
|
| 167 |
+
model_resource,
|
| 168 |
+
):
|
| 169 |
+
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (
|
| 170 |
+
model_names,
|
| 171 |
+
model_classes,
|
| 172 |
+
model_resource,
|
| 173 |
+
)
|
| 174 |
+
if keys_hash is not None:
|
| 175 |
+
self.keys_hash_dict[keys_hash] = (
|
| 176 |
+
model_names,
|
| 177 |
+
model_classes,
|
| 178 |
+
model_resource,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def match(self, file_path="", state_dict={}):
|
| 182 |
+
if isinstance(file_path, str) and os.path.isdir(file_path):
|
| 183 |
+
return False
|
| 184 |
+
if len(state_dict) == 0:
|
| 185 |
+
state_dict = load_state_dict(file_path)
|
| 186 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
| 187 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
| 188 |
+
return True
|
| 189 |
+
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
| 190 |
+
if keys_hash in self.keys_hash_dict:
|
| 191 |
+
return True
|
| 192 |
+
return False
|
| 193 |
+
|
| 194 |
+
def load(
|
| 195 |
+
self,
|
| 196 |
+
file_path="",
|
| 197 |
+
state_dict={},
|
| 198 |
+
device="cuda",
|
| 199 |
+
torch_dtype=torch.float16,
|
| 200 |
+
**kwargs,
|
| 201 |
+
):
|
| 202 |
+
if len(state_dict) == 0:
|
| 203 |
+
state_dict = load_state_dict(file_path)
|
| 204 |
+
|
| 205 |
+
# Load models with strict matching
|
| 206 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
| 207 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
| 208 |
+
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[
|
| 209 |
+
keys_hash_with_shape
|
| 210 |
+
]
|
| 211 |
+
loaded_model_names, loaded_models = load_model_from_single_file(
|
| 212 |
+
state_dict,
|
| 213 |
+
model_names,
|
| 214 |
+
model_classes,
|
| 215 |
+
model_resource,
|
| 216 |
+
torch_dtype,
|
| 217 |
+
device,
|
| 218 |
+
)
|
| 219 |
+
return loaded_model_names, loaded_models
|
| 220 |
+
|
| 221 |
+
# Load models without strict matching
|
| 222 |
+
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
| 223 |
+
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
| 224 |
+
if keys_hash in self.keys_hash_dict:
|
| 225 |
+
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
| 226 |
+
loaded_model_names, loaded_models = load_model_from_single_file(
|
| 227 |
+
state_dict,
|
| 228 |
+
model_names,
|
| 229 |
+
model_classes,
|
| 230 |
+
model_resource,
|
| 231 |
+
torch_dtype,
|
| 232 |
+
device,
|
| 233 |
+
)
|
| 234 |
+
return loaded_model_names, loaded_models
|
| 235 |
+
|
| 236 |
+
return loaded_model_names, loaded_models
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
| 240 |
+
def __init__(self, model_loader_configs=[]):
|
| 241 |
+
super().__init__(model_loader_configs)
|
| 242 |
+
|
| 243 |
+
def match(self, file_path="", state_dict={}):
|
| 244 |
+
if isinstance(file_path, str) and os.path.isdir(file_path):
|
| 245 |
+
return False
|
| 246 |
+
if len(state_dict) == 0:
|
| 247 |
+
state_dict = load_state_dict(file_path)
|
| 248 |
+
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
| 249 |
+
for sub_state_dict in splited_state_dict:
|
| 250 |
+
if super().match(file_path, sub_state_dict):
|
| 251 |
+
return True
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
def load(
|
| 255 |
+
self,
|
| 256 |
+
file_path="",
|
| 257 |
+
state_dict={},
|
| 258 |
+
device="cuda",
|
| 259 |
+
torch_dtype=torch.float16,
|
| 260 |
+
**kwargs,
|
| 261 |
+
):
|
| 262 |
+
# Split the state_dict and load from each component
|
| 263 |
+
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
| 264 |
+
valid_state_dict = {}
|
| 265 |
+
for sub_state_dict in splited_state_dict:
|
| 266 |
+
if super().match(file_path, sub_state_dict):
|
| 267 |
+
valid_state_dict.update(sub_state_dict)
|
| 268 |
+
if super().match(file_path, valid_state_dict):
|
| 269 |
+
loaded_model_names, loaded_models = super().load(
|
| 270 |
+
file_path, valid_state_dict, device, torch_dtype
|
| 271 |
+
)
|
| 272 |
+
else:
|
| 273 |
+
loaded_model_names, loaded_models = [], []
|
| 274 |
+
for sub_state_dict in splited_state_dict:
|
| 275 |
+
if super().match(file_path, sub_state_dict):
|
| 276 |
+
loaded_model_names_, loaded_models_ = super().load(
|
| 277 |
+
file_path, valid_state_dict, device, torch_dtype
|
| 278 |
+
)
|
| 279 |
+
loaded_model_names += loaded_model_names_
|
| 280 |
+
loaded_models += loaded_models_
|
| 281 |
+
return loaded_model_names, loaded_models
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class ModelDetectorFromHuggingfaceFolder:
|
| 285 |
+
def __init__(self, model_loader_configs=[]):
|
| 286 |
+
self.architecture_dict = {}
|
| 287 |
+
for metadata in model_loader_configs:
|
| 288 |
+
self.add_model_metadata(*metadata)
|
| 289 |
+
|
| 290 |
+
def add_model_metadata(
|
| 291 |
+
self, architecture, huggingface_lib, model_name, redirected_architecture
|
| 292 |
+
):
|
| 293 |
+
self.architecture_dict[architecture] = (
|
| 294 |
+
huggingface_lib,
|
| 295 |
+
model_name,
|
| 296 |
+
redirected_architecture,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
def match(self, file_path="", state_dict={}):
|
| 300 |
+
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
| 301 |
+
return False
|
| 302 |
+
file_list = os.listdir(file_path)
|
| 303 |
+
if "config.json" not in file_list:
|
| 304 |
+
return False
|
| 305 |
+
with open(os.path.join(file_path, "config.json"), "r") as f:
|
| 306 |
+
config = json.load(f)
|
| 307 |
+
if "architectures" not in config and "_class_name" not in config:
|
| 308 |
+
return False
|
| 309 |
+
return True
|
| 310 |
+
|
| 311 |
+
def load(
|
| 312 |
+
self,
|
| 313 |
+
file_path="",
|
| 314 |
+
state_dict={},
|
| 315 |
+
device="cuda",
|
| 316 |
+
torch_dtype=torch.float16,
|
| 317 |
+
**kwargs,
|
| 318 |
+
):
|
| 319 |
+
with open(os.path.join(file_path, "config.json"), "r") as f:
|
| 320 |
+
config = json.load(f)
|
| 321 |
+
loaded_model_names, loaded_models = [], []
|
| 322 |
+
architectures = (
|
| 323 |
+
config["architectures"]
|
| 324 |
+
if "architectures" in config
|
| 325 |
+
else [config["_class_name"]]
|
| 326 |
+
)
|
| 327 |
+
for architecture in architectures:
|
| 328 |
+
huggingface_lib, model_name, redirected_architecture = (
|
| 329 |
+
self.architecture_dict[architecture]
|
| 330 |
+
)
|
| 331 |
+
if redirected_architecture is not None:
|
| 332 |
+
architecture = redirected_architecture
|
| 333 |
+
model_class = importlib.import_module(huggingface_lib).__getattribute__(
|
| 334 |
+
architecture
|
| 335 |
+
)
|
| 336 |
+
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(
|
| 337 |
+
file_path, [model_name], [model_class], torch_dtype, device
|
| 338 |
+
)
|
| 339 |
+
loaded_model_names += loaded_model_names_
|
| 340 |
+
loaded_models += loaded_models_
|
| 341 |
+
return loaded_model_names, loaded_models
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class ModelDetectorFromPatchedSingleFile:
|
| 345 |
+
def __init__(self, model_loader_configs=[]):
|
| 346 |
+
self.keys_hash_with_shape_dict = {}
|
| 347 |
+
for metadata in model_loader_configs:
|
| 348 |
+
self.add_model_metadata(*metadata)
|
| 349 |
+
|
| 350 |
+
def add_model_metadata(
|
| 351 |
+
self, keys_hash_with_shape, model_name, model_class, extra_kwargs
|
| 352 |
+
):
|
| 353 |
+
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (
|
| 354 |
+
model_name,
|
| 355 |
+
model_class,
|
| 356 |
+
extra_kwargs,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
def match(self, file_path="", state_dict={}):
|
| 360 |
+
if not isinstance(file_path, str) or os.path.isdir(file_path):
|
| 361 |
+
return False
|
| 362 |
+
if len(state_dict) == 0:
|
| 363 |
+
state_dict = load_state_dict(file_path)
|
| 364 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
| 365 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
| 366 |
+
return True
|
| 367 |
+
return False
|
| 368 |
+
|
| 369 |
+
def load(
|
| 370 |
+
self,
|
| 371 |
+
file_path="",
|
| 372 |
+
state_dict={},
|
| 373 |
+
device="cuda",
|
| 374 |
+
torch_dtype=torch.float16,
|
| 375 |
+
model_manager=None,
|
| 376 |
+
**kwargs,
|
| 377 |
+
):
|
| 378 |
+
if len(state_dict) == 0:
|
| 379 |
+
state_dict = load_state_dict(file_path)
|
| 380 |
+
|
| 381 |
+
# Load models with strict matching
|
| 382 |
+
loaded_model_names, loaded_models = [], []
|
| 383 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
| 384 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
| 385 |
+
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[
|
| 386 |
+
keys_hash_with_shape
|
| 387 |
+
]
|
| 388 |
+
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
| 389 |
+
state_dict,
|
| 390 |
+
model_names,
|
| 391 |
+
model_classes,
|
| 392 |
+
extra_kwargs,
|
| 393 |
+
model_manager,
|
| 394 |
+
torch_dtype,
|
| 395 |
+
device,
|
| 396 |
+
)
|
| 397 |
+
loaded_model_names += loaded_model_names_
|
| 398 |
+
loaded_models += loaded_models_
|
| 399 |
+
return loaded_model_names, loaded_models
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class ModelManager:
|
| 403 |
+
def __init__(
|
| 404 |
+
self,
|
| 405 |
+
torch_dtype=torch.float16,
|
| 406 |
+
device="cuda",
|
| 407 |
+
model_id_list: List[Preset_model_id] = [],
|
| 408 |
+
downloading_priority: List[Preset_model_website] = [
|
| 409 |
+
"ModelScope",
|
| 410 |
+
"HuggingFace",
|
| 411 |
+
],
|
| 412 |
+
file_path_list: List[str] = [],
|
| 413 |
+
):
|
| 414 |
+
self.torch_dtype = torch_dtype
|
| 415 |
+
self.device = device
|
| 416 |
+
self.model = []
|
| 417 |
+
self.model_path = []
|
| 418 |
+
self.model_name = []
|
| 419 |
+
downloaded_files = (
|
| 420 |
+
download_models(model_id_list, downloading_priority)
|
| 421 |
+
if len(model_id_list) > 0
|
| 422 |
+
else []
|
| 423 |
+
)
|
| 424 |
+
self.model_detector = [
|
| 425 |
+
ModelDetectorFromSingleFile(model_loader_configs),
|
| 426 |
+
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
| 427 |
+
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
| 428 |
+
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
| 429 |
+
]
|
| 430 |
+
self.load_models(downloaded_files + file_path_list)
|
| 431 |
+
|
| 432 |
+
def load_model_from_single_file(
|
| 433 |
+
self,
|
| 434 |
+
file_path="",
|
| 435 |
+
state_dict={},
|
| 436 |
+
model_names=[],
|
| 437 |
+
model_classes=[],
|
| 438 |
+
model_resource=None,
|
| 439 |
+
):
|
| 440 |
+
print(f"Loading models from file: {file_path}")
|
| 441 |
+
if len(state_dict) == 0:
|
| 442 |
+
state_dict = load_state_dict(file_path)
|
| 443 |
+
model_names, models = load_model_from_single_file(
|
| 444 |
+
state_dict,
|
| 445 |
+
model_names,
|
| 446 |
+
model_classes,
|
| 447 |
+
model_resource,
|
| 448 |
+
self.torch_dtype,
|
| 449 |
+
self.device,
|
| 450 |
+
)
|
| 451 |
+
for model_name, model in zip(model_names, models):
|
| 452 |
+
self.model.append(model)
|
| 453 |
+
self.model_path.append(file_path)
|
| 454 |
+
self.model_name.append(model_name)
|
| 455 |
+
print(f" The following models are loaded: {model_names}.")
|
| 456 |
+
|
| 457 |
+
def load_model_from_huggingface_folder(
|
| 458 |
+
self, file_path="", model_names=[], model_classes=[]
|
| 459 |
+
):
|
| 460 |
+
print(f"Loading models from folder: {file_path}")
|
| 461 |
+
model_names, models = load_model_from_huggingface_folder(
|
| 462 |
+
file_path, model_names, model_classes, self.torch_dtype, self.device
|
| 463 |
+
)
|
| 464 |
+
for model_name, model in zip(model_names, models):
|
| 465 |
+
self.model.append(model)
|
| 466 |
+
self.model_path.append(file_path)
|
| 467 |
+
self.model_name.append(model_name)
|
| 468 |
+
print(f" The following models are loaded: {model_names}.")
|
| 469 |
+
|
| 470 |
+
def load_patch_model_from_single_file(
|
| 471 |
+
self,
|
| 472 |
+
file_path="",
|
| 473 |
+
state_dict={},
|
| 474 |
+
model_names=[],
|
| 475 |
+
model_classes=[],
|
| 476 |
+
extra_kwargs={},
|
| 477 |
+
):
|
| 478 |
+
print(f"Loading patch models from file: {file_path}")
|
| 479 |
+
model_names, models = load_patch_model_from_single_file(
|
| 480 |
+
state_dict,
|
| 481 |
+
model_names,
|
| 482 |
+
model_classes,
|
| 483 |
+
extra_kwargs,
|
| 484 |
+
self,
|
| 485 |
+
self.torch_dtype,
|
| 486 |
+
self.device,
|
| 487 |
+
)
|
| 488 |
+
for model_name, model in zip(model_names, models):
|
| 489 |
+
self.model.append(model)
|
| 490 |
+
self.model_path.append(file_path)
|
| 491 |
+
self.model_name.append(model_name)
|
| 492 |
+
print(f" The following patched models are loaded: {model_names}.")
|
| 493 |
+
|
| 494 |
+
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
| 495 |
+
if isinstance(file_path, list):
|
| 496 |
+
for file_path_ in file_path:
|
| 497 |
+
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
| 498 |
+
else:
|
| 499 |
+
print(f"Loading LoRA models from file: {file_path}")
|
| 500 |
+
is_loaded = False
|
| 501 |
+
if len(state_dict) == 0:
|
| 502 |
+
state_dict = load_state_dict(file_path)
|
| 503 |
+
for model_name, model, model_path in zip(
|
| 504 |
+
self.model_name, self.model, self.model_path
|
| 505 |
+
):
|
| 506 |
+
for lora in get_lora_loaders():
|
| 507 |
+
match_results = lora.match(model, state_dict)
|
| 508 |
+
if match_results is not None:
|
| 509 |
+
print(f" Adding LoRA to {model_name} ({model_path}).")
|
| 510 |
+
lora_prefix, model_resource = match_results
|
| 511 |
+
lora.load(
|
| 512 |
+
model,
|
| 513 |
+
state_dict,
|
| 514 |
+
lora_prefix,
|
| 515 |
+
alpha=lora_alpha,
|
| 516 |
+
model_resource=model_resource,
|
| 517 |
+
)
|
| 518 |
+
is_loaded = True
|
| 519 |
+
break
|
| 520 |
+
if not is_loaded:
|
| 521 |
+
print(f" Cannot load LoRA: {file_path}")
|
| 522 |
+
|
| 523 |
+
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
| 524 |
+
print(f"Loading models from: {file_path}")
|
| 525 |
+
if device is None:
|
| 526 |
+
device = self.device
|
| 527 |
+
if torch_dtype is None:
|
| 528 |
+
torch_dtype = self.torch_dtype
|
| 529 |
+
if isinstance(file_path, list):
|
| 530 |
+
state_dict = {}
|
| 531 |
+
for path in file_path:
|
| 532 |
+
state_dict.update(load_state_dict(path))
|
| 533 |
+
elif os.path.isfile(file_path):
|
| 534 |
+
state_dict = load_state_dict(file_path)
|
| 535 |
+
else:
|
| 536 |
+
state_dict = None
|
| 537 |
+
for model_detector in self.model_detector:
|
| 538 |
+
if model_detector.match(file_path, state_dict):
|
| 539 |
+
model_names, models = model_detector.load(
|
| 540 |
+
file_path,
|
| 541 |
+
state_dict,
|
| 542 |
+
device=device,
|
| 543 |
+
torch_dtype=torch_dtype,
|
| 544 |
+
allowed_model_names=model_names,
|
| 545 |
+
model_manager=self,
|
| 546 |
+
)
|
| 547 |
+
for model_name, model in zip(model_names, models):
|
| 548 |
+
self.model.append(model)
|
| 549 |
+
self.model_path.append(file_path)
|
| 550 |
+
self.model_name.append(model_name)
|
| 551 |
+
print(f" The following models are loaded: {model_names}.")
|
| 552 |
+
break
|
| 553 |
+
else:
|
| 554 |
+
print(f" We cannot detect the model type. No models are loaded.")
|
| 555 |
+
|
| 556 |
+
def load_models(
|
| 557 |
+
self, file_path_list, model_names=None, device=None, torch_dtype=None
|
| 558 |
+
):
|
| 559 |
+
for file_path in file_path_list:
|
| 560 |
+
self.load_model(
|
| 561 |
+
file_path, model_names, device=device, torch_dtype=torch_dtype
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
def fetch_model(
|
| 565 |
+
self, model_name, file_path=None, require_model_path=False, index=None
|
| 566 |
+
):
|
| 567 |
+
fetched_models = []
|
| 568 |
+
fetched_model_paths = []
|
| 569 |
+
for model, model_path, model_name_ in zip(
|
| 570 |
+
self.model, self.model_path, self.model_name
|
| 571 |
+
):
|
| 572 |
+
if file_path is not None and file_path != model_path:
|
| 573 |
+
continue
|
| 574 |
+
if model_name == model_name_:
|
| 575 |
+
fetched_models.append(model)
|
| 576 |
+
fetched_model_paths.append(model_path)
|
| 577 |
+
if len(fetched_models) == 0:
|
| 578 |
+
print(f"No {model_name} models available.")
|
| 579 |
+
return None
|
| 580 |
+
if len(fetched_models) == 1:
|
| 581 |
+
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
| 582 |
+
model = fetched_models[0]
|
| 583 |
+
path = fetched_model_paths[0]
|
| 584 |
+
else:
|
| 585 |
+
if index is None:
|
| 586 |
+
model = fetched_models[0]
|
| 587 |
+
path = fetched_model_paths[0]
|
| 588 |
+
print(
|
| 589 |
+
f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}."
|
| 590 |
+
)
|
| 591 |
+
elif isinstance(index, int):
|
| 592 |
+
model = fetched_models[:index]
|
| 593 |
+
path = fetched_model_paths[:index]
|
| 594 |
+
print(
|
| 595 |
+
f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}."
|
| 596 |
+
)
|
| 597 |
+
else:
|
| 598 |
+
model = fetched_models
|
| 599 |
+
path = fetched_model_paths
|
| 600 |
+
print(
|
| 601 |
+
f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}."
|
| 602 |
+
)
|
| 603 |
+
if require_model_path:
|
| 604 |
+
return model, path
|
| 605 |
+
else:
|
| 606 |
+
return model
|
| 607 |
+
|
| 608 |
+
def to(self, device):
|
| 609 |
+
for model in self.model:
|
| 610 |
+
model.to(device)
|
models/set_condition_branch.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def set_stand_in(pipe, train=False, model_path=None):
|
| 5 |
+
for block in pipe.dit.blocks:
|
| 6 |
+
block.self_attn.init_lora(train)
|
| 7 |
+
if model_path is not None:
|
| 8 |
+
print(f"Loading Stand-In weights from: {model_path}")
|
| 9 |
+
load_lora_weights_into_pipe(pipe, model_path)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_lora_weights_into_pipe(pipe, ckpt_path, strict=True):
|
| 13 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 14 |
+
state_dict = ckpt.get("state_dict", ckpt)
|
| 15 |
+
|
| 16 |
+
model = {}
|
| 17 |
+
for i, block in enumerate(pipe.dit.blocks):
|
| 18 |
+
prefix = f"blocks.{i}.self_attn."
|
| 19 |
+
attn = block.self_attn
|
| 20 |
+
for name in ["q_loras", "k_loras", "v_loras"]:
|
| 21 |
+
for sub in ["down", "up"]:
|
| 22 |
+
key = f"{prefix}{name}.{sub}.weight"
|
| 23 |
+
if hasattr(getattr(attn, name), sub):
|
| 24 |
+
model[key] = getattr(getattr(attn, name), sub).weight
|
| 25 |
+
else:
|
| 26 |
+
if strict:
|
| 27 |
+
raise KeyError(f"Missing module: {key}")
|
| 28 |
+
|
| 29 |
+
for k, param in state_dict.items():
|
| 30 |
+
if k in model:
|
| 31 |
+
if model[k].shape != param.shape:
|
| 32 |
+
if strict:
|
| 33 |
+
raise ValueError(
|
| 34 |
+
f"Shape mismatch: {k} | {model[k].shape} vs {param.shape}"
|
| 35 |
+
)
|
| 36 |
+
else:
|
| 37 |
+
continue
|
| 38 |
+
model[k].data.copy_(param)
|
| 39 |
+
else:
|
| 40 |
+
if strict:
|
| 41 |
+
raise KeyError(f"Unexpected key in ckpt: {k}")
|
models/tiler.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange, repeat
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TileWorker:
|
| 6 |
+
def __init__(self):
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
def mask(self, height, width, border_width):
|
| 10 |
+
# Create a mask with shape (height, width).
|
| 11 |
+
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
| 12 |
+
x = torch.arange(height).repeat(width, 1).T
|
| 13 |
+
y = torch.arange(width).repeat(height, 1)
|
| 14 |
+
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
| 15 |
+
mask = (mask / border_width).clip(0, 1)
|
| 16 |
+
return mask
|
| 17 |
+
|
| 18 |
+
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
| 19 |
+
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
| 20 |
+
batch_size, channel, _, _ = model_input.shape
|
| 21 |
+
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
| 22 |
+
unfold_operator = torch.nn.Unfold(
|
| 23 |
+
kernel_size=(tile_size, tile_size), stride=(tile_stride, tile_stride)
|
| 24 |
+
)
|
| 25 |
+
model_input = unfold_operator(model_input)
|
| 26 |
+
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
| 27 |
+
|
| 28 |
+
return model_input
|
| 29 |
+
|
| 30 |
+
def tiled_inference(
|
| 31 |
+
self,
|
| 32 |
+
forward_fn,
|
| 33 |
+
model_input,
|
| 34 |
+
tile_batch_size,
|
| 35 |
+
inference_device,
|
| 36 |
+
inference_dtype,
|
| 37 |
+
tile_device,
|
| 38 |
+
tile_dtype,
|
| 39 |
+
):
|
| 40 |
+
# Call y=forward_fn(x) for each tile
|
| 41 |
+
tile_num = model_input.shape[-1]
|
| 42 |
+
model_output_stack = []
|
| 43 |
+
|
| 44 |
+
for tile_id in range(0, tile_num, tile_batch_size):
|
| 45 |
+
# process input
|
| 46 |
+
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
| 47 |
+
x = model_input[:, :, :, :, tile_id:tile_id_]
|
| 48 |
+
x = x.to(device=inference_device, dtype=inference_dtype)
|
| 49 |
+
x = rearrange(x, "b c h w n -> (n b) c h w")
|
| 50 |
+
|
| 51 |
+
# process output
|
| 52 |
+
y = forward_fn(x)
|
| 53 |
+
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_ - tile_id)
|
| 54 |
+
y = y.to(device=tile_device, dtype=tile_dtype)
|
| 55 |
+
model_output_stack.append(y)
|
| 56 |
+
|
| 57 |
+
model_output = torch.concat(model_output_stack, dim=-1)
|
| 58 |
+
return model_output
|
| 59 |
+
|
| 60 |
+
def io_scale(self, model_output, tile_size):
|
| 61 |
+
# Determine the size modification happened in forward_fn
|
| 62 |
+
# We only consider the same scale on height and width.
|
| 63 |
+
io_scale = model_output.shape[2] / tile_size
|
| 64 |
+
return io_scale
|
| 65 |
+
|
| 66 |
+
def untile(
|
| 67 |
+
self,
|
| 68 |
+
model_output,
|
| 69 |
+
height,
|
| 70 |
+
width,
|
| 71 |
+
tile_size,
|
| 72 |
+
tile_stride,
|
| 73 |
+
border_width,
|
| 74 |
+
tile_device,
|
| 75 |
+
tile_dtype,
|
| 76 |
+
):
|
| 77 |
+
# The reversed function of tile
|
| 78 |
+
mask = self.mask(tile_size, tile_size, border_width)
|
| 79 |
+
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
| 80 |
+
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
| 81 |
+
model_output = model_output * mask
|
| 82 |
+
|
| 83 |
+
fold_operator = torch.nn.Fold(
|
| 84 |
+
output_size=(height, width),
|
| 85 |
+
kernel_size=(tile_size, tile_size),
|
| 86 |
+
stride=(tile_stride, tile_stride),
|
| 87 |
+
)
|
| 88 |
+
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
| 89 |
+
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
| 90 |
+
model_output = fold_operator(model_output) / fold_operator(mask)
|
| 91 |
+
|
| 92 |
+
return model_output
|
| 93 |
+
|
| 94 |
+
def tiled_forward(
|
| 95 |
+
self,
|
| 96 |
+
forward_fn,
|
| 97 |
+
model_input,
|
| 98 |
+
tile_size,
|
| 99 |
+
tile_stride,
|
| 100 |
+
tile_batch_size=1,
|
| 101 |
+
tile_device="cpu",
|
| 102 |
+
tile_dtype=torch.float32,
|
| 103 |
+
border_width=None,
|
| 104 |
+
):
|
| 105 |
+
# Prepare
|
| 106 |
+
inference_device, inference_dtype = model_input.device, model_input.dtype
|
| 107 |
+
height, width = model_input.shape[2], model_input.shape[3]
|
| 108 |
+
border_width = int(tile_stride * 0.5) if border_width is None else border_width
|
| 109 |
+
|
| 110 |
+
# tile
|
| 111 |
+
model_input = self.tile(
|
| 112 |
+
model_input, tile_size, tile_stride, tile_device, tile_dtype
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# inference
|
| 116 |
+
model_output = self.tiled_inference(
|
| 117 |
+
forward_fn,
|
| 118 |
+
model_input,
|
| 119 |
+
tile_batch_size,
|
| 120 |
+
inference_device,
|
| 121 |
+
inference_dtype,
|
| 122 |
+
tile_device,
|
| 123 |
+
tile_dtype,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# resize
|
| 127 |
+
io_scale = self.io_scale(model_output, tile_size)
|
| 128 |
+
height, width = int(height * io_scale), int(width * io_scale)
|
| 129 |
+
tile_size, tile_stride = int(tile_size * io_scale), int(tile_stride * io_scale)
|
| 130 |
+
border_width = int(border_width * io_scale)
|
| 131 |
+
|
| 132 |
+
# untile
|
| 133 |
+
model_output = self.untile(
|
| 134 |
+
model_output,
|
| 135 |
+
height,
|
| 136 |
+
width,
|
| 137 |
+
tile_size,
|
| 138 |
+
tile_stride,
|
| 139 |
+
border_width,
|
| 140 |
+
tile_device,
|
| 141 |
+
tile_dtype,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Done!
|
| 145 |
+
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
| 146 |
+
return model_output
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class FastTileWorker:
|
| 150 |
+
def __init__(self):
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
def build_mask(self, data, is_bound):
|
| 154 |
+
_, _, H, W = data.shape
|
| 155 |
+
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
| 156 |
+
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
| 157 |
+
border_width = (H + W) // 4
|
| 158 |
+
pad = torch.ones_like(h) * border_width
|
| 159 |
+
mask = (
|
| 160 |
+
torch.stack(
|
| 161 |
+
[
|
| 162 |
+
pad if is_bound[0] else h + 1,
|
| 163 |
+
pad if is_bound[1] else H - h,
|
| 164 |
+
pad if is_bound[2] else w + 1,
|
| 165 |
+
pad if is_bound[3] else W - w,
|
| 166 |
+
]
|
| 167 |
+
)
|
| 168 |
+
.min(dim=0)
|
| 169 |
+
.values
|
| 170 |
+
)
|
| 171 |
+
mask = mask.clip(1, border_width)
|
| 172 |
+
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
| 173 |
+
mask = rearrange(mask, "H W -> 1 H W")
|
| 174 |
+
return mask
|
| 175 |
+
|
| 176 |
+
def tiled_forward(
|
| 177 |
+
self,
|
| 178 |
+
forward_fn,
|
| 179 |
+
model_input,
|
| 180 |
+
tile_size,
|
| 181 |
+
tile_stride,
|
| 182 |
+
tile_device="cpu",
|
| 183 |
+
tile_dtype=torch.float32,
|
| 184 |
+
border_width=None,
|
| 185 |
+
):
|
| 186 |
+
# Prepare
|
| 187 |
+
B, C, H, W = model_input.shape
|
| 188 |
+
border_width = int(tile_stride * 0.5) if border_width is None else border_width
|
| 189 |
+
weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)
|
| 190 |
+
values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)
|
| 191 |
+
|
| 192 |
+
# Split tasks
|
| 193 |
+
tasks = []
|
| 194 |
+
for h in range(0, H, tile_stride):
|
| 195 |
+
for w in range(0, W, tile_stride):
|
| 196 |
+
if (h - tile_stride >= 0 and h - tile_stride + tile_size >= H) or (
|
| 197 |
+
w - tile_stride >= 0 and w - tile_stride + tile_size >= W
|
| 198 |
+
):
|
| 199 |
+
continue
|
| 200 |
+
h_, w_ = h + tile_size, w + tile_size
|
| 201 |
+
if h_ > H:
|
| 202 |
+
h, h_ = H - tile_size, H
|
| 203 |
+
if w_ > W:
|
| 204 |
+
w, w_ = W - tile_size, W
|
| 205 |
+
tasks.append((h, h_, w, w_))
|
| 206 |
+
|
| 207 |
+
# Run
|
| 208 |
+
for hl, hr, wl, wr in tasks:
|
| 209 |
+
# Forward
|
| 210 |
+
hidden_states_batch = forward_fn(hl, hr, wl, wr).to(
|
| 211 |
+
dtype=tile_dtype, device=tile_device
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
mask = self.build_mask(
|
| 215 |
+
hidden_states_batch, is_bound=(hl == 0, hr >= H, wl == 0, wr >= W)
|
| 216 |
+
)
|
| 217 |
+
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
| 218 |
+
weight[:, :, hl:hr, wl:wr] += mask
|
| 219 |
+
values /= weight
|
| 220 |
+
return values
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class TileWorker2Dto3D:
|
| 224 |
+
"""
|
| 225 |
+
Process 3D tensors, but only enable TileWorker on 2D.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(self):
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
|
| 232 |
+
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
| 233 |
+
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
| 234 |
+
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
| 235 |
+
border_width = (H + W) // 4 if border_width is None else border_width
|
| 236 |
+
pad = torch.ones_like(h) * border_width
|
| 237 |
+
mask = (
|
| 238 |
+
torch.stack(
|
| 239 |
+
[
|
| 240 |
+
pad if is_bound[0] else t + 1,
|
| 241 |
+
pad if is_bound[1] else T - t,
|
| 242 |
+
pad if is_bound[2] else h + 1,
|
| 243 |
+
pad if is_bound[3] else H - h,
|
| 244 |
+
pad if is_bound[4] else w + 1,
|
| 245 |
+
pad if is_bound[5] else W - w,
|
| 246 |
+
]
|
| 247 |
+
)
|
| 248 |
+
.min(dim=0)
|
| 249 |
+
.values
|
| 250 |
+
)
|
| 251 |
+
mask = mask.clip(1, border_width)
|
| 252 |
+
mask = (mask / border_width).to(dtype=dtype, device=device)
|
| 253 |
+
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
| 254 |
+
return mask
|
| 255 |
+
|
| 256 |
+
def tiled_forward(
|
| 257 |
+
self,
|
| 258 |
+
forward_fn,
|
| 259 |
+
model_input,
|
| 260 |
+
tile_size,
|
| 261 |
+
tile_stride,
|
| 262 |
+
tile_device="cpu",
|
| 263 |
+
tile_dtype=torch.float32,
|
| 264 |
+
computation_device="cuda",
|
| 265 |
+
computation_dtype=torch.float32,
|
| 266 |
+
border_width=None,
|
| 267 |
+
scales=[1, 1, 1, 1],
|
| 268 |
+
progress_bar=lambda x: x,
|
| 269 |
+
):
|
| 270 |
+
B, C, T, H, W = model_input.shape
|
| 271 |
+
scale_C, scale_T, scale_H, scale_W = scales
|
| 272 |
+
tile_size_H, tile_size_W = tile_size
|
| 273 |
+
tile_stride_H, tile_stride_W = tile_stride
|
| 274 |
+
|
| 275 |
+
value = torch.zeros(
|
| 276 |
+
(B, int(C * scale_C), int(T * scale_T), int(H * scale_H), int(W * scale_W)),
|
| 277 |
+
dtype=tile_dtype,
|
| 278 |
+
device=tile_device,
|
| 279 |
+
)
|
| 280 |
+
weight = torch.zeros(
|
| 281 |
+
(1, 1, int(T * scale_T), int(H * scale_H), int(W * scale_W)),
|
| 282 |
+
dtype=tile_dtype,
|
| 283 |
+
device=tile_device,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Split tasks
|
| 287 |
+
tasks = []
|
| 288 |
+
for h in range(0, H, tile_stride_H):
|
| 289 |
+
for w in range(0, W, tile_stride_W):
|
| 290 |
+
if (
|
| 291 |
+
h - tile_stride_H >= 0 and h - tile_stride_H + tile_size_H >= H
|
| 292 |
+
) or (w - tile_stride_W >= 0 and w - tile_stride_W + tile_size_W >= W):
|
| 293 |
+
continue
|
| 294 |
+
h_, w_ = h + tile_size_H, w + tile_size_W
|
| 295 |
+
if h_ > H:
|
| 296 |
+
h, h_ = max(H - tile_size_H, 0), H
|
| 297 |
+
if w_ > W:
|
| 298 |
+
w, w_ = max(W - tile_size_W, 0), W
|
| 299 |
+
tasks.append((h, h_, w, w_))
|
| 300 |
+
|
| 301 |
+
# Run
|
| 302 |
+
for hl, hr, wl, wr in progress_bar(tasks):
|
| 303 |
+
mask = self.build_mask(
|
| 304 |
+
int(T * scale_T),
|
| 305 |
+
int((hr - hl) * scale_H),
|
| 306 |
+
int((wr - wl) * scale_W),
|
| 307 |
+
tile_dtype,
|
| 308 |
+
tile_device,
|
| 309 |
+
is_bound=(True, True, hl == 0, hr >= H, wl == 0, wr >= W),
|
| 310 |
+
border_width=border_width,
|
| 311 |
+
)
|
| 312 |
+
grid_input = model_input[:, :, :, hl:hr, wl:wr].to(
|
| 313 |
+
dtype=computation_dtype, device=computation_device
|
| 314 |
+
)
|
| 315 |
+
grid_output = forward_fn(grid_input).to(
|
| 316 |
+
dtype=tile_dtype, device=tile_device
|
| 317 |
+
)
|
| 318 |
+
value[
|
| 319 |
+
:,
|
| 320 |
+
:,
|
| 321 |
+
:,
|
| 322 |
+
int(hl * scale_H) : int(hr * scale_H),
|
| 323 |
+
int(wl * scale_W) : int(wr * scale_W),
|
| 324 |
+
] += grid_output * mask
|
| 325 |
+
weight[
|
| 326 |
+
:,
|
| 327 |
+
:,
|
| 328 |
+
:,
|
| 329 |
+
int(hl * scale_H) : int(hr * scale_H),
|
| 330 |
+
int(wl * scale_W) : int(wr * scale_W),
|
| 331 |
+
] += mask
|
| 332 |
+
value = value / weight
|
| 333 |
+
return value
|
models/utils.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, os
|
| 2 |
+
from safetensors import safe_open
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
import hashlib
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@contextmanager
|
| 8 |
+
def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False):
|
| 9 |
+
old_register_parameter = torch.nn.Module.register_parameter
|
| 10 |
+
if include_buffers:
|
| 11 |
+
old_register_buffer = torch.nn.Module.register_buffer
|
| 12 |
+
|
| 13 |
+
def register_empty_parameter(module, name, param):
|
| 14 |
+
old_register_parameter(module, name, param)
|
| 15 |
+
if param is not None:
|
| 16 |
+
param_cls = type(module._parameters[name])
|
| 17 |
+
kwargs = module._parameters[name].__dict__
|
| 18 |
+
kwargs["requires_grad"] = param.requires_grad
|
| 19 |
+
module._parameters[name] = param_cls(
|
| 20 |
+
module._parameters[name].to(device), **kwargs
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def register_empty_buffer(module, name, buffer, persistent=True):
|
| 24 |
+
old_register_buffer(module, name, buffer, persistent=persistent)
|
| 25 |
+
if buffer is not None:
|
| 26 |
+
module._buffers[name] = module._buffers[name].to(device)
|
| 27 |
+
|
| 28 |
+
def patch_tensor_constructor(fn):
|
| 29 |
+
def wrapper(*args, **kwargs):
|
| 30 |
+
kwargs["device"] = device
|
| 31 |
+
return fn(*args, **kwargs)
|
| 32 |
+
|
| 33 |
+
return wrapper
|
| 34 |
+
|
| 35 |
+
if include_buffers:
|
| 36 |
+
tensor_constructors_to_patch = {
|
| 37 |
+
torch_function_name: getattr(torch, torch_function_name)
|
| 38 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
| 39 |
+
}
|
| 40 |
+
else:
|
| 41 |
+
tensor_constructors_to_patch = {}
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
torch.nn.Module.register_parameter = register_empty_parameter
|
| 45 |
+
if include_buffers:
|
| 46 |
+
torch.nn.Module.register_buffer = register_empty_buffer
|
| 47 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
| 48 |
+
setattr(
|
| 49 |
+
torch,
|
| 50 |
+
torch_function_name,
|
| 51 |
+
patch_tensor_constructor(getattr(torch, torch_function_name)),
|
| 52 |
+
)
|
| 53 |
+
yield
|
| 54 |
+
finally:
|
| 55 |
+
torch.nn.Module.register_parameter = old_register_parameter
|
| 56 |
+
if include_buffers:
|
| 57 |
+
torch.nn.Module.register_buffer = old_register_buffer
|
| 58 |
+
for (
|
| 59 |
+
torch_function_name,
|
| 60 |
+
old_torch_function,
|
| 61 |
+
) in tensor_constructors_to_patch.items():
|
| 62 |
+
setattr(torch, torch_function_name, old_torch_function)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
| 66 |
+
state_dict = {}
|
| 67 |
+
for file_name in os.listdir(file_path):
|
| 68 |
+
if "." in file_name and file_name.split(".")[-1] in [
|
| 69 |
+
"safetensors",
|
| 70 |
+
"bin",
|
| 71 |
+
"ckpt",
|
| 72 |
+
"pth",
|
| 73 |
+
"pt",
|
| 74 |
+
]:
|
| 75 |
+
state_dict.update(
|
| 76 |
+
load_state_dict(
|
| 77 |
+
os.path.join(file_path, file_name), torch_dtype=torch_dtype
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
return state_dict
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
| 84 |
+
if file_path.endswith(".safetensors"):
|
| 85 |
+
return load_state_dict_from_safetensors(
|
| 86 |
+
file_path, torch_dtype=torch_dtype, device=device
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
return load_state_dict_from_bin(
|
| 90 |
+
file_path, torch_dtype=torch_dtype, device=device
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
| 95 |
+
state_dict = {}
|
| 96 |
+
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
| 97 |
+
for k in f.keys():
|
| 98 |
+
state_dict[k] = f.get_tensor(k)
|
| 99 |
+
if torch_dtype is not None:
|
| 100 |
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
| 101 |
+
return state_dict
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
| 105 |
+
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
| 106 |
+
if torch_dtype is not None:
|
| 107 |
+
for i in state_dict:
|
| 108 |
+
if isinstance(state_dict[i], torch.Tensor):
|
| 109 |
+
state_dict[i] = state_dict[i].to(torch_dtype)
|
| 110 |
+
return state_dict
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def search_for_embeddings(state_dict):
|
| 114 |
+
embeddings = []
|
| 115 |
+
for k in state_dict:
|
| 116 |
+
if isinstance(state_dict[k], torch.Tensor):
|
| 117 |
+
embeddings.append(state_dict[k])
|
| 118 |
+
elif isinstance(state_dict[k], dict):
|
| 119 |
+
embeddings += search_for_embeddings(state_dict[k])
|
| 120 |
+
return embeddings
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def search_parameter(param, state_dict):
|
| 124 |
+
for name, param_ in state_dict.items():
|
| 125 |
+
if param.numel() == param_.numel():
|
| 126 |
+
if param.shape == param_.shape:
|
| 127 |
+
if torch.dist(param, param_) < 1e-3:
|
| 128 |
+
return name
|
| 129 |
+
else:
|
| 130 |
+
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
|
| 131 |
+
return name
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
| 136 |
+
matched_keys = set()
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
for name in source_state_dict:
|
| 139 |
+
rename = search_parameter(source_state_dict[name], target_state_dict)
|
| 140 |
+
if rename is not None:
|
| 141 |
+
print(f'"{name}": "{rename}",')
|
| 142 |
+
matched_keys.add(rename)
|
| 143 |
+
elif (
|
| 144 |
+
split_qkv
|
| 145 |
+
and len(source_state_dict[name].shape) >= 1
|
| 146 |
+
and source_state_dict[name].shape[0] % 3 == 0
|
| 147 |
+
):
|
| 148 |
+
length = source_state_dict[name].shape[0] // 3
|
| 149 |
+
rename = []
|
| 150 |
+
for i in range(3):
|
| 151 |
+
rename.append(
|
| 152 |
+
search_parameter(
|
| 153 |
+
source_state_dict[name][i * length : i * length + length],
|
| 154 |
+
target_state_dict,
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
if None not in rename:
|
| 158 |
+
print(f'"{name}": {rename},')
|
| 159 |
+
for rename_ in rename:
|
| 160 |
+
matched_keys.add(rename_)
|
| 161 |
+
for name in target_state_dict:
|
| 162 |
+
if name not in matched_keys:
|
| 163 |
+
print("Cannot find", name, target_state_dict[name].shape)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def search_for_files(folder, extensions):
|
| 167 |
+
files = []
|
| 168 |
+
if os.path.isdir(folder):
|
| 169 |
+
for file in sorted(os.listdir(folder)):
|
| 170 |
+
files += search_for_files(os.path.join(folder, file), extensions)
|
| 171 |
+
elif os.path.isfile(folder):
|
| 172 |
+
for extension in extensions:
|
| 173 |
+
if folder.endswith(extension):
|
| 174 |
+
files.append(folder)
|
| 175 |
+
break
|
| 176 |
+
return files
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
| 180 |
+
keys = []
|
| 181 |
+
for key, value in state_dict.items():
|
| 182 |
+
if isinstance(key, str):
|
| 183 |
+
if isinstance(value, torch.Tensor):
|
| 184 |
+
if with_shape:
|
| 185 |
+
shape = "_".join(map(str, list(value.shape)))
|
| 186 |
+
keys.append(key + ":" + shape)
|
| 187 |
+
keys.append(key)
|
| 188 |
+
elif isinstance(value, dict):
|
| 189 |
+
keys.append(
|
| 190 |
+
key
|
| 191 |
+
+ "|"
|
| 192 |
+
+ convert_state_dict_keys_to_single_str(
|
| 193 |
+
value, with_shape=with_shape
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
+
keys.sort()
|
| 197 |
+
keys_str = ",".join(keys)
|
| 198 |
+
return keys_str
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def split_state_dict_with_prefix(state_dict):
|
| 202 |
+
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
| 203 |
+
prefix_dict = {}
|
| 204 |
+
for key in keys:
|
| 205 |
+
prefix = key if "." not in key else key.split(".")[0]
|
| 206 |
+
if prefix not in prefix_dict:
|
| 207 |
+
prefix_dict[prefix] = []
|
| 208 |
+
prefix_dict[prefix].append(key)
|
| 209 |
+
state_dicts = []
|
| 210 |
+
for prefix, keys in prefix_dict.items():
|
| 211 |
+
sub_state_dict = {key: state_dict[key] for key in keys}
|
| 212 |
+
state_dicts.append(sub_state_dict)
|
| 213 |
+
return state_dicts
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def hash_state_dict_keys(state_dict, with_shape=True):
|
| 217 |
+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
| 218 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
| 219 |
+
return hashlib.md5(keys_str).hexdigest()
|
models/wan_video_camera_controller.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
import os
|
| 6 |
+
from typing_extensions import Literal
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SimpleAdapter(nn.Module):
|
| 10 |
+
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
|
| 11 |
+
super(SimpleAdapter, self).__init__()
|
| 12 |
+
|
| 13 |
+
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
|
| 14 |
+
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
|
| 15 |
+
|
| 16 |
+
# Convolution: reduce spatial dimensions by a factor
|
| 17 |
+
# of 2 (without overlap)
|
| 18 |
+
self.conv = nn.Conv2d(
|
| 19 |
+
in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Residual blocks for feature extraction
|
| 23 |
+
self.residual_blocks = nn.Sequential(
|
| 24 |
+
*[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
# Reshape to merge the frame dimension into batch
|
| 29 |
+
bs, c, f, h, w = x.size()
|
| 30 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
|
| 31 |
+
|
| 32 |
+
# Pixel Unshuffle operation
|
| 33 |
+
x_unshuffled = self.pixel_unshuffle(x)
|
| 34 |
+
|
| 35 |
+
# Convolution operation
|
| 36 |
+
x_conv = self.conv(x_unshuffled)
|
| 37 |
+
|
| 38 |
+
# Feature extraction with residual blocks
|
| 39 |
+
out = self.residual_blocks(x_conv)
|
| 40 |
+
|
| 41 |
+
# Reshape to restore original bf dimension
|
| 42 |
+
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
|
| 43 |
+
|
| 44 |
+
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
|
| 45 |
+
out = out.permute(0, 2, 1, 3, 4)
|
| 46 |
+
|
| 47 |
+
return out
|
| 48 |
+
|
| 49 |
+
def process_camera_coordinates(
|
| 50 |
+
self,
|
| 51 |
+
direction: Literal[
|
| 52 |
+
"Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"
|
| 53 |
+
],
|
| 54 |
+
length: int,
|
| 55 |
+
height: int,
|
| 56 |
+
width: int,
|
| 57 |
+
speed: float = 1 / 54,
|
| 58 |
+
origin=(
|
| 59 |
+
0,
|
| 60 |
+
0.532139961,
|
| 61 |
+
0.946026558,
|
| 62 |
+
0.5,
|
| 63 |
+
0.5,
|
| 64 |
+
0,
|
| 65 |
+
0,
|
| 66 |
+
1,
|
| 67 |
+
0,
|
| 68 |
+
0,
|
| 69 |
+
0,
|
| 70 |
+
0,
|
| 71 |
+
1,
|
| 72 |
+
0,
|
| 73 |
+
0,
|
| 74 |
+
0,
|
| 75 |
+
0,
|
| 76 |
+
1,
|
| 77 |
+
0,
|
| 78 |
+
),
|
| 79 |
+
):
|
| 80 |
+
if origin is None:
|
| 81 |
+
origin = (
|
| 82 |
+
0,
|
| 83 |
+
0.532139961,
|
| 84 |
+
0.946026558,
|
| 85 |
+
0.5,
|
| 86 |
+
0.5,
|
| 87 |
+
0,
|
| 88 |
+
0,
|
| 89 |
+
1,
|
| 90 |
+
0,
|
| 91 |
+
0,
|
| 92 |
+
0,
|
| 93 |
+
0,
|
| 94 |
+
1,
|
| 95 |
+
0,
|
| 96 |
+
0,
|
| 97 |
+
0,
|
| 98 |
+
0,
|
| 99 |
+
1,
|
| 100 |
+
0,
|
| 101 |
+
)
|
| 102 |
+
coordinates = generate_camera_coordinates(direction, length, speed, origin)
|
| 103 |
+
plucker_embedding = process_pose_file(coordinates, width, height)
|
| 104 |
+
return plucker_embedding
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class ResidualBlock(nn.Module):
|
| 108 |
+
def __init__(self, dim):
|
| 109 |
+
super(ResidualBlock, self).__init__()
|
| 110 |
+
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
| 111 |
+
self.relu = nn.ReLU(inplace=True)
|
| 112 |
+
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
| 113 |
+
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
residual = x
|
| 116 |
+
out = self.relu(self.conv1(x))
|
| 117 |
+
out = self.conv2(out)
|
| 118 |
+
out += residual
|
| 119 |
+
return out
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Camera(object):
|
| 123 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, entry):
|
| 126 |
+
fx, fy, cx, cy = entry[1:5]
|
| 127 |
+
self.fx = fx
|
| 128 |
+
self.fy = fy
|
| 129 |
+
self.cx = cx
|
| 130 |
+
self.cy = cy
|
| 131 |
+
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
| 132 |
+
w2c_mat_4x4 = np.eye(4)
|
| 133 |
+
w2c_mat_4x4[:3, :] = w2c_mat
|
| 134 |
+
self.w2c_mat = w2c_mat_4x4
|
| 135 |
+
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_relative_pose(cam_params):
|
| 139 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
|
| 140 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
| 141 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
| 142 |
+
cam_to_origin = 0
|
| 143 |
+
target_cam_c2w = np.array(
|
| 144 |
+
[[1, 0, 0, 0], [0, 1, 0, -cam_to_origin], [0, 0, 1, 0], [0, 0, 0, 1]]
|
| 145 |
+
)
|
| 146 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
| 147 |
+
ret_poses = [
|
| 148 |
+
target_cam_c2w,
|
| 149 |
+
] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
| 150 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
| 151 |
+
return ret_poses
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def custom_meshgrid(*args):
|
| 155 |
+
# torch>=2.0.0 only
|
| 156 |
+
return torch.meshgrid(*args, indexing="ij")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def ray_condition(K, c2w, H, W, device):
|
| 160 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
|
| 161 |
+
# c2w: B, V, 4, 4
|
| 162 |
+
# K: B, V, 4
|
| 163 |
+
|
| 164 |
+
B = K.shape[0]
|
| 165 |
+
|
| 166 |
+
j, i = custom_meshgrid(
|
| 167 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
| 168 |
+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
| 169 |
+
)
|
| 170 |
+
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 171 |
+
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 172 |
+
|
| 173 |
+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
| 174 |
+
|
| 175 |
+
zs = torch.ones_like(i) # [B, HxW]
|
| 176 |
+
xs = (i - cx) / fx * zs
|
| 177 |
+
ys = (j - cy) / fy * zs
|
| 178 |
+
zs = zs.expand_as(ys)
|
| 179 |
+
|
| 180 |
+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
| 181 |
+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
| 182 |
+
|
| 183 |
+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
| 184 |
+
rays_o = c2w[..., :3, 3] # B, V, 3
|
| 185 |
+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
| 186 |
+
# c2w @ dirctions
|
| 187 |
+
rays_dxo = torch.linalg.cross(rays_o, rays_d)
|
| 188 |
+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
| 189 |
+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
| 190 |
+
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
| 191 |
+
return plucker
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def process_pose_file(
|
| 195 |
+
cam_params,
|
| 196 |
+
width=672,
|
| 197 |
+
height=384,
|
| 198 |
+
original_pose_width=1280,
|
| 199 |
+
original_pose_height=720,
|
| 200 |
+
device="cpu",
|
| 201 |
+
return_poses=False,
|
| 202 |
+
):
|
| 203 |
+
if return_poses:
|
| 204 |
+
return cam_params
|
| 205 |
+
else:
|
| 206 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
| 207 |
+
|
| 208 |
+
sample_wh_ratio = width / height
|
| 209 |
+
pose_wh_ratio = (
|
| 210 |
+
original_pose_width / original_pose_height
|
| 211 |
+
) # Assuming placeholder ratios, change as needed
|
| 212 |
+
|
| 213 |
+
if pose_wh_ratio > sample_wh_ratio:
|
| 214 |
+
resized_ori_w = height * pose_wh_ratio
|
| 215 |
+
for cam_param in cam_params:
|
| 216 |
+
cam_param.fx = resized_ori_w * cam_param.fx / width
|
| 217 |
+
else:
|
| 218 |
+
resized_ori_h = width / pose_wh_ratio
|
| 219 |
+
for cam_param in cam_params:
|
| 220 |
+
cam_param.fy = resized_ori_h * cam_param.fy / height
|
| 221 |
+
|
| 222 |
+
intrinsic = np.asarray(
|
| 223 |
+
[
|
| 224 |
+
[
|
| 225 |
+
cam_param.fx * width,
|
| 226 |
+
cam_param.fy * height,
|
| 227 |
+
cam_param.cx * width,
|
| 228 |
+
cam_param.cy * height,
|
| 229 |
+
]
|
| 230 |
+
for cam_param in cam_params
|
| 231 |
+
],
|
| 232 |
+
dtype=np.float32,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
| 236 |
+
c2ws = get_relative_pose(
|
| 237 |
+
cam_params
|
| 238 |
+
) # Assuming this function is defined elsewhere
|
| 239 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
| 240 |
+
plucker_embedding = (
|
| 241 |
+
ray_condition(K, c2ws, height, width, device=device)[0]
|
| 242 |
+
.permute(0, 3, 1, 2)
|
| 243 |
+
.contiguous()
|
| 244 |
+
) # V, 6, H, W
|
| 245 |
+
plucker_embedding = plucker_embedding[None]
|
| 246 |
+
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
| 247 |
+
return plucker_embedding
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def generate_camera_coordinates(
|
| 251 |
+
direction: Literal[
|
| 252 |
+
"Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"
|
| 253 |
+
],
|
| 254 |
+
length: int,
|
| 255 |
+
speed: float = 1 / 54,
|
| 256 |
+
origin=(
|
| 257 |
+
0,
|
| 258 |
+
0.532139961,
|
| 259 |
+
0.946026558,
|
| 260 |
+
0.5,
|
| 261 |
+
0.5,
|
| 262 |
+
0,
|
| 263 |
+
0,
|
| 264 |
+
1,
|
| 265 |
+
0,
|
| 266 |
+
0,
|
| 267 |
+
0,
|
| 268 |
+
0,
|
| 269 |
+
1,
|
| 270 |
+
0,
|
| 271 |
+
0,
|
| 272 |
+
0,
|
| 273 |
+
0,
|
| 274 |
+
1,
|
| 275 |
+
0,
|
| 276 |
+
),
|
| 277 |
+
):
|
| 278 |
+
coordinates = [list(origin)]
|
| 279 |
+
while len(coordinates) < length:
|
| 280 |
+
coor = coordinates[-1].copy()
|
| 281 |
+
if "Left" in direction:
|
| 282 |
+
coor[9] += speed
|
| 283 |
+
if "Right" in direction:
|
| 284 |
+
coor[9] -= speed
|
| 285 |
+
if "Up" in direction:
|
| 286 |
+
coor[13] += speed
|
| 287 |
+
if "Down" in direction:
|
| 288 |
+
coor[13] -= speed
|
| 289 |
+
coordinates.append(coor)
|
| 290 |
+
return coordinates
|
models/wan_video_dit.py
ADDED
|
@@ -0,0 +1,952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
from typing import Tuple, Optional
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from .utils import hash_state_dict_keys
|
| 8 |
+
from .wan_video_camera_controller import SimpleAdapter
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import flash_attn_interface
|
| 12 |
+
|
| 13 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 14 |
+
except ModuleNotFoundError:
|
| 15 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import flash_attn
|
| 19 |
+
|
| 20 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 21 |
+
except ModuleNotFoundError:
|
| 22 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from sageattention import sageattn
|
| 26 |
+
|
| 27 |
+
SAGE_ATTN_AVAILABLE = True
|
| 28 |
+
except ModuleNotFoundError:
|
| 29 |
+
SAGE_ATTN_AVAILABLE = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def flash_attention(
|
| 33 |
+
q: torch.Tensor,
|
| 34 |
+
k: torch.Tensor,
|
| 35 |
+
v: torch.Tensor,
|
| 36 |
+
num_heads: int,
|
| 37 |
+
compatibility_mode=False,
|
| 38 |
+
):
|
| 39 |
+
if compatibility_mode:
|
| 40 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
| 41 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
| 42 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
| 43 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 44 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
| 45 |
+
elif FLASH_ATTN_3_AVAILABLE:
|
| 46 |
+
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
| 47 |
+
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
| 48 |
+
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
| 49 |
+
x = flash_attn_interface.flash_attn_func(q, k, v)
|
| 50 |
+
if isinstance(x, tuple):
|
| 51 |
+
x = x[0]
|
| 52 |
+
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
| 53 |
+
elif FLASH_ATTN_2_AVAILABLE:
|
| 54 |
+
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
| 55 |
+
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
| 56 |
+
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
| 57 |
+
x = flash_attn.flash_attn_func(q, k, v)
|
| 58 |
+
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
| 59 |
+
elif SAGE_ATTN_AVAILABLE:
|
| 60 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
| 61 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
| 62 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
| 63 |
+
x = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
|
| 64 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
| 65 |
+
else:
|
| 66 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
| 67 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
| 68 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
| 69 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 70 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
| 75 |
+
return x * (1 + scale) + shift
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 79 |
+
sinusoid = torch.outer(
|
| 80 |
+
position.type(torch.float64),
|
| 81 |
+
torch.pow(
|
| 82 |
+
10000,
|
| 83 |
+
-torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(
|
| 84 |
+
dim // 2
|
| 85 |
+
),
|
| 86 |
+
),
|
| 87 |
+
)
|
| 88 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 89 |
+
return x.to(position.dtype)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
| 93 |
+
# 3d rope precompute
|
| 94 |
+
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end + 1, theta)
|
| 95 |
+
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
| 96 |
+
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
| 97 |
+
return f_freqs_cis, h_freqs_cis, w_freqs_cis
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
| 101 |
+
# 1d rope precompute
|
| 102 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim))
|
| 103 |
+
###################################################### add f = -1
|
| 104 |
+
positions = torch.arange(-1, end, device=freqs.device)
|
| 105 |
+
freqs = torch.outer(positions, freqs)
|
| 106 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 107 |
+
######################################################
|
| 108 |
+
return freqs_cis
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def rope_apply(x, freqs, num_heads):
|
| 112 |
+
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
| 113 |
+
x_out = torch.view_as_complex(
|
| 114 |
+
x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
| 115 |
+
)
|
| 116 |
+
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
| 117 |
+
return x_out.to(x.dtype)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class RMSNorm(nn.Module):
|
| 121 |
+
def __init__(self, dim, eps=1e-5):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.eps = eps
|
| 124 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 125 |
+
|
| 126 |
+
def norm(self, x):
|
| 127 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
dtype = x.dtype
|
| 131 |
+
return self.norm(x.float()).to(dtype) * self.weight
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class AttentionModule(nn.Module):
|
| 135 |
+
def __init__(self, num_heads):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.num_heads = num_heads
|
| 138 |
+
|
| 139 |
+
def forward(self, q, k, v):
|
| 140 |
+
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
|
| 141 |
+
return x
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class LoRALinearLayer(nn.Module):
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
in_features: int,
|
| 148 |
+
out_features: int,
|
| 149 |
+
rank: int = 128,
|
| 150 |
+
device="cuda",
|
| 151 |
+
dtype: Optional[torch.dtype] = torch.float32,
|
| 152 |
+
):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
| 155 |
+
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
| 156 |
+
self.rank = rank
|
| 157 |
+
self.out_features = out_features
|
| 158 |
+
self.in_features = in_features
|
| 159 |
+
|
| 160 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
| 161 |
+
nn.init.zeros_(self.up.weight)
|
| 162 |
+
|
| 163 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 164 |
+
orig_dtype = hidden_states.dtype
|
| 165 |
+
dtype = self.down.weight.dtype
|
| 166 |
+
|
| 167 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
| 168 |
+
up_hidden_states = self.up(down_hidden_states)
|
| 169 |
+
return up_hidden_states.to(orig_dtype)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class SelfAttention(nn.Module):
|
| 173 |
+
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.dim = dim
|
| 176 |
+
self.num_heads = num_heads
|
| 177 |
+
self.head_dim = dim // num_heads
|
| 178 |
+
|
| 179 |
+
self.q = nn.Linear(dim, dim)
|
| 180 |
+
self.k = nn.Linear(dim, dim)
|
| 181 |
+
self.v = nn.Linear(dim, dim)
|
| 182 |
+
self.o = nn.Linear(dim, dim)
|
| 183 |
+
self.norm_q = RMSNorm(dim, eps=eps)
|
| 184 |
+
self.norm_k = RMSNorm(dim, eps=eps)
|
| 185 |
+
|
| 186 |
+
self.attn = AttentionModule(self.num_heads)
|
| 187 |
+
|
| 188 |
+
self.kv_cache = None
|
| 189 |
+
self.cond_size = None
|
| 190 |
+
|
| 191 |
+
def init_lora(self, train=False):
|
| 192 |
+
dim = self.dim
|
| 193 |
+
self.q_loras = LoRALinearLayer(dim, dim, rank=128)
|
| 194 |
+
self.k_loras = LoRALinearLayer(dim, dim, rank=128)
|
| 195 |
+
self.v_loras = LoRALinearLayer(dim, dim, rank=128)
|
| 196 |
+
|
| 197 |
+
requires_grad = train
|
| 198 |
+
for lora in [self.q_loras, self.k_loras, self.v_loras]:
|
| 199 |
+
for param in lora.parameters():
|
| 200 |
+
param.requires_grad = requires_grad
|
| 201 |
+
|
| 202 |
+
def forward(self, x, freqs):
|
| 203 |
+
if self.cond_size is not None:
|
| 204 |
+
if self.kv_cache is None:
|
| 205 |
+
x_main, x_ip = x[:, : -self.cond_size], x[:, -self.cond_size :]
|
| 206 |
+
split_point = freqs.shape[0] - self.cond_size
|
| 207 |
+
freqs_main = freqs[:split_point]
|
| 208 |
+
freqs_ip = freqs[split_point:]
|
| 209 |
+
|
| 210 |
+
q_main = self.norm_q(self.q(x_main))
|
| 211 |
+
k_main = self.norm_k(self.k(x_main))
|
| 212 |
+
v_main = self.v(x_main)
|
| 213 |
+
|
| 214 |
+
q_main = rope_apply(q_main, freqs_main, self.num_heads)
|
| 215 |
+
k_main = rope_apply(k_main, freqs_main, self.num_heads)
|
| 216 |
+
|
| 217 |
+
q_ip = self.norm_q(self.q(x_ip) + self.q_loras(x_ip))
|
| 218 |
+
k_ip = self.norm_k(self.k(x_ip) + self.k_loras(x_ip))
|
| 219 |
+
v_ip = self.v(x_ip) + self.v_loras(x_ip)
|
| 220 |
+
|
| 221 |
+
q_ip = rope_apply(q_ip, freqs_ip, self.num_heads)
|
| 222 |
+
k_ip = rope_apply(k_ip, freqs_ip, self.num_heads)
|
| 223 |
+
self.kv_cache = {"k_ip": k_ip.detach(), "v_ip": v_ip.detach()}
|
| 224 |
+
full_k = torch.concat([k_main, k_ip], dim=1)
|
| 225 |
+
full_v = torch.concat([v_main, v_ip], dim=1)
|
| 226 |
+
cond_out = self.attn(q_ip, k_ip, v_ip)
|
| 227 |
+
main_out = self.attn(q_main, full_k, full_v)
|
| 228 |
+
out = torch.concat([main_out, cond_out], dim=1)
|
| 229 |
+
return self.o(out)
|
| 230 |
+
|
| 231 |
+
else:
|
| 232 |
+
k_ip = self.kv_cache["k_ip"]
|
| 233 |
+
v_ip = self.kv_cache["v_ip"]
|
| 234 |
+
q_main = self.norm_q(self.q(x))
|
| 235 |
+
k_main = self.norm_k(self.k(x))
|
| 236 |
+
v_main = self.v(x)
|
| 237 |
+
q_main = rope_apply(q_main, freqs, self.num_heads)
|
| 238 |
+
k_main = rope_apply(k_main, freqs, self.num_heads)
|
| 239 |
+
|
| 240 |
+
full_k = torch.concat([k_main, k_ip], dim=1)
|
| 241 |
+
full_v = torch.concat([v_main, v_ip], dim=1)
|
| 242 |
+
x = self.attn(q_main, full_k, full_v)
|
| 243 |
+
return self.o(x)
|
| 244 |
+
else:
|
| 245 |
+
q = self.norm_q(self.q(x))
|
| 246 |
+
k = self.norm_k(self.k(x))
|
| 247 |
+
v = self.v(x)
|
| 248 |
+
q = rope_apply(q, freqs, self.num_heads)
|
| 249 |
+
k = rope_apply(k, freqs, self.num_heads)
|
| 250 |
+
x = self.attn(q, k, v)
|
| 251 |
+
return self.o(x)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class CrossAttention(nn.Module):
|
| 255 |
+
def __init__(
|
| 256 |
+
self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False
|
| 257 |
+
):
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.dim = dim
|
| 260 |
+
self.num_heads = num_heads
|
| 261 |
+
self.head_dim = dim // num_heads
|
| 262 |
+
|
| 263 |
+
self.q = nn.Linear(dim, dim)
|
| 264 |
+
self.k = nn.Linear(dim, dim)
|
| 265 |
+
self.v = nn.Linear(dim, dim)
|
| 266 |
+
self.o = nn.Linear(dim, dim)
|
| 267 |
+
self.norm_q = RMSNorm(dim, eps=eps)
|
| 268 |
+
self.norm_k = RMSNorm(dim, eps=eps)
|
| 269 |
+
self.has_image_input = has_image_input
|
| 270 |
+
if has_image_input:
|
| 271 |
+
self.k_img = nn.Linear(dim, dim)
|
| 272 |
+
self.v_img = nn.Linear(dim, dim)
|
| 273 |
+
self.norm_k_img = RMSNorm(dim, eps=eps)
|
| 274 |
+
|
| 275 |
+
self.attn = AttentionModule(self.num_heads)
|
| 276 |
+
|
| 277 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
| 278 |
+
if self.has_image_input:
|
| 279 |
+
img = y[:, :257]
|
| 280 |
+
ctx = y[:, 257:]
|
| 281 |
+
else:
|
| 282 |
+
ctx = y
|
| 283 |
+
q = self.norm_q(self.q(x))
|
| 284 |
+
k = self.norm_k(self.k(ctx))
|
| 285 |
+
v = self.v(ctx)
|
| 286 |
+
x = self.attn(q, k, v)
|
| 287 |
+
if self.has_image_input:
|
| 288 |
+
k_img = self.norm_k_img(self.k_img(img))
|
| 289 |
+
v_img = self.v_img(img)
|
| 290 |
+
y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
|
| 291 |
+
x = x + y
|
| 292 |
+
return self.o(x)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class GateModule(nn.Module):
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
):
|
| 299 |
+
super().__init__()
|
| 300 |
+
|
| 301 |
+
def forward(self, x, gate, residual):
|
| 302 |
+
return x + gate * residual
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class DiTBlock(nn.Module):
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
has_image_input: bool,
|
| 309 |
+
dim: int,
|
| 310 |
+
num_heads: int,
|
| 311 |
+
ffn_dim: int,
|
| 312 |
+
eps: float = 1e-6,
|
| 313 |
+
):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.dim = dim
|
| 316 |
+
self.num_heads = num_heads
|
| 317 |
+
self.ffn_dim = ffn_dim
|
| 318 |
+
|
| 319 |
+
self.self_attn = SelfAttention(dim, num_heads, eps)
|
| 320 |
+
self.cross_attn = CrossAttention(
|
| 321 |
+
dim, num_heads, eps, has_image_input=has_image_input
|
| 322 |
+
)
|
| 323 |
+
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
| 324 |
+
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
| 325 |
+
self.norm3 = nn.LayerNorm(dim, eps=eps)
|
| 326 |
+
self.ffn = nn.Sequential(
|
| 327 |
+
nn.Linear(dim, ffn_dim),
|
| 328 |
+
nn.GELU(approximate="tanh"),
|
| 329 |
+
nn.Linear(ffn_dim, dim),
|
| 330 |
+
)
|
| 331 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 332 |
+
self.gate = GateModule()
|
| 333 |
+
|
| 334 |
+
def forward(self, x, context, t_mod, freqs, x_ip=None, t_mod_ip=None):
|
| 335 |
+
# msa: multi-head self-attention mlp: multi-layer perceptron
|
| 336 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 337 |
+
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
|
| 338 |
+
).chunk(6, dim=1)
|
| 339 |
+
|
| 340 |
+
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
| 341 |
+
|
| 342 |
+
if x_ip is not None:
|
| 343 |
+
(
|
| 344 |
+
shift_msa_ip,
|
| 345 |
+
scale_msa_ip,
|
| 346 |
+
gate_msa_ip,
|
| 347 |
+
shift_mlp_ip,
|
| 348 |
+
scale_mlp_ip,
|
| 349 |
+
gate_mlp_ip,
|
| 350 |
+
) = (
|
| 351 |
+
self.modulation.to(dtype=t_mod_ip.dtype, device=t_mod_ip.device)
|
| 352 |
+
+ t_mod_ip
|
| 353 |
+
).chunk(6, dim=1)
|
| 354 |
+
input_x_ip = modulate(
|
| 355 |
+
self.norm1(x_ip), shift_msa_ip, scale_msa_ip
|
| 356 |
+
) # [1, 1024, 5120]
|
| 357 |
+
self.self_attn.cond_size = input_x_ip.shape[1]
|
| 358 |
+
input_x = torch.concat([input_x, input_x_ip], dim=1)
|
| 359 |
+
self.self_attn.kv_cache = None
|
| 360 |
+
|
| 361 |
+
attn_out = self.self_attn(input_x, freqs)
|
| 362 |
+
if x_ip is not None:
|
| 363 |
+
attn_out, attn_out_ip = (
|
| 364 |
+
attn_out[:, : -self.self_attn.cond_size],
|
| 365 |
+
attn_out[:, -self.self_attn.cond_size :],
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
x = self.gate(x, gate_msa, attn_out)
|
| 369 |
+
x = x + self.cross_attn(self.norm3(x), context)
|
| 370 |
+
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
| 371 |
+
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
| 372 |
+
|
| 373 |
+
if x_ip is not None:
|
| 374 |
+
x_ip = self.gate(x_ip, gate_msa_ip, attn_out_ip)
|
| 375 |
+
input_x_ip = modulate(self.norm2(x_ip), shift_mlp_ip, scale_mlp_ip)
|
| 376 |
+
x_ip = self.gate(x_ip, gate_mlp_ip, self.ffn(input_x_ip))
|
| 377 |
+
return x, x_ip
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class MLP(torch.nn.Module):
|
| 381 |
+
def __init__(self, in_dim, out_dim, has_pos_emb=False):
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.proj = torch.nn.Sequential(
|
| 384 |
+
nn.LayerNorm(in_dim),
|
| 385 |
+
nn.Linear(in_dim, in_dim),
|
| 386 |
+
nn.GELU(),
|
| 387 |
+
nn.Linear(in_dim, out_dim),
|
| 388 |
+
nn.LayerNorm(out_dim),
|
| 389 |
+
)
|
| 390 |
+
self.has_pos_emb = has_pos_emb
|
| 391 |
+
if has_pos_emb:
|
| 392 |
+
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
|
| 393 |
+
|
| 394 |
+
def forward(self, x):
|
| 395 |
+
if self.has_pos_emb:
|
| 396 |
+
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
|
| 397 |
+
return self.proj(x)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class Head(nn.Module):
|
| 401 |
+
def __init__(
|
| 402 |
+
self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float
|
| 403 |
+
):
|
| 404 |
+
super().__init__()
|
| 405 |
+
self.dim = dim
|
| 406 |
+
self.patch_size = patch_size
|
| 407 |
+
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
| 408 |
+
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
|
| 409 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 410 |
+
|
| 411 |
+
def forward(self, x, t_mod):
|
| 412 |
+
if len(t_mod.shape) == 3:
|
| 413 |
+
shift, scale = (
|
| 414 |
+
self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device)
|
| 415 |
+
+ t_mod.unsqueeze(2)
|
| 416 |
+
).chunk(2, dim=2)
|
| 417 |
+
x = self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))
|
| 418 |
+
else:
|
| 419 |
+
shift, scale = (
|
| 420 |
+
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
|
| 421 |
+
).chunk(2, dim=1)
|
| 422 |
+
x = self.head(self.norm(x) * (1 + scale) + shift)
|
| 423 |
+
return x
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class WanModel(torch.nn.Module):
|
| 427 |
+
def __init__(
|
| 428 |
+
self,
|
| 429 |
+
dim: int,
|
| 430 |
+
in_dim: int,
|
| 431 |
+
ffn_dim: int,
|
| 432 |
+
out_dim: int,
|
| 433 |
+
text_dim: int,
|
| 434 |
+
freq_dim: int,
|
| 435 |
+
eps: float,
|
| 436 |
+
patch_size: Tuple[int, int, int],
|
| 437 |
+
num_heads: int,
|
| 438 |
+
num_layers: int,
|
| 439 |
+
has_image_input: bool,
|
| 440 |
+
has_image_pos_emb: bool = False,
|
| 441 |
+
has_ref_conv: bool = False,
|
| 442 |
+
add_control_adapter: bool = False,
|
| 443 |
+
in_dim_control_adapter: int = 24,
|
| 444 |
+
seperated_timestep: bool = False,
|
| 445 |
+
require_vae_embedding: bool = True,
|
| 446 |
+
require_clip_embedding: bool = True,
|
| 447 |
+
fuse_vae_embedding_in_latents: bool = False,
|
| 448 |
+
):
|
| 449 |
+
super().__init__()
|
| 450 |
+
self.dim = dim
|
| 451 |
+
self.freq_dim = freq_dim
|
| 452 |
+
self.has_image_input = has_image_input
|
| 453 |
+
self.patch_size = patch_size
|
| 454 |
+
self.seperated_timestep = seperated_timestep
|
| 455 |
+
self.require_vae_embedding = require_vae_embedding
|
| 456 |
+
self.require_clip_embedding = require_clip_embedding
|
| 457 |
+
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
| 458 |
+
|
| 459 |
+
self.patch_embedding = nn.Conv3d(
|
| 460 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size
|
| 461 |
+
)
|
| 462 |
+
self.text_embedding = nn.Sequential(
|
| 463 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
|
| 464 |
+
)
|
| 465 |
+
self.time_embedding = nn.Sequential(
|
| 466 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
|
| 467 |
+
)
|
| 468 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 469 |
+
self.blocks = nn.ModuleList(
|
| 470 |
+
[
|
| 471 |
+
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
|
| 472 |
+
for _ in range(num_layers)
|
| 473 |
+
]
|
| 474 |
+
)
|
| 475 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 476 |
+
head_dim = dim // num_heads
|
| 477 |
+
self.freqs = precompute_freqs_cis_3d(head_dim)
|
| 478 |
+
|
| 479 |
+
if has_image_input:
|
| 480 |
+
self.img_emb = MLP(
|
| 481 |
+
1280, dim, has_pos_emb=has_image_pos_emb
|
| 482 |
+
) # clip_feature_dim = 1280
|
| 483 |
+
if has_ref_conv:
|
| 484 |
+
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
|
| 485 |
+
self.has_image_pos_emb = has_image_pos_emb
|
| 486 |
+
self.has_ref_conv = has_ref_conv
|
| 487 |
+
if add_control_adapter:
|
| 488 |
+
self.control_adapter = SimpleAdapter(
|
| 489 |
+
in_dim_control_adapter,
|
| 490 |
+
dim,
|
| 491 |
+
kernel_size=patch_size[1:],
|
| 492 |
+
stride=patch_size[1:],
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
self.control_adapter = None
|
| 496 |
+
|
| 497 |
+
def patchify(
|
| 498 |
+
self, x: torch.Tensor, control_camera_latents_input: torch.Tensor = None
|
| 499 |
+
):
|
| 500 |
+
x = self.patch_embedding(x)
|
| 501 |
+
if (
|
| 502 |
+
self.control_adapter is not None
|
| 503 |
+
and control_camera_latents_input is not None
|
| 504 |
+
):
|
| 505 |
+
y_camera = self.control_adapter(control_camera_latents_input)
|
| 506 |
+
x = [u + v for u, v in zip(x, y_camera)]
|
| 507 |
+
x = x[0].unsqueeze(0)
|
| 508 |
+
grid_size = x.shape[2:]
|
| 509 |
+
x = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
|
| 510 |
+
return x, grid_size # x, grid_size: (f, h, w)
|
| 511 |
+
|
| 512 |
+
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
| 513 |
+
return rearrange(
|
| 514 |
+
x,
|
| 515 |
+
"b (f h w) (x y z c) -> b c (f x) (h y) (w z)",
|
| 516 |
+
f=grid_size[0],
|
| 517 |
+
h=grid_size[1],
|
| 518 |
+
w=grid_size[2],
|
| 519 |
+
x=self.patch_size[0],
|
| 520 |
+
y=self.patch_size[1],
|
| 521 |
+
z=self.patch_size[2],
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
def forward(
|
| 525 |
+
self,
|
| 526 |
+
x: torch.Tensor,
|
| 527 |
+
timestep: torch.Tensor,
|
| 528 |
+
context: torch.Tensor,
|
| 529 |
+
clip_feature: Optional[torch.Tensor] = None,
|
| 530 |
+
y: Optional[torch.Tensor] = None,
|
| 531 |
+
use_gradient_checkpointing: bool = False,
|
| 532 |
+
use_gradient_checkpointing_offload: bool = False,
|
| 533 |
+
ip_image=None,
|
| 534 |
+
**kwargs,
|
| 535 |
+
):
|
| 536 |
+
x_ip = None
|
| 537 |
+
t_mod_ip = None
|
| 538 |
+
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
| 539 |
+
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
| 540 |
+
context = self.text_embedding(context)
|
| 541 |
+
|
| 542 |
+
if ip_image is not None:
|
| 543 |
+
timestep_ip = torch.zeros_like(timestep) # [B] with 0s
|
| 544 |
+
t_ip = self.time_embedding(
|
| 545 |
+
sinusoidal_embedding_1d(self.freq_dim, timestep_ip)
|
| 546 |
+
)
|
| 547 |
+
t_mod_ip = self.time_projection(t_ip).unflatten(1, (6, self.dim))
|
| 548 |
+
x, (f, h, w) = self.patchify(x)
|
| 549 |
+
|
| 550 |
+
offset = 1
|
| 551 |
+
freqs = (
|
| 552 |
+
torch.cat(
|
| 553 |
+
[
|
| 554 |
+
self.freqs[0][offset : f + offset]
|
| 555 |
+
.view(f, 1, 1, -1)
|
| 556 |
+
.expand(f, h, w, -1),
|
| 557 |
+
self.freqs[1][offset : h + offset]
|
| 558 |
+
.view(1, h, 1, -1)
|
| 559 |
+
.expand(f, h, w, -1),
|
| 560 |
+
self.freqs[2][offset : w + offset]
|
| 561 |
+
.view(1, 1, w, -1)
|
| 562 |
+
.expand(f, h, w, -1),
|
| 563 |
+
],
|
| 564 |
+
dim=-1,
|
| 565 |
+
)
|
| 566 |
+
.reshape(f * h * w, 1, -1)
|
| 567 |
+
.to(x.device)
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
############################################################################################
|
| 571 |
+
if ip_image is not None:
|
| 572 |
+
if ip_image.dim() == 6 and ip_image.shape[3] == 1:
|
| 573 |
+
ip_image = ip_image.squeeze(1)
|
| 574 |
+
x_ip, (f_ip, h_ip, w_ip) = self.patchify(
|
| 575 |
+
ip_image
|
| 576 |
+
) # x_ip [1, 1024, 5120] [B, N, D] f_ip = 1 h_ip = 32 w_ip = 32
|
| 577 |
+
freqs_ip = (
|
| 578 |
+
torch.cat(
|
| 579 |
+
[
|
| 580 |
+
self.freqs[0][0]
|
| 581 |
+
.view(f_ip, 1, 1, -1)
|
| 582 |
+
.expand(f_ip, h_ip, w_ip, -1),
|
| 583 |
+
self.freqs[1][h + offset : h + offset + h_ip]
|
| 584 |
+
.view(1, h_ip, 1, -1)
|
| 585 |
+
.expand(f_ip, h_ip, w_ip, -1),
|
| 586 |
+
self.freqs[2][w + offset : w + offset + w_ip]
|
| 587 |
+
.view(1, 1, w_ip, -1)
|
| 588 |
+
.expand(f_ip, h_ip, w_ip, -1),
|
| 589 |
+
],
|
| 590 |
+
dim=-1,
|
| 591 |
+
)
|
| 592 |
+
.reshape(f_ip * h_ip * w_ip, 1, -1)
|
| 593 |
+
.to(x_ip.device)
|
| 594 |
+
)
|
| 595 |
+
freqs = torch.cat([freqs, freqs_ip], dim=0)
|
| 596 |
+
|
| 597 |
+
############################################################################################
|
| 598 |
+
def create_custom_forward(module):
|
| 599 |
+
def custom_forward(*inputs):
|
| 600 |
+
return module(*inputs)
|
| 601 |
+
|
| 602 |
+
return custom_forward
|
| 603 |
+
|
| 604 |
+
for block in self.blocks:
|
| 605 |
+
if self.training and use_gradient_checkpointing:
|
| 606 |
+
if use_gradient_checkpointing_offload:
|
| 607 |
+
with torch.autograd.graph.save_on_cpu():
|
| 608 |
+
x, x_ip = torch.utils.checkpoint.checkpoint(
|
| 609 |
+
create_custom_forward(block),
|
| 610 |
+
x,
|
| 611 |
+
context,
|
| 612 |
+
t_mod,
|
| 613 |
+
freqs,
|
| 614 |
+
x_ip,
|
| 615 |
+
t_mod_ip,
|
| 616 |
+
use_reentrant=False,
|
| 617 |
+
)
|
| 618 |
+
else:
|
| 619 |
+
x, x_ip = torch.utils.checkpoint.checkpoint(
|
| 620 |
+
create_custom_forward(block),
|
| 621 |
+
x,
|
| 622 |
+
context,
|
| 623 |
+
t_mod,
|
| 624 |
+
freqs,
|
| 625 |
+
x_ip,
|
| 626 |
+
t_mod_ip,
|
| 627 |
+
use_reentrant=False,
|
| 628 |
+
)
|
| 629 |
+
else:
|
| 630 |
+
x, x_ip = block(x, context, t_mod, freqs, x_ip, t_mod_ip)
|
| 631 |
+
|
| 632 |
+
x = self.head(x, t)
|
| 633 |
+
x = self.unpatchify(x, (f, h, w))
|
| 634 |
+
return x
|
| 635 |
+
|
| 636 |
+
@staticmethod
|
| 637 |
+
def state_dict_converter():
|
| 638 |
+
return WanModelStateDictConverter()
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
class WanModelStateDictConverter:
|
| 642 |
+
def __init__(self):
|
| 643 |
+
pass
|
| 644 |
+
|
| 645 |
+
def from_diffusers(self, state_dict):
|
| 646 |
+
rename_dict = {
|
| 647 |
+
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
| 648 |
+
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
| 649 |
+
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
| 650 |
+
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
| 651 |
+
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
| 652 |
+
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
| 653 |
+
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
| 654 |
+
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
| 655 |
+
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
| 656 |
+
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
| 657 |
+
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
| 658 |
+
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
| 659 |
+
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
| 660 |
+
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
| 661 |
+
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
| 662 |
+
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
| 663 |
+
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
| 664 |
+
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
| 665 |
+
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
| 666 |
+
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
| 667 |
+
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
| 668 |
+
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
| 669 |
+
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
| 670 |
+
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
| 671 |
+
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
| 672 |
+
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
| 673 |
+
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
| 674 |
+
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
| 675 |
+
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
| 676 |
+
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
| 677 |
+
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
| 678 |
+
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
| 679 |
+
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
| 680 |
+
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
| 681 |
+
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
| 682 |
+
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
| 683 |
+
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
| 684 |
+
"patch_embedding.bias": "patch_embedding.bias",
|
| 685 |
+
"patch_embedding.weight": "patch_embedding.weight",
|
| 686 |
+
"scale_shift_table": "head.modulation",
|
| 687 |
+
"proj_out.bias": "head.head.bias",
|
| 688 |
+
"proj_out.weight": "head.head.weight",
|
| 689 |
+
}
|
| 690 |
+
state_dict_ = {}
|
| 691 |
+
for name, param in state_dict.items():
|
| 692 |
+
if name in rename_dict:
|
| 693 |
+
state_dict_[rename_dict[name]] = param
|
| 694 |
+
else:
|
| 695 |
+
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
| 696 |
+
if name_ in rename_dict:
|
| 697 |
+
name_ = rename_dict[name_]
|
| 698 |
+
name_ = ".".join(
|
| 699 |
+
name_.split(".")[:1]
|
| 700 |
+
+ [name.split(".")[1]]
|
| 701 |
+
+ name_.split(".")[2:]
|
| 702 |
+
)
|
| 703 |
+
state_dict_[name_] = param
|
| 704 |
+
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
| 705 |
+
config = {
|
| 706 |
+
"model_type": "t2v",
|
| 707 |
+
"patch_size": (1, 2, 2),
|
| 708 |
+
"text_len": 512,
|
| 709 |
+
"in_dim": 16,
|
| 710 |
+
"dim": 5120,
|
| 711 |
+
"ffn_dim": 13824,
|
| 712 |
+
"freq_dim": 256,
|
| 713 |
+
"text_dim": 4096,
|
| 714 |
+
"out_dim": 16,
|
| 715 |
+
"num_heads": 40,
|
| 716 |
+
"num_layers": 40,
|
| 717 |
+
"window_size": (-1, -1),
|
| 718 |
+
"qk_norm": True,
|
| 719 |
+
"cross_attn_norm": True,
|
| 720 |
+
"eps": 1e-6,
|
| 721 |
+
}
|
| 722 |
+
else:
|
| 723 |
+
config = {}
|
| 724 |
+
return state_dict_, config
|
| 725 |
+
|
| 726 |
+
def from_civitai(self, state_dict):
|
| 727 |
+
state_dict = {
|
| 728 |
+
name: param
|
| 729 |
+
for name, param in state_dict.items()
|
| 730 |
+
if not name.startswith("vace")
|
| 731 |
+
}
|
| 732 |
+
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
| 733 |
+
config = {
|
| 734 |
+
"has_image_input": False,
|
| 735 |
+
"patch_size": [1, 2, 2],
|
| 736 |
+
"in_dim": 16,
|
| 737 |
+
"dim": 1536,
|
| 738 |
+
"ffn_dim": 8960,
|
| 739 |
+
"freq_dim": 256,
|
| 740 |
+
"text_dim": 4096,
|
| 741 |
+
"out_dim": 16,
|
| 742 |
+
"num_heads": 12,
|
| 743 |
+
"num_layers": 30,
|
| 744 |
+
"eps": 1e-6,
|
| 745 |
+
}
|
| 746 |
+
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
| 747 |
+
config = {
|
| 748 |
+
"has_image_input": False,
|
| 749 |
+
"patch_size": [1, 2, 2],
|
| 750 |
+
"in_dim": 16,
|
| 751 |
+
"dim": 5120,
|
| 752 |
+
"ffn_dim": 13824,
|
| 753 |
+
"freq_dim": 256,
|
| 754 |
+
"text_dim": 4096,
|
| 755 |
+
"out_dim": 16,
|
| 756 |
+
"num_heads": 40,
|
| 757 |
+
"num_layers": 40,
|
| 758 |
+
"eps": 1e-6,
|
| 759 |
+
}
|
| 760 |
+
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
| 761 |
+
config = {
|
| 762 |
+
"has_image_input": True,
|
| 763 |
+
"patch_size": [1, 2, 2],
|
| 764 |
+
"in_dim": 36,
|
| 765 |
+
"dim": 5120,
|
| 766 |
+
"ffn_dim": 13824,
|
| 767 |
+
"freq_dim": 256,
|
| 768 |
+
"text_dim": 4096,
|
| 769 |
+
"out_dim": 16,
|
| 770 |
+
"num_heads": 40,
|
| 771 |
+
"num_layers": 40,
|
| 772 |
+
"eps": 1e-6,
|
| 773 |
+
}
|
| 774 |
+
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
|
| 775 |
+
config = {
|
| 776 |
+
"has_image_input": True,
|
| 777 |
+
"patch_size": [1, 2, 2],
|
| 778 |
+
"in_dim": 36,
|
| 779 |
+
"dim": 1536,
|
| 780 |
+
"ffn_dim": 8960,
|
| 781 |
+
"freq_dim": 256,
|
| 782 |
+
"text_dim": 4096,
|
| 783 |
+
"out_dim": 16,
|
| 784 |
+
"num_heads": 12,
|
| 785 |
+
"num_layers": 30,
|
| 786 |
+
"eps": 1e-6,
|
| 787 |
+
}
|
| 788 |
+
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
| 789 |
+
config = {
|
| 790 |
+
"has_image_input": True,
|
| 791 |
+
"patch_size": [1, 2, 2],
|
| 792 |
+
"in_dim": 36,
|
| 793 |
+
"dim": 5120,
|
| 794 |
+
"ffn_dim": 13824,
|
| 795 |
+
"freq_dim": 256,
|
| 796 |
+
"text_dim": 4096,
|
| 797 |
+
"out_dim": 16,
|
| 798 |
+
"num_heads": 40,
|
| 799 |
+
"num_layers": 40,
|
| 800 |
+
"eps": 1e-6,
|
| 801 |
+
}
|
| 802 |
+
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
| 803 |
+
# 1.3B PAI control
|
| 804 |
+
config = {
|
| 805 |
+
"has_image_input": True,
|
| 806 |
+
"patch_size": [1, 2, 2],
|
| 807 |
+
"in_dim": 48,
|
| 808 |
+
"dim": 1536,
|
| 809 |
+
"ffn_dim": 8960,
|
| 810 |
+
"freq_dim": 256,
|
| 811 |
+
"text_dim": 4096,
|
| 812 |
+
"out_dim": 16,
|
| 813 |
+
"num_heads": 12,
|
| 814 |
+
"num_layers": 30,
|
| 815 |
+
"eps": 1e-6,
|
| 816 |
+
}
|
| 817 |
+
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
| 818 |
+
# 14B PAI control
|
| 819 |
+
config = {
|
| 820 |
+
"has_image_input": True,
|
| 821 |
+
"patch_size": [1, 2, 2],
|
| 822 |
+
"in_dim": 48,
|
| 823 |
+
"dim": 5120,
|
| 824 |
+
"ffn_dim": 13824,
|
| 825 |
+
"freq_dim": 256,
|
| 826 |
+
"text_dim": 4096,
|
| 827 |
+
"out_dim": 16,
|
| 828 |
+
"num_heads": 40,
|
| 829 |
+
"num_layers": 40,
|
| 830 |
+
"eps": 1e-6,
|
| 831 |
+
}
|
| 832 |
+
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
|
| 833 |
+
config = {
|
| 834 |
+
"has_image_input": True,
|
| 835 |
+
"patch_size": [1, 2, 2],
|
| 836 |
+
"in_dim": 36,
|
| 837 |
+
"dim": 5120,
|
| 838 |
+
"ffn_dim": 13824,
|
| 839 |
+
"freq_dim": 256,
|
| 840 |
+
"text_dim": 4096,
|
| 841 |
+
"out_dim": 16,
|
| 842 |
+
"num_heads": 40,
|
| 843 |
+
"num_layers": 40,
|
| 844 |
+
"eps": 1e-6,
|
| 845 |
+
"has_image_pos_emb": True,
|
| 846 |
+
}
|
| 847 |
+
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
|
| 848 |
+
# 1.3B PAI control v1.1
|
| 849 |
+
config = {
|
| 850 |
+
"has_image_input": True,
|
| 851 |
+
"patch_size": [1, 2, 2],
|
| 852 |
+
"in_dim": 48,
|
| 853 |
+
"dim": 1536,
|
| 854 |
+
"ffn_dim": 8960,
|
| 855 |
+
"freq_dim": 256,
|
| 856 |
+
"text_dim": 4096,
|
| 857 |
+
"out_dim": 16,
|
| 858 |
+
"num_heads": 12,
|
| 859 |
+
"num_layers": 30,
|
| 860 |
+
"eps": 1e-6,
|
| 861 |
+
"has_ref_conv": True,
|
| 862 |
+
}
|
| 863 |
+
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
|
| 864 |
+
# 14B PAI control v1.1
|
| 865 |
+
config = {
|
| 866 |
+
"has_image_input": True,
|
| 867 |
+
"patch_size": [1, 2, 2],
|
| 868 |
+
"in_dim": 48,
|
| 869 |
+
"dim": 5120,
|
| 870 |
+
"ffn_dim": 13824,
|
| 871 |
+
"freq_dim": 256,
|
| 872 |
+
"text_dim": 4096,
|
| 873 |
+
"out_dim": 16,
|
| 874 |
+
"num_heads": 40,
|
| 875 |
+
"num_layers": 40,
|
| 876 |
+
"eps": 1e-6,
|
| 877 |
+
"has_ref_conv": True,
|
| 878 |
+
}
|
| 879 |
+
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
|
| 880 |
+
# 1.3B PAI control-camera v1.1
|
| 881 |
+
config = {
|
| 882 |
+
"has_image_input": True,
|
| 883 |
+
"patch_size": [1, 2, 2],
|
| 884 |
+
"in_dim": 32,
|
| 885 |
+
"dim": 1536,
|
| 886 |
+
"ffn_dim": 8960,
|
| 887 |
+
"freq_dim": 256,
|
| 888 |
+
"text_dim": 4096,
|
| 889 |
+
"out_dim": 16,
|
| 890 |
+
"num_heads": 12,
|
| 891 |
+
"num_layers": 30,
|
| 892 |
+
"eps": 1e-6,
|
| 893 |
+
"has_ref_conv": False,
|
| 894 |
+
"add_control_adapter": True,
|
| 895 |
+
"in_dim_control_adapter": 24,
|
| 896 |
+
}
|
| 897 |
+
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
|
| 898 |
+
# 14B PAI control-camera v1.1
|
| 899 |
+
config = {
|
| 900 |
+
"has_image_input": True,
|
| 901 |
+
"patch_size": [1, 2, 2],
|
| 902 |
+
"in_dim": 32,
|
| 903 |
+
"dim": 5120,
|
| 904 |
+
"ffn_dim": 13824,
|
| 905 |
+
"freq_dim": 256,
|
| 906 |
+
"text_dim": 4096,
|
| 907 |
+
"out_dim": 16,
|
| 908 |
+
"num_heads": 40,
|
| 909 |
+
"num_layers": 40,
|
| 910 |
+
"eps": 1e-6,
|
| 911 |
+
"has_ref_conv": False,
|
| 912 |
+
"add_control_adapter": True,
|
| 913 |
+
"in_dim_control_adapter": 24,
|
| 914 |
+
}
|
| 915 |
+
elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
|
| 916 |
+
# Wan-AI/Wan2.2-TI2V-5B
|
| 917 |
+
config = {
|
| 918 |
+
"has_image_input": False,
|
| 919 |
+
"patch_size": [1, 2, 2],
|
| 920 |
+
"in_dim": 48,
|
| 921 |
+
"dim": 3072,
|
| 922 |
+
"ffn_dim": 14336,
|
| 923 |
+
"freq_dim": 256,
|
| 924 |
+
"text_dim": 4096,
|
| 925 |
+
"out_dim": 48,
|
| 926 |
+
"num_heads": 24,
|
| 927 |
+
"num_layers": 30,
|
| 928 |
+
"eps": 1e-6,
|
| 929 |
+
"seperated_timestep": True,
|
| 930 |
+
"require_clip_embedding": False,
|
| 931 |
+
"require_vae_embedding": False,
|
| 932 |
+
"fuse_vae_embedding_in_latents": True,
|
| 933 |
+
}
|
| 934 |
+
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
|
| 935 |
+
# Wan-AI/Wan2.2-I2V-A14B
|
| 936 |
+
config = {
|
| 937 |
+
"has_image_input": False,
|
| 938 |
+
"patch_size": [1, 2, 2],
|
| 939 |
+
"in_dim": 36,
|
| 940 |
+
"dim": 5120,
|
| 941 |
+
"ffn_dim": 13824,
|
| 942 |
+
"freq_dim": 256,
|
| 943 |
+
"text_dim": 4096,
|
| 944 |
+
"out_dim": 16,
|
| 945 |
+
"num_heads": 40,
|
| 946 |
+
"num_layers": 40,
|
| 947 |
+
"eps": 1e-6,
|
| 948 |
+
"require_clip_embedding": False,
|
| 949 |
+
}
|
| 950 |
+
else:
|
| 951 |
+
config = {}
|
| 952 |
+
return state_dict, config
|
models/wan_video_image_encoder.py
ADDED
|
@@ -0,0 +1,957 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Concise re-implementation of
|
| 3 |
+
``https://github.com/openai/CLIP'' and
|
| 4 |
+
``https://github.com/mlfoundations/open_clip''.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
from .wan_video_dit import flash_attention
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SelfAttention(nn.Module):
|
| 16 |
+
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
| 17 |
+
assert dim % num_heads == 0
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.dim = dim
|
| 20 |
+
self.num_heads = num_heads
|
| 21 |
+
self.head_dim = dim // num_heads
|
| 22 |
+
self.eps = eps
|
| 23 |
+
|
| 24 |
+
# layers
|
| 25 |
+
self.q = nn.Linear(dim, dim)
|
| 26 |
+
self.k = nn.Linear(dim, dim)
|
| 27 |
+
self.v = nn.Linear(dim, dim)
|
| 28 |
+
self.o = nn.Linear(dim, dim)
|
| 29 |
+
self.dropout = nn.Dropout(dropout)
|
| 30 |
+
|
| 31 |
+
def forward(self, x, mask):
|
| 32 |
+
"""
|
| 33 |
+
x: [B, L, C].
|
| 34 |
+
"""
|
| 35 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 36 |
+
|
| 37 |
+
# compute query, key, value
|
| 38 |
+
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 39 |
+
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 40 |
+
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 41 |
+
|
| 42 |
+
# compute attention
|
| 43 |
+
p = self.dropout.p if self.training else 0.0
|
| 44 |
+
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
| 45 |
+
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
| 46 |
+
|
| 47 |
+
# output
|
| 48 |
+
x = self.o(x)
|
| 49 |
+
x = self.dropout(x)
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class AttentionBlock(nn.Module):
|
| 54 |
+
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.dim = dim
|
| 57 |
+
self.num_heads = num_heads
|
| 58 |
+
self.post_norm = post_norm
|
| 59 |
+
self.eps = eps
|
| 60 |
+
|
| 61 |
+
# layers
|
| 62 |
+
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
| 63 |
+
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
| 64 |
+
self.ffn = nn.Sequential(
|
| 65 |
+
nn.Linear(dim, dim * 4),
|
| 66 |
+
nn.GELU(),
|
| 67 |
+
nn.Linear(dim * 4, dim),
|
| 68 |
+
nn.Dropout(dropout),
|
| 69 |
+
)
|
| 70 |
+
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
| 71 |
+
|
| 72 |
+
def forward(self, x, mask):
|
| 73 |
+
if self.post_norm:
|
| 74 |
+
x = self.norm1(x + self.attn(x, mask))
|
| 75 |
+
x = self.norm2(x + self.ffn(x))
|
| 76 |
+
else:
|
| 77 |
+
x = x + self.attn(self.norm1(x), mask)
|
| 78 |
+
x = x + self.ffn(self.norm2(x))
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class XLMRoberta(nn.Module):
|
| 83 |
+
"""
|
| 84 |
+
XLMRobertaModel with no pooler and no LM head.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
vocab_size=250002,
|
| 90 |
+
max_seq_len=514,
|
| 91 |
+
type_size=1,
|
| 92 |
+
pad_id=1,
|
| 93 |
+
dim=1024,
|
| 94 |
+
num_heads=16,
|
| 95 |
+
num_layers=24,
|
| 96 |
+
post_norm=True,
|
| 97 |
+
dropout=0.1,
|
| 98 |
+
eps=1e-5,
|
| 99 |
+
):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.vocab_size = vocab_size
|
| 102 |
+
self.max_seq_len = max_seq_len
|
| 103 |
+
self.type_size = type_size
|
| 104 |
+
self.pad_id = pad_id
|
| 105 |
+
self.dim = dim
|
| 106 |
+
self.num_heads = num_heads
|
| 107 |
+
self.num_layers = num_layers
|
| 108 |
+
self.post_norm = post_norm
|
| 109 |
+
self.eps = eps
|
| 110 |
+
|
| 111 |
+
# embeddings
|
| 112 |
+
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
| 113 |
+
self.type_embedding = nn.Embedding(type_size, dim)
|
| 114 |
+
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
| 115 |
+
self.dropout = nn.Dropout(dropout)
|
| 116 |
+
|
| 117 |
+
# blocks
|
| 118 |
+
self.blocks = nn.ModuleList(
|
| 119 |
+
[
|
| 120 |
+
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
|
| 121 |
+
for _ in range(num_layers)
|
| 122 |
+
]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# norm layer
|
| 126 |
+
self.norm = nn.LayerNorm(dim, eps=eps)
|
| 127 |
+
|
| 128 |
+
def forward(self, ids):
|
| 129 |
+
"""
|
| 130 |
+
ids: [B, L] of torch.LongTensor.
|
| 131 |
+
"""
|
| 132 |
+
b, s = ids.shape
|
| 133 |
+
mask = ids.ne(self.pad_id).long()
|
| 134 |
+
|
| 135 |
+
# embeddings
|
| 136 |
+
x = (
|
| 137 |
+
self.token_embedding(ids)
|
| 138 |
+
+ self.type_embedding(torch.zeros_like(ids))
|
| 139 |
+
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
| 140 |
+
)
|
| 141 |
+
if self.post_norm:
|
| 142 |
+
x = self.norm(x)
|
| 143 |
+
x = self.dropout(x)
|
| 144 |
+
|
| 145 |
+
# blocks
|
| 146 |
+
mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
|
| 147 |
+
for block in self.blocks:
|
| 148 |
+
x = block(x, mask)
|
| 149 |
+
|
| 150 |
+
# output
|
| 151 |
+
if not self.post_norm:
|
| 152 |
+
x = self.norm(x)
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
|
| 157 |
+
"""
|
| 158 |
+
XLMRobertaLarge adapted from Huggingface.
|
| 159 |
+
"""
|
| 160 |
+
# params
|
| 161 |
+
cfg = dict(
|
| 162 |
+
vocab_size=250002,
|
| 163 |
+
max_seq_len=514,
|
| 164 |
+
type_size=1,
|
| 165 |
+
pad_id=1,
|
| 166 |
+
dim=1024,
|
| 167 |
+
num_heads=16,
|
| 168 |
+
num_layers=24,
|
| 169 |
+
post_norm=True,
|
| 170 |
+
dropout=0.1,
|
| 171 |
+
eps=1e-5,
|
| 172 |
+
)
|
| 173 |
+
cfg.update(**kwargs)
|
| 174 |
+
|
| 175 |
+
# init model
|
| 176 |
+
if pretrained:
|
| 177 |
+
from sora import DOWNLOAD_TO_CACHE
|
| 178 |
+
|
| 179 |
+
# init a meta model
|
| 180 |
+
with torch.device("meta"):
|
| 181 |
+
model = XLMRoberta(**cfg)
|
| 182 |
+
|
| 183 |
+
# load checkpoint
|
| 184 |
+
model.load_state_dict(
|
| 185 |
+
torch.load(
|
| 186 |
+
DOWNLOAD_TO_CACHE("models/xlm_roberta/xlm_roberta_large.pth"),
|
| 187 |
+
map_location=device,
|
| 188 |
+
),
|
| 189 |
+
assign=True,
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
# init a model on device
|
| 193 |
+
with torch.device(device):
|
| 194 |
+
model = XLMRoberta(**cfg)
|
| 195 |
+
|
| 196 |
+
# init tokenizer
|
| 197 |
+
if return_tokenizer:
|
| 198 |
+
from sora.data import HuggingfaceTokenizer
|
| 199 |
+
|
| 200 |
+
tokenizer = HuggingfaceTokenizer(
|
| 201 |
+
name="xlm-roberta-large", seq_len=model.text_len, clean="whitespace"
|
| 202 |
+
)
|
| 203 |
+
return model, tokenizer
|
| 204 |
+
else:
|
| 205 |
+
return model
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def pos_interpolate(pos, seq_len):
|
| 209 |
+
if pos.size(1) == seq_len:
|
| 210 |
+
return pos
|
| 211 |
+
else:
|
| 212 |
+
src_grid = int(math.sqrt(pos.size(1)))
|
| 213 |
+
tar_grid = int(math.sqrt(seq_len))
|
| 214 |
+
n = pos.size(1) - src_grid * src_grid
|
| 215 |
+
return torch.cat(
|
| 216 |
+
[
|
| 217 |
+
pos[:, :n],
|
| 218 |
+
F.interpolate(
|
| 219 |
+
pos[:, n:]
|
| 220 |
+
.float()
|
| 221 |
+
.reshape(1, src_grid, src_grid, -1)
|
| 222 |
+
.permute(0, 3, 1, 2),
|
| 223 |
+
size=(tar_grid, tar_grid),
|
| 224 |
+
mode="bicubic",
|
| 225 |
+
align_corners=False,
|
| 226 |
+
)
|
| 227 |
+
.flatten(2)
|
| 228 |
+
.transpose(1, 2),
|
| 229 |
+
],
|
| 230 |
+
dim=1,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class QuickGELU(nn.Module):
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
return x * torch.sigmoid(1.702 * x)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class LayerNorm(nn.LayerNorm):
|
| 240 |
+
def forward(self, x):
|
| 241 |
+
return super().forward(x).type_as(x)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class SelfAttention(nn.Module):
|
| 245 |
+
def __init__(
|
| 246 |
+
self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0
|
| 247 |
+
):
|
| 248 |
+
assert dim % num_heads == 0
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.dim = dim
|
| 251 |
+
self.num_heads = num_heads
|
| 252 |
+
self.head_dim = dim // num_heads
|
| 253 |
+
self.causal = causal
|
| 254 |
+
self.attn_dropout = attn_dropout
|
| 255 |
+
self.proj_dropout = proj_dropout
|
| 256 |
+
|
| 257 |
+
# layers
|
| 258 |
+
self.to_qkv = nn.Linear(dim, dim * 3)
|
| 259 |
+
self.proj = nn.Linear(dim, dim)
|
| 260 |
+
|
| 261 |
+
def forward(self, x):
|
| 262 |
+
"""
|
| 263 |
+
x: [B, L, C].
|
| 264 |
+
"""
|
| 265 |
+
# compute query, key, value
|
| 266 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 267 |
+
|
| 268 |
+
# compute attention
|
| 269 |
+
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
| 270 |
+
|
| 271 |
+
# output
|
| 272 |
+
x = self.proj(x)
|
| 273 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 274 |
+
return x
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class SwiGLU(nn.Module):
|
| 278 |
+
def __init__(self, dim, mid_dim):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.dim = dim
|
| 281 |
+
self.mid_dim = mid_dim
|
| 282 |
+
|
| 283 |
+
# layers
|
| 284 |
+
self.fc1 = nn.Linear(dim, mid_dim)
|
| 285 |
+
self.fc2 = nn.Linear(dim, mid_dim)
|
| 286 |
+
self.fc3 = nn.Linear(mid_dim, dim)
|
| 287 |
+
|
| 288 |
+
def forward(self, x):
|
| 289 |
+
x = F.silu(self.fc1(x)) * self.fc2(x)
|
| 290 |
+
x = self.fc3(x)
|
| 291 |
+
return x
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class AttentionBlock(nn.Module):
|
| 295 |
+
def __init__(
|
| 296 |
+
self,
|
| 297 |
+
dim,
|
| 298 |
+
mlp_ratio,
|
| 299 |
+
num_heads,
|
| 300 |
+
post_norm=False,
|
| 301 |
+
causal=False,
|
| 302 |
+
activation="quick_gelu",
|
| 303 |
+
attn_dropout=0.0,
|
| 304 |
+
proj_dropout=0.0,
|
| 305 |
+
norm_eps=1e-5,
|
| 306 |
+
):
|
| 307 |
+
assert activation in ["quick_gelu", "gelu", "swi_glu"]
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.dim = dim
|
| 310 |
+
self.mlp_ratio = mlp_ratio
|
| 311 |
+
self.num_heads = num_heads
|
| 312 |
+
self.post_norm = post_norm
|
| 313 |
+
self.causal = causal
|
| 314 |
+
self.norm_eps = norm_eps
|
| 315 |
+
|
| 316 |
+
# layers
|
| 317 |
+
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
| 318 |
+
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
|
| 319 |
+
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
| 320 |
+
if activation == "swi_glu":
|
| 321 |
+
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
| 322 |
+
else:
|
| 323 |
+
self.mlp = nn.Sequential(
|
| 324 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 325 |
+
QuickGELU() if activation == "quick_gelu" else nn.GELU(),
|
| 326 |
+
nn.Linear(int(dim * mlp_ratio), dim),
|
| 327 |
+
nn.Dropout(proj_dropout),
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def forward(self, x):
|
| 331 |
+
if self.post_norm:
|
| 332 |
+
x = x + self.norm1(self.attn(x))
|
| 333 |
+
x = x + self.norm2(self.mlp(x))
|
| 334 |
+
else:
|
| 335 |
+
x = x + self.attn(self.norm1(x))
|
| 336 |
+
x = x + self.mlp(self.norm2(x))
|
| 337 |
+
return x
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class AttentionPool(nn.Module):
|
| 341 |
+
def __init__(
|
| 342 |
+
self,
|
| 343 |
+
dim,
|
| 344 |
+
mlp_ratio,
|
| 345 |
+
num_heads,
|
| 346 |
+
activation="gelu",
|
| 347 |
+
proj_dropout=0.0,
|
| 348 |
+
norm_eps=1e-5,
|
| 349 |
+
):
|
| 350 |
+
assert dim % num_heads == 0
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.dim = dim
|
| 353 |
+
self.mlp_ratio = mlp_ratio
|
| 354 |
+
self.num_heads = num_heads
|
| 355 |
+
self.head_dim = dim // num_heads
|
| 356 |
+
self.proj_dropout = proj_dropout
|
| 357 |
+
self.norm_eps = norm_eps
|
| 358 |
+
|
| 359 |
+
# layers
|
| 360 |
+
gain = 1.0 / math.sqrt(dim)
|
| 361 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 362 |
+
self.to_q = nn.Linear(dim, dim)
|
| 363 |
+
self.to_kv = nn.Linear(dim, dim * 2)
|
| 364 |
+
self.proj = nn.Linear(dim, dim)
|
| 365 |
+
self.norm = LayerNorm(dim, eps=norm_eps)
|
| 366 |
+
self.mlp = nn.Sequential(
|
| 367 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 368 |
+
QuickGELU() if activation == "quick_gelu" else nn.GELU(),
|
| 369 |
+
nn.Linear(int(dim * mlp_ratio), dim),
|
| 370 |
+
nn.Dropout(proj_dropout),
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
def forward(self, x):
|
| 374 |
+
"""
|
| 375 |
+
x: [B, L, C].
|
| 376 |
+
"""
|
| 377 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 378 |
+
|
| 379 |
+
# compute query, key, value
|
| 380 |
+
q = self.to_q(self.cls_embedding).view(1, 1, n * d).expand(b, -1, -1)
|
| 381 |
+
k, v = self.to_kv(x).chunk(2, dim=-1)
|
| 382 |
+
|
| 383 |
+
# compute attention
|
| 384 |
+
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
| 385 |
+
x = x.reshape(b, 1, c)
|
| 386 |
+
|
| 387 |
+
# output
|
| 388 |
+
x = self.proj(x)
|
| 389 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 390 |
+
|
| 391 |
+
# mlp
|
| 392 |
+
x = x + self.mlp(self.norm(x))
|
| 393 |
+
return x[:, 0]
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class VisionTransformer(nn.Module):
|
| 397 |
+
def __init__(
|
| 398 |
+
self,
|
| 399 |
+
image_size=224,
|
| 400 |
+
patch_size=16,
|
| 401 |
+
dim=768,
|
| 402 |
+
mlp_ratio=4,
|
| 403 |
+
out_dim=512,
|
| 404 |
+
num_heads=12,
|
| 405 |
+
num_layers=12,
|
| 406 |
+
pool_type="token",
|
| 407 |
+
pre_norm=True,
|
| 408 |
+
post_norm=False,
|
| 409 |
+
activation="quick_gelu",
|
| 410 |
+
attn_dropout=0.0,
|
| 411 |
+
proj_dropout=0.0,
|
| 412 |
+
embedding_dropout=0.0,
|
| 413 |
+
norm_eps=1e-5,
|
| 414 |
+
):
|
| 415 |
+
if image_size % patch_size != 0:
|
| 416 |
+
print("[WARNING] image_size is not divisible by patch_size", flush=True)
|
| 417 |
+
assert pool_type in ("token", "token_fc", "attn_pool")
|
| 418 |
+
out_dim = out_dim or dim
|
| 419 |
+
super().__init__()
|
| 420 |
+
self.image_size = image_size
|
| 421 |
+
self.patch_size = patch_size
|
| 422 |
+
self.num_patches = (image_size // patch_size) ** 2
|
| 423 |
+
self.dim = dim
|
| 424 |
+
self.mlp_ratio = mlp_ratio
|
| 425 |
+
self.out_dim = out_dim
|
| 426 |
+
self.num_heads = num_heads
|
| 427 |
+
self.num_layers = num_layers
|
| 428 |
+
self.pool_type = pool_type
|
| 429 |
+
self.post_norm = post_norm
|
| 430 |
+
self.norm_eps = norm_eps
|
| 431 |
+
|
| 432 |
+
# embeddings
|
| 433 |
+
gain = 1.0 / math.sqrt(dim)
|
| 434 |
+
self.patch_embedding = nn.Conv2d(
|
| 435 |
+
3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm
|
| 436 |
+
)
|
| 437 |
+
if pool_type in ("token", "token_fc"):
|
| 438 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 439 |
+
self.pos_embedding = nn.Parameter(
|
| 440 |
+
gain
|
| 441 |
+
* torch.randn(
|
| 442 |
+
1,
|
| 443 |
+
self.num_patches + (1 if pool_type in ("token", "token_fc") else 0),
|
| 444 |
+
dim,
|
| 445 |
+
)
|
| 446 |
+
)
|
| 447 |
+
self.dropout = nn.Dropout(embedding_dropout)
|
| 448 |
+
|
| 449 |
+
# transformer
|
| 450 |
+
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
| 451 |
+
self.transformer = nn.Sequential(
|
| 452 |
+
*[
|
| 453 |
+
AttentionBlock(
|
| 454 |
+
dim,
|
| 455 |
+
mlp_ratio,
|
| 456 |
+
num_heads,
|
| 457 |
+
post_norm,
|
| 458 |
+
False,
|
| 459 |
+
activation,
|
| 460 |
+
attn_dropout,
|
| 461 |
+
proj_dropout,
|
| 462 |
+
norm_eps,
|
| 463 |
+
)
|
| 464 |
+
for _ in range(num_layers)
|
| 465 |
+
]
|
| 466 |
+
)
|
| 467 |
+
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
| 468 |
+
|
| 469 |
+
# head
|
| 470 |
+
if pool_type == "token":
|
| 471 |
+
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
| 472 |
+
elif pool_type == "token_fc":
|
| 473 |
+
self.head = nn.Linear(dim, out_dim)
|
| 474 |
+
elif pool_type == "attn_pool":
|
| 475 |
+
self.head = AttentionPool(
|
| 476 |
+
dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
def forward(self, x, interpolation=False, use_31_block=False):
|
| 480 |
+
b = x.size(0)
|
| 481 |
+
|
| 482 |
+
# embeddings
|
| 483 |
+
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
| 484 |
+
if self.pool_type in ("token", "token_fc"):
|
| 485 |
+
x = torch.cat(
|
| 486 |
+
[
|
| 487 |
+
self.cls_embedding.expand(b, -1, -1).to(
|
| 488 |
+
dtype=x.dtype, device=x.device
|
| 489 |
+
),
|
| 490 |
+
x,
|
| 491 |
+
],
|
| 492 |
+
dim=1,
|
| 493 |
+
)
|
| 494 |
+
if interpolation:
|
| 495 |
+
e = pos_interpolate(self.pos_embedding, x.size(1))
|
| 496 |
+
else:
|
| 497 |
+
e = self.pos_embedding
|
| 498 |
+
e = e.to(dtype=x.dtype, device=x.device)
|
| 499 |
+
x = self.dropout(x + e)
|
| 500 |
+
if self.pre_norm is not None:
|
| 501 |
+
x = self.pre_norm(x)
|
| 502 |
+
|
| 503 |
+
# transformer
|
| 504 |
+
if use_31_block:
|
| 505 |
+
x = self.transformer[:-1](x)
|
| 506 |
+
return x
|
| 507 |
+
else:
|
| 508 |
+
x = self.transformer(x)
|
| 509 |
+
return x
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class CLIP(nn.Module):
|
| 513 |
+
def __init__(
|
| 514 |
+
self,
|
| 515 |
+
embed_dim=512,
|
| 516 |
+
image_size=224,
|
| 517 |
+
patch_size=16,
|
| 518 |
+
vision_dim=768,
|
| 519 |
+
vision_mlp_ratio=4,
|
| 520 |
+
vision_heads=12,
|
| 521 |
+
vision_layers=12,
|
| 522 |
+
vision_pool="token",
|
| 523 |
+
vision_pre_norm=True,
|
| 524 |
+
vision_post_norm=False,
|
| 525 |
+
vocab_size=49408,
|
| 526 |
+
text_len=77,
|
| 527 |
+
text_dim=512,
|
| 528 |
+
text_mlp_ratio=4,
|
| 529 |
+
text_heads=8,
|
| 530 |
+
text_layers=12,
|
| 531 |
+
text_causal=True,
|
| 532 |
+
text_pool="argmax",
|
| 533 |
+
text_head_bias=False,
|
| 534 |
+
logit_bias=None,
|
| 535 |
+
activation="quick_gelu",
|
| 536 |
+
attn_dropout=0.0,
|
| 537 |
+
proj_dropout=0.0,
|
| 538 |
+
embedding_dropout=0.0,
|
| 539 |
+
norm_eps=1e-5,
|
| 540 |
+
):
|
| 541 |
+
super().__init__()
|
| 542 |
+
self.embed_dim = embed_dim
|
| 543 |
+
self.image_size = image_size
|
| 544 |
+
self.patch_size = patch_size
|
| 545 |
+
self.vision_dim = vision_dim
|
| 546 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
| 547 |
+
self.vision_heads = vision_heads
|
| 548 |
+
self.vision_layers = vision_layers
|
| 549 |
+
self.vision_pool = vision_pool
|
| 550 |
+
self.vision_pre_norm = vision_pre_norm
|
| 551 |
+
self.vision_post_norm = vision_post_norm
|
| 552 |
+
self.vocab_size = vocab_size
|
| 553 |
+
self.text_len = text_len
|
| 554 |
+
self.text_dim = text_dim
|
| 555 |
+
self.text_mlp_ratio = text_mlp_ratio
|
| 556 |
+
self.text_heads = text_heads
|
| 557 |
+
self.text_layers = text_layers
|
| 558 |
+
self.text_causal = text_causal
|
| 559 |
+
self.text_pool = text_pool
|
| 560 |
+
self.text_head_bias = text_head_bias
|
| 561 |
+
self.norm_eps = norm_eps
|
| 562 |
+
|
| 563 |
+
# models
|
| 564 |
+
self.visual = VisionTransformer(
|
| 565 |
+
image_size=image_size,
|
| 566 |
+
patch_size=patch_size,
|
| 567 |
+
dim=vision_dim,
|
| 568 |
+
mlp_ratio=vision_mlp_ratio,
|
| 569 |
+
out_dim=embed_dim,
|
| 570 |
+
num_heads=vision_heads,
|
| 571 |
+
num_layers=vision_layers,
|
| 572 |
+
pool_type=vision_pool,
|
| 573 |
+
pre_norm=vision_pre_norm,
|
| 574 |
+
post_norm=vision_post_norm,
|
| 575 |
+
activation=activation,
|
| 576 |
+
attn_dropout=attn_dropout,
|
| 577 |
+
proj_dropout=proj_dropout,
|
| 578 |
+
embedding_dropout=embedding_dropout,
|
| 579 |
+
norm_eps=norm_eps,
|
| 580 |
+
)
|
| 581 |
+
self.textual = TextTransformer(
|
| 582 |
+
vocab_size=vocab_size,
|
| 583 |
+
text_len=text_len,
|
| 584 |
+
dim=text_dim,
|
| 585 |
+
mlp_ratio=text_mlp_ratio,
|
| 586 |
+
out_dim=embed_dim,
|
| 587 |
+
num_heads=text_heads,
|
| 588 |
+
num_layers=text_layers,
|
| 589 |
+
causal=text_causal,
|
| 590 |
+
pool_type=text_pool,
|
| 591 |
+
head_bias=text_head_bias,
|
| 592 |
+
activation=activation,
|
| 593 |
+
attn_dropout=attn_dropout,
|
| 594 |
+
proj_dropout=proj_dropout,
|
| 595 |
+
embedding_dropout=embedding_dropout,
|
| 596 |
+
norm_eps=norm_eps,
|
| 597 |
+
)
|
| 598 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
| 599 |
+
if logit_bias is not None:
|
| 600 |
+
self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
|
| 601 |
+
|
| 602 |
+
# initialize weights
|
| 603 |
+
self.init_weights()
|
| 604 |
+
|
| 605 |
+
def forward(self, imgs, txt_ids):
|
| 606 |
+
"""
|
| 607 |
+
imgs: [B, 3, H, W] of torch.float32.
|
| 608 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
| 609 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
| 610 |
+
txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
|
| 611 |
+
"""
|
| 612 |
+
xi = self.visual(imgs)
|
| 613 |
+
xt = self.textual(txt_ids)
|
| 614 |
+
return xi, xt
|
| 615 |
+
|
| 616 |
+
def init_weights(self):
|
| 617 |
+
# embeddings
|
| 618 |
+
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
|
| 619 |
+
nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
|
| 620 |
+
|
| 621 |
+
# attentions
|
| 622 |
+
for modality in ["visual", "textual"]:
|
| 623 |
+
dim = self.vision_dim if modality == "visual" else self.text_dim
|
| 624 |
+
transformer = getattr(self, modality).transformer
|
| 625 |
+
proj_gain = (1.0 / math.sqrt(dim)) * (1.0 / math.sqrt(2 * len(transformer)))
|
| 626 |
+
attn_gain = 1.0 / math.sqrt(dim)
|
| 627 |
+
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
|
| 628 |
+
for block in transformer:
|
| 629 |
+
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
|
| 630 |
+
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
|
| 631 |
+
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
|
| 632 |
+
nn.init.normal_(block.mlp[2].weight, std=proj_gain)
|
| 633 |
+
|
| 634 |
+
def param_groups(self):
|
| 635 |
+
groups = [
|
| 636 |
+
{
|
| 637 |
+
"params": [
|
| 638 |
+
p
|
| 639 |
+
for n, p in self.named_parameters()
|
| 640 |
+
if "norm" in n or n.endswith("bias")
|
| 641 |
+
],
|
| 642 |
+
"weight_decay": 0.0,
|
| 643 |
+
},
|
| 644 |
+
{
|
| 645 |
+
"params": [
|
| 646 |
+
p
|
| 647 |
+
for n, p in self.named_parameters()
|
| 648 |
+
if not ("norm" in n or n.endswith("bias"))
|
| 649 |
+
]
|
| 650 |
+
},
|
| 651 |
+
]
|
| 652 |
+
return groups
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
class XLMRobertaWithHead(XLMRoberta):
|
| 656 |
+
def __init__(self, **kwargs):
|
| 657 |
+
self.out_dim = kwargs.pop("out_dim")
|
| 658 |
+
super().__init__(**kwargs)
|
| 659 |
+
|
| 660 |
+
# head
|
| 661 |
+
mid_dim = (self.dim + self.out_dim) // 2
|
| 662 |
+
self.head = nn.Sequential(
|
| 663 |
+
nn.Linear(self.dim, mid_dim, bias=False),
|
| 664 |
+
nn.GELU(),
|
| 665 |
+
nn.Linear(mid_dim, self.out_dim, bias=False),
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
def forward(self, ids):
|
| 669 |
+
# xlm-roberta
|
| 670 |
+
x = super().forward(ids)
|
| 671 |
+
|
| 672 |
+
# average pooling
|
| 673 |
+
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
| 674 |
+
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
| 675 |
+
|
| 676 |
+
# head
|
| 677 |
+
x = self.head(x)
|
| 678 |
+
return x
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
class XLMRobertaCLIP(nn.Module):
|
| 682 |
+
def __init__(
|
| 683 |
+
self,
|
| 684 |
+
embed_dim=1024,
|
| 685 |
+
image_size=224,
|
| 686 |
+
patch_size=14,
|
| 687 |
+
vision_dim=1280,
|
| 688 |
+
vision_mlp_ratio=4,
|
| 689 |
+
vision_heads=16,
|
| 690 |
+
vision_layers=32,
|
| 691 |
+
vision_pool="token",
|
| 692 |
+
vision_pre_norm=True,
|
| 693 |
+
vision_post_norm=False,
|
| 694 |
+
activation="gelu",
|
| 695 |
+
vocab_size=250002,
|
| 696 |
+
max_text_len=514,
|
| 697 |
+
type_size=1,
|
| 698 |
+
pad_id=1,
|
| 699 |
+
text_dim=1024,
|
| 700 |
+
text_heads=16,
|
| 701 |
+
text_layers=24,
|
| 702 |
+
text_post_norm=True,
|
| 703 |
+
text_dropout=0.1,
|
| 704 |
+
attn_dropout=0.0,
|
| 705 |
+
proj_dropout=0.0,
|
| 706 |
+
embedding_dropout=0.0,
|
| 707 |
+
norm_eps=1e-5,
|
| 708 |
+
):
|
| 709 |
+
super().__init__()
|
| 710 |
+
self.embed_dim = embed_dim
|
| 711 |
+
self.image_size = image_size
|
| 712 |
+
self.patch_size = patch_size
|
| 713 |
+
self.vision_dim = vision_dim
|
| 714 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
| 715 |
+
self.vision_heads = vision_heads
|
| 716 |
+
self.vision_layers = vision_layers
|
| 717 |
+
self.vision_pre_norm = vision_pre_norm
|
| 718 |
+
self.vision_post_norm = vision_post_norm
|
| 719 |
+
self.activation = activation
|
| 720 |
+
self.vocab_size = vocab_size
|
| 721 |
+
self.max_text_len = max_text_len
|
| 722 |
+
self.type_size = type_size
|
| 723 |
+
self.pad_id = pad_id
|
| 724 |
+
self.text_dim = text_dim
|
| 725 |
+
self.text_heads = text_heads
|
| 726 |
+
self.text_layers = text_layers
|
| 727 |
+
self.text_post_norm = text_post_norm
|
| 728 |
+
self.norm_eps = norm_eps
|
| 729 |
+
|
| 730 |
+
# models
|
| 731 |
+
self.visual = VisionTransformer(
|
| 732 |
+
image_size=image_size,
|
| 733 |
+
patch_size=patch_size,
|
| 734 |
+
dim=vision_dim,
|
| 735 |
+
mlp_ratio=vision_mlp_ratio,
|
| 736 |
+
out_dim=embed_dim,
|
| 737 |
+
num_heads=vision_heads,
|
| 738 |
+
num_layers=vision_layers,
|
| 739 |
+
pool_type=vision_pool,
|
| 740 |
+
pre_norm=vision_pre_norm,
|
| 741 |
+
post_norm=vision_post_norm,
|
| 742 |
+
activation=activation,
|
| 743 |
+
attn_dropout=attn_dropout,
|
| 744 |
+
proj_dropout=proj_dropout,
|
| 745 |
+
embedding_dropout=embedding_dropout,
|
| 746 |
+
norm_eps=norm_eps,
|
| 747 |
+
)
|
| 748 |
+
self.textual = None
|
| 749 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
| 750 |
+
|
| 751 |
+
def forward(self, imgs, txt_ids):
|
| 752 |
+
"""
|
| 753 |
+
imgs: [B, 3, H, W] of torch.float32.
|
| 754 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
| 755 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
| 756 |
+
txt_ids: [B, L] of torch.long.
|
| 757 |
+
Encoded by data.CLIPTokenizer.
|
| 758 |
+
"""
|
| 759 |
+
xi = self.visual(imgs)
|
| 760 |
+
xt = self.textual(txt_ids)
|
| 761 |
+
return xi, xt
|
| 762 |
+
|
| 763 |
+
def param_groups(self):
|
| 764 |
+
groups = [
|
| 765 |
+
{
|
| 766 |
+
"params": [
|
| 767 |
+
p
|
| 768 |
+
for n, p in self.named_parameters()
|
| 769 |
+
if "norm" in n or n.endswith("bias")
|
| 770 |
+
],
|
| 771 |
+
"weight_decay": 0.0,
|
| 772 |
+
},
|
| 773 |
+
{
|
| 774 |
+
"params": [
|
| 775 |
+
p
|
| 776 |
+
for n, p in self.named_parameters()
|
| 777 |
+
if not ("norm" in n or n.endswith("bias"))
|
| 778 |
+
]
|
| 779 |
+
},
|
| 780 |
+
]
|
| 781 |
+
return groups
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
def _clip(
|
| 785 |
+
pretrained=False,
|
| 786 |
+
pretrained_name=None,
|
| 787 |
+
model_cls=CLIP,
|
| 788 |
+
return_transforms=False,
|
| 789 |
+
return_tokenizer=False,
|
| 790 |
+
tokenizer_padding="eos",
|
| 791 |
+
dtype=torch.float32,
|
| 792 |
+
device="cpu",
|
| 793 |
+
**kwargs,
|
| 794 |
+
):
|
| 795 |
+
# init model
|
| 796 |
+
if pretrained and pretrained_name:
|
| 797 |
+
from sora import BUCKET, DOWNLOAD_TO_CACHE
|
| 798 |
+
|
| 799 |
+
# init a meta model
|
| 800 |
+
with torch.device("meta"):
|
| 801 |
+
model = model_cls(**kwargs)
|
| 802 |
+
|
| 803 |
+
# checkpoint path
|
| 804 |
+
checkpoint = f"models/clip/{pretrained_name}"
|
| 805 |
+
if dtype in (torch.float16, torch.bfloat16):
|
| 806 |
+
suffix = "-" + {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype]
|
| 807 |
+
if object_exists(BUCKET, f"{checkpoint}{suffix}.pth"):
|
| 808 |
+
checkpoint = f"{checkpoint}{suffix}"
|
| 809 |
+
checkpoint += ".pth"
|
| 810 |
+
|
| 811 |
+
# load
|
| 812 |
+
model.load_state_dict(
|
| 813 |
+
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
|
| 814 |
+
assign=True,
|
| 815 |
+
strict=False,
|
| 816 |
+
)
|
| 817 |
+
else:
|
| 818 |
+
# init a model on device
|
| 819 |
+
with torch.device(device):
|
| 820 |
+
model = model_cls(**kwargs)
|
| 821 |
+
|
| 822 |
+
# set device
|
| 823 |
+
output = (model,)
|
| 824 |
+
|
| 825 |
+
# init transforms
|
| 826 |
+
if return_transforms:
|
| 827 |
+
# mean and std
|
| 828 |
+
if "siglip" in pretrained_name.lower():
|
| 829 |
+
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
| 830 |
+
else:
|
| 831 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
| 832 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
| 833 |
+
|
| 834 |
+
# transforms
|
| 835 |
+
transforms = T.Compose(
|
| 836 |
+
[
|
| 837 |
+
T.Resize(
|
| 838 |
+
(model.image_size, model.image_size),
|
| 839 |
+
interpolation=T.InterpolationMode.BICUBIC,
|
| 840 |
+
),
|
| 841 |
+
T.ToTensor(),
|
| 842 |
+
T.Normalize(mean=mean, std=std),
|
| 843 |
+
]
|
| 844 |
+
)
|
| 845 |
+
output += (transforms,)
|
| 846 |
+
|
| 847 |
+
# init tokenizer
|
| 848 |
+
if return_tokenizer:
|
| 849 |
+
from sora import data
|
| 850 |
+
|
| 851 |
+
if "siglip" in pretrained_name.lower():
|
| 852 |
+
tokenizer = data.HuggingfaceTokenizer(
|
| 853 |
+
name=f"timm/{pretrained_name}",
|
| 854 |
+
seq_len=model.text_len,
|
| 855 |
+
clean="canonicalize",
|
| 856 |
+
)
|
| 857 |
+
elif "xlm" in pretrained_name.lower():
|
| 858 |
+
tokenizer = data.HuggingfaceTokenizer(
|
| 859 |
+
name="xlm-roberta-large",
|
| 860 |
+
seq_len=model.max_text_len - 2,
|
| 861 |
+
clean="whitespace",
|
| 862 |
+
)
|
| 863 |
+
elif "mba" in pretrained_name.lower():
|
| 864 |
+
tokenizer = data.HuggingfaceTokenizer(
|
| 865 |
+
name="facebook/xlm-roberta-xl",
|
| 866 |
+
seq_len=model.max_text_len - 2,
|
| 867 |
+
clean="whitespace",
|
| 868 |
+
)
|
| 869 |
+
else:
|
| 870 |
+
tokenizer = data.CLIPTokenizer(
|
| 871 |
+
seq_len=model.text_len, padding=tokenizer_padding
|
| 872 |
+
)
|
| 873 |
+
output += (tokenizer,)
|
| 874 |
+
return output[0] if len(output) == 1 else output
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
def clip_xlm_roberta_vit_h_14(
|
| 878 |
+
pretrained=False,
|
| 879 |
+
pretrained_name="open-clip-xlm-roberta-large-vit-huge-14",
|
| 880 |
+
**kwargs,
|
| 881 |
+
):
|
| 882 |
+
cfg = dict(
|
| 883 |
+
embed_dim=1024,
|
| 884 |
+
image_size=224,
|
| 885 |
+
patch_size=14,
|
| 886 |
+
vision_dim=1280,
|
| 887 |
+
vision_mlp_ratio=4,
|
| 888 |
+
vision_heads=16,
|
| 889 |
+
vision_layers=32,
|
| 890 |
+
vision_pool="token",
|
| 891 |
+
activation="gelu",
|
| 892 |
+
vocab_size=250002,
|
| 893 |
+
max_text_len=514,
|
| 894 |
+
type_size=1,
|
| 895 |
+
pad_id=1,
|
| 896 |
+
text_dim=1024,
|
| 897 |
+
text_heads=16,
|
| 898 |
+
text_layers=24,
|
| 899 |
+
text_post_norm=True,
|
| 900 |
+
text_dropout=0.1,
|
| 901 |
+
attn_dropout=0.0,
|
| 902 |
+
proj_dropout=0.0,
|
| 903 |
+
embedding_dropout=0.0,
|
| 904 |
+
)
|
| 905 |
+
cfg.update(**kwargs)
|
| 906 |
+
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
class WanImageEncoder(torch.nn.Module):
|
| 910 |
+
def __init__(self):
|
| 911 |
+
super().__init__()
|
| 912 |
+
# init model
|
| 913 |
+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
| 914 |
+
pretrained=False,
|
| 915 |
+
return_transforms=True,
|
| 916 |
+
return_tokenizer=False,
|
| 917 |
+
dtype=torch.float32,
|
| 918 |
+
device="cpu",
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
def encode_image(self, videos):
|
| 922 |
+
# preprocess
|
| 923 |
+
size = (self.model.image_size,) * 2
|
| 924 |
+
videos = torch.cat(
|
| 925 |
+
[
|
| 926 |
+
F.interpolate(u, size=size, mode="bicubic", align_corners=False)
|
| 927 |
+
for u in videos
|
| 928 |
+
]
|
| 929 |
+
)
|
| 930 |
+
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
| 931 |
+
|
| 932 |
+
# forward
|
| 933 |
+
dtype = next(iter(self.model.visual.parameters())).dtype
|
| 934 |
+
videos = videos.to(dtype)
|
| 935 |
+
out = self.model.visual(videos, use_31_block=True)
|
| 936 |
+
return out
|
| 937 |
+
|
| 938 |
+
@staticmethod
|
| 939 |
+
def state_dict_converter():
|
| 940 |
+
return WanImageEncoderStateDictConverter()
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
class WanImageEncoderStateDictConverter:
|
| 944 |
+
def __init__(self):
|
| 945 |
+
pass
|
| 946 |
+
|
| 947 |
+
def from_diffusers(self, state_dict):
|
| 948 |
+
return state_dict
|
| 949 |
+
|
| 950 |
+
def from_civitai(self, state_dict):
|
| 951 |
+
state_dict_ = {}
|
| 952 |
+
for name, param in state_dict.items():
|
| 953 |
+
if name.startswith("textual."):
|
| 954 |
+
continue
|
| 955 |
+
name = "model." + name
|
| 956 |
+
state_dict_[name] = param
|
| 957 |
+
return state_dict_
|
models/wan_video_motion_controller.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .wan_video_dit import sinusoidal_embedding_1d
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class WanMotionControllerModel(torch.nn.Module):
|
| 7 |
+
def __init__(self, freq_dim=256, dim=1536):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.freq_dim = freq_dim
|
| 10 |
+
self.linear = nn.Sequential(
|
| 11 |
+
nn.Linear(freq_dim, dim),
|
| 12 |
+
nn.SiLU(),
|
| 13 |
+
nn.Linear(dim, dim),
|
| 14 |
+
nn.SiLU(),
|
| 15 |
+
nn.Linear(dim, dim * 6),
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def forward(self, motion_bucket_id):
|
| 19 |
+
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
|
| 20 |
+
emb = self.linear(emb)
|
| 21 |
+
return emb
|
| 22 |
+
|
| 23 |
+
def init(self):
|
| 24 |
+
state_dict = self.linear[-1].state_dict()
|
| 25 |
+
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
| 26 |
+
self.linear[-1].load_state_dict(state_dict)
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def state_dict_converter():
|
| 30 |
+
return WanMotionControllerModelDictConverter()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class WanMotionControllerModelDictConverter:
|
| 34 |
+
def __init__(self):
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
def from_diffusers(self, state_dict):
|
| 38 |
+
return state_dict
|
| 39 |
+
|
| 40 |
+
def from_civitai(self, state_dict):
|
| 41 |
+
return state_dict
|
models/wan_video_text_encoder.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def fp16_clamp(x):
|
| 9 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
| 10 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
| 11 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
| 12 |
+
return x
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GELU(nn.Module):
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
return (
|
| 18 |
+
0.5
|
| 19 |
+
* x
|
| 20 |
+
* (
|
| 21 |
+
1.0
|
| 22 |
+
+ torch.tanh(
|
| 23 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
|
| 24 |
+
)
|
| 25 |
+
)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class T5LayerNorm(nn.Module):
|
| 30 |
+
def __init__(self, dim, eps=1e-6):
|
| 31 |
+
super(T5LayerNorm, self).__init__()
|
| 32 |
+
self.dim = dim
|
| 33 |
+
self.eps = eps
|
| 34 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 38 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 39 |
+
x = x.type_as(self.weight)
|
| 40 |
+
return self.weight * x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class T5Attention(nn.Module):
|
| 44 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
| 45 |
+
assert dim_attn % num_heads == 0
|
| 46 |
+
super(T5Attention, self).__init__()
|
| 47 |
+
self.dim = dim
|
| 48 |
+
self.dim_attn = dim_attn
|
| 49 |
+
self.num_heads = num_heads
|
| 50 |
+
self.head_dim = dim_attn // num_heads
|
| 51 |
+
|
| 52 |
+
# layers
|
| 53 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
| 54 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
| 55 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
| 56 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
| 57 |
+
self.dropout = nn.Dropout(dropout)
|
| 58 |
+
|
| 59 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
| 60 |
+
"""
|
| 61 |
+
x: [B, L1, C].
|
| 62 |
+
context: [B, L2, C] or None.
|
| 63 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
| 64 |
+
"""
|
| 65 |
+
# check inputs
|
| 66 |
+
context = x if context is None else context
|
| 67 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
| 68 |
+
|
| 69 |
+
# compute query, key, value
|
| 70 |
+
q = self.q(x).view(b, -1, n, c)
|
| 71 |
+
k = self.k(context).view(b, -1, n, c)
|
| 72 |
+
v = self.v(context).view(b, -1, n, c)
|
| 73 |
+
|
| 74 |
+
# attention bias
|
| 75 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
| 76 |
+
if pos_bias is not None:
|
| 77 |
+
attn_bias += pos_bias
|
| 78 |
+
if mask is not None:
|
| 79 |
+
assert mask.ndim in [2, 3]
|
| 80 |
+
mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
|
| 81 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
| 82 |
+
|
| 83 |
+
# compute attention (T5 does not use scaling)
|
| 84 |
+
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
|
| 85 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
| 86 |
+
x = torch.einsum("bnij,bjnc->binc", attn, v)
|
| 87 |
+
|
| 88 |
+
# output
|
| 89 |
+
x = x.reshape(b, -1, n * c)
|
| 90 |
+
x = self.o(x)
|
| 91 |
+
x = self.dropout(x)
|
| 92 |
+
return x
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class T5FeedForward(nn.Module):
|
| 96 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
| 97 |
+
super(T5FeedForward, self).__init__()
|
| 98 |
+
self.dim = dim
|
| 99 |
+
self.dim_ffn = dim_ffn
|
| 100 |
+
|
| 101 |
+
# layers
|
| 102 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
| 103 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
| 104 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
| 105 |
+
self.dropout = nn.Dropout(dropout)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
x = self.fc1(x) * self.gate(x)
|
| 109 |
+
x = self.dropout(x)
|
| 110 |
+
x = self.fc2(x)
|
| 111 |
+
x = self.dropout(x)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class T5SelfAttention(nn.Module):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
dim,
|
| 119 |
+
dim_attn,
|
| 120 |
+
dim_ffn,
|
| 121 |
+
num_heads,
|
| 122 |
+
num_buckets,
|
| 123 |
+
shared_pos=True,
|
| 124 |
+
dropout=0.1,
|
| 125 |
+
):
|
| 126 |
+
super(T5SelfAttention, self).__init__()
|
| 127 |
+
self.dim = dim
|
| 128 |
+
self.dim_attn = dim_attn
|
| 129 |
+
self.dim_ffn = dim_ffn
|
| 130 |
+
self.num_heads = num_heads
|
| 131 |
+
self.num_buckets = num_buckets
|
| 132 |
+
self.shared_pos = shared_pos
|
| 133 |
+
|
| 134 |
+
# layers
|
| 135 |
+
self.norm1 = T5LayerNorm(dim)
|
| 136 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 137 |
+
self.norm2 = T5LayerNorm(dim)
|
| 138 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 139 |
+
self.pos_embedding = (
|
| 140 |
+
None
|
| 141 |
+
if shared_pos
|
| 142 |
+
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def forward(self, x, mask=None, pos_bias=None):
|
| 146 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
|
| 147 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 148 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class T5RelativeEmbedding(nn.Module):
|
| 153 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
| 154 |
+
super(T5RelativeEmbedding, self).__init__()
|
| 155 |
+
self.num_buckets = num_buckets
|
| 156 |
+
self.num_heads = num_heads
|
| 157 |
+
self.bidirectional = bidirectional
|
| 158 |
+
self.max_dist = max_dist
|
| 159 |
+
|
| 160 |
+
# layers
|
| 161 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
| 162 |
+
|
| 163 |
+
def forward(self, lq, lk):
|
| 164 |
+
device = self.embedding.weight.device
|
| 165 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
| 166 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
| 167 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(
|
| 168 |
+
lq, device=device
|
| 169 |
+
).unsqueeze(1)
|
| 170 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
| 171 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
| 172 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
|
| 173 |
+
return rel_pos_embeds.contiguous()
|
| 174 |
+
|
| 175 |
+
def _relative_position_bucket(self, rel_pos):
|
| 176 |
+
# preprocess
|
| 177 |
+
if self.bidirectional:
|
| 178 |
+
num_buckets = self.num_buckets // 2
|
| 179 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
| 180 |
+
rel_pos = torch.abs(rel_pos)
|
| 181 |
+
else:
|
| 182 |
+
num_buckets = self.num_buckets
|
| 183 |
+
rel_buckets = 0
|
| 184 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
| 185 |
+
|
| 186 |
+
# embeddings for small and large positions
|
| 187 |
+
max_exact = num_buckets // 2
|
| 188 |
+
rel_pos_large = (
|
| 189 |
+
max_exact
|
| 190 |
+
+ (
|
| 191 |
+
torch.log(rel_pos.float() / max_exact)
|
| 192 |
+
/ math.log(self.max_dist / max_exact)
|
| 193 |
+
* (num_buckets - max_exact)
|
| 194 |
+
).long()
|
| 195 |
+
)
|
| 196 |
+
rel_pos_large = torch.min(
|
| 197 |
+
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
|
| 198 |
+
)
|
| 199 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
| 200 |
+
return rel_buckets
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def init_weights(m):
|
| 204 |
+
if isinstance(m, T5LayerNorm):
|
| 205 |
+
nn.init.ones_(m.weight)
|
| 206 |
+
elif isinstance(m, T5FeedForward):
|
| 207 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
| 208 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
| 209 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
| 210 |
+
elif isinstance(m, T5Attention):
|
| 211 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
|
| 212 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
| 213 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
| 214 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
|
| 215 |
+
elif isinstance(m, T5RelativeEmbedding):
|
| 216 |
+
nn.init.normal_(
|
| 217 |
+
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class WanTextEncoder(torch.nn.Module):
|
| 222 |
+
def __init__(
|
| 223 |
+
self,
|
| 224 |
+
vocab=256384,
|
| 225 |
+
dim=4096,
|
| 226 |
+
dim_attn=4096,
|
| 227 |
+
dim_ffn=10240,
|
| 228 |
+
num_heads=64,
|
| 229 |
+
num_layers=24,
|
| 230 |
+
num_buckets=32,
|
| 231 |
+
shared_pos=False,
|
| 232 |
+
dropout=0.1,
|
| 233 |
+
):
|
| 234 |
+
super(WanTextEncoder, self).__init__()
|
| 235 |
+
self.dim = dim
|
| 236 |
+
self.dim_attn = dim_attn
|
| 237 |
+
self.dim_ffn = dim_ffn
|
| 238 |
+
self.num_heads = num_heads
|
| 239 |
+
self.num_layers = num_layers
|
| 240 |
+
self.num_buckets = num_buckets
|
| 241 |
+
self.shared_pos = shared_pos
|
| 242 |
+
|
| 243 |
+
# layers
|
| 244 |
+
self.token_embedding = (
|
| 245 |
+
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
|
| 246 |
+
)
|
| 247 |
+
self.pos_embedding = (
|
| 248 |
+
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
| 249 |
+
if shared_pos
|
| 250 |
+
else None
|
| 251 |
+
)
|
| 252 |
+
self.dropout = nn.Dropout(dropout)
|
| 253 |
+
self.blocks = nn.ModuleList(
|
| 254 |
+
[
|
| 255 |
+
T5SelfAttention(
|
| 256 |
+
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
|
| 257 |
+
)
|
| 258 |
+
for _ in range(num_layers)
|
| 259 |
+
]
|
| 260 |
+
)
|
| 261 |
+
self.norm = T5LayerNorm(dim)
|
| 262 |
+
|
| 263 |
+
# initialize weights
|
| 264 |
+
self.apply(init_weights)
|
| 265 |
+
|
| 266 |
+
def forward(self, ids, mask=None):
|
| 267 |
+
x = self.token_embedding(ids)
|
| 268 |
+
x = self.dropout(x)
|
| 269 |
+
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
|
| 270 |
+
for block in self.blocks:
|
| 271 |
+
x = block(x, mask, pos_bias=e)
|
| 272 |
+
x = self.norm(x)
|
| 273 |
+
x = self.dropout(x)
|
| 274 |
+
return x
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
def state_dict_converter():
|
| 278 |
+
return WanTextEncoderStateDictConverter()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class WanTextEncoderStateDictConverter:
|
| 282 |
+
def __init__(self):
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
def from_diffusers(self, state_dict):
|
| 286 |
+
return state_dict
|
| 287 |
+
|
| 288 |
+
def from_civitai(self, state_dict):
|
| 289 |
+
return state_dict
|
models/wan_video_vace.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .wan_video_dit import DiTBlock
|
| 3 |
+
from .utils import hash_state_dict_keys
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class VaceWanAttentionBlock(DiTBlock):
|
| 7 |
+
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
| 8 |
+
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
|
| 9 |
+
self.block_id = block_id
|
| 10 |
+
if block_id == 0:
|
| 11 |
+
self.before_proj = torch.nn.Linear(self.dim, self.dim)
|
| 12 |
+
self.after_proj = torch.nn.Linear(self.dim, self.dim)
|
| 13 |
+
|
| 14 |
+
def forward(self, c, x, context, t_mod, freqs):
|
| 15 |
+
if self.block_id == 0:
|
| 16 |
+
c = self.before_proj(c) + x
|
| 17 |
+
all_c = []
|
| 18 |
+
else:
|
| 19 |
+
all_c = list(torch.unbind(c))
|
| 20 |
+
c = all_c.pop(-1)
|
| 21 |
+
c, _ = super().forward(c, context, t_mod, freqs)
|
| 22 |
+
c_skip = self.after_proj(c)
|
| 23 |
+
all_c += [c_skip, c]
|
| 24 |
+
c = torch.stack(all_c)
|
| 25 |
+
return c
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class VaceWanModel(torch.nn.Module):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
|
| 32 |
+
vace_in_dim=96,
|
| 33 |
+
patch_size=(1, 2, 2),
|
| 34 |
+
has_image_input=False,
|
| 35 |
+
dim=1536,
|
| 36 |
+
num_heads=12,
|
| 37 |
+
ffn_dim=8960,
|
| 38 |
+
eps=1e-6,
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.vace_layers = vace_layers
|
| 42 |
+
self.vace_in_dim = vace_in_dim
|
| 43 |
+
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
| 44 |
+
|
| 45 |
+
# vace blocks
|
| 46 |
+
self.vace_blocks = torch.nn.ModuleList(
|
| 47 |
+
[
|
| 48 |
+
VaceWanAttentionBlock(
|
| 49 |
+
has_image_input, dim, num_heads, ffn_dim, eps, block_id=i
|
| 50 |
+
)
|
| 51 |
+
for i in self.vace_layers
|
| 52 |
+
]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# vace patch embeddings
|
| 56 |
+
self.vace_patch_embedding = torch.nn.Conv3d(
|
| 57 |
+
vace_in_dim, dim, kernel_size=patch_size, stride=patch_size
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
x,
|
| 63 |
+
vace_context,
|
| 64 |
+
context,
|
| 65 |
+
t_mod,
|
| 66 |
+
freqs,
|
| 67 |
+
use_gradient_checkpointing: bool = False,
|
| 68 |
+
use_gradient_checkpointing_offload: bool = False,
|
| 69 |
+
):
|
| 70 |
+
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
| 71 |
+
c = [u.flatten(2).transpose(1, 2) for u in c]
|
| 72 |
+
c = torch.cat(
|
| 73 |
+
[
|
| 74 |
+
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))], dim=1)
|
| 75 |
+
for u in c
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def create_custom_forward(module):
|
| 80 |
+
def custom_forward(*inputs):
|
| 81 |
+
return module(*inputs)
|
| 82 |
+
|
| 83 |
+
return custom_forward
|
| 84 |
+
|
| 85 |
+
for block in self.vace_blocks:
|
| 86 |
+
if use_gradient_checkpointing_offload:
|
| 87 |
+
with torch.autograd.graph.save_on_cpu():
|
| 88 |
+
c = torch.utils.checkpoint.checkpoint(
|
| 89 |
+
create_custom_forward(block),
|
| 90 |
+
c,
|
| 91 |
+
x,
|
| 92 |
+
context,
|
| 93 |
+
t_mod,
|
| 94 |
+
freqs,
|
| 95 |
+
use_reentrant=False,
|
| 96 |
+
)
|
| 97 |
+
elif use_gradient_checkpointing:
|
| 98 |
+
c = torch.utils.checkpoint.checkpoint(
|
| 99 |
+
create_custom_forward(block),
|
| 100 |
+
c,
|
| 101 |
+
x,
|
| 102 |
+
context,
|
| 103 |
+
t_mod,
|
| 104 |
+
freqs,
|
| 105 |
+
use_reentrant=False,
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
c = block(c, x, context, t_mod, freqs)
|
| 109 |
+
hints = torch.unbind(c)[:-1]
|
| 110 |
+
return hints
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def state_dict_converter():
|
| 114 |
+
return VaceWanModelDictConverter()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class VaceWanModelDictConverter:
|
| 118 |
+
def __init__(self):
|
| 119 |
+
pass
|
| 120 |
+
|
| 121 |
+
def from_civitai(self, state_dict):
|
| 122 |
+
state_dict_ = {
|
| 123 |
+
name: param for name, param in state_dict.items() if name.startswith("vace")
|
| 124 |
+
}
|
| 125 |
+
if (
|
| 126 |
+
hash_state_dict_keys(state_dict_) == "3b2726384e4f64837bdf216eea3f310d"
|
| 127 |
+
): # vace 14B
|
| 128 |
+
config = {
|
| 129 |
+
"vace_layers": (0, 5, 10, 15, 20, 25, 30, 35),
|
| 130 |
+
"vace_in_dim": 96,
|
| 131 |
+
"patch_size": (1, 2, 2),
|
| 132 |
+
"has_image_input": False,
|
| 133 |
+
"dim": 5120,
|
| 134 |
+
"num_heads": 40,
|
| 135 |
+
"ffn_dim": 13824,
|
| 136 |
+
"eps": 1e-06,
|
| 137 |
+
}
|
| 138 |
+
else:
|
| 139 |
+
config = {}
|
| 140 |
+
return state_dict_, config
|
models/wan_video_vae.py
ADDED
|
@@ -0,0 +1,1634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from einops import rearrange, repeat
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
CACHE_T = 2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def check_is_instance(model, module_class):
|
| 12 |
+
if isinstance(model, module_class):
|
| 13 |
+
return True
|
| 14 |
+
if hasattr(model, "module") and isinstance(model.module, module_class):
|
| 15 |
+
return True
|
| 16 |
+
return False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def block_causal_mask(x, block_size):
|
| 20 |
+
# params
|
| 21 |
+
b, n, s, _, device = *x.size(), x.device
|
| 22 |
+
assert s % block_size == 0
|
| 23 |
+
num_blocks = s // block_size
|
| 24 |
+
|
| 25 |
+
# build mask
|
| 26 |
+
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
|
| 27 |
+
for i in range(num_blocks):
|
| 28 |
+
mask[:, :, i * block_size : (i + 1) * block_size, : (i + 1) * block_size] = 1
|
| 29 |
+
return mask
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class CausalConv3d(nn.Conv3d):
|
| 33 |
+
"""
|
| 34 |
+
Causal 3d convolusion.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, *args, **kwargs):
|
| 38 |
+
super().__init__(*args, **kwargs)
|
| 39 |
+
self._padding = (
|
| 40 |
+
self.padding[2],
|
| 41 |
+
self.padding[2],
|
| 42 |
+
self.padding[1],
|
| 43 |
+
self.padding[1],
|
| 44 |
+
2 * self.padding[0],
|
| 45 |
+
0,
|
| 46 |
+
)
|
| 47 |
+
self.padding = (0, 0, 0)
|
| 48 |
+
|
| 49 |
+
def forward(self, x, cache_x=None):
|
| 50 |
+
padding = list(self._padding)
|
| 51 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 52 |
+
cache_x = cache_x.to(x.device)
|
| 53 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 54 |
+
padding[4] -= cache_x.shape[2]
|
| 55 |
+
x = F.pad(x, padding)
|
| 56 |
+
|
| 57 |
+
return super().forward(x)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class RMS_norm(nn.Module):
|
| 61 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 62 |
+
super().__init__()
|
| 63 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 64 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 65 |
+
|
| 66 |
+
self.channel_first = channel_first
|
| 67 |
+
self.scale = dim**0.5
|
| 68 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 69 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
return (
|
| 73 |
+
F.normalize(x, dim=(1 if self.channel_first else -1))
|
| 74 |
+
* self.scale
|
| 75 |
+
* self.gamma
|
| 76 |
+
+ self.bias
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Upsample(nn.Upsample):
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
"""
|
| 83 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 84 |
+
"""
|
| 85 |
+
return super().forward(x.float()).type_as(x)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class Resample(nn.Module):
|
| 89 |
+
def __init__(self, dim, mode):
|
| 90 |
+
assert mode in (
|
| 91 |
+
"none",
|
| 92 |
+
"upsample2d",
|
| 93 |
+
"upsample3d",
|
| 94 |
+
"downsample2d",
|
| 95 |
+
"downsample3d",
|
| 96 |
+
)
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.dim = dim
|
| 99 |
+
self.mode = mode
|
| 100 |
+
|
| 101 |
+
# layers
|
| 102 |
+
if mode == "upsample2d":
|
| 103 |
+
self.resample = nn.Sequential(
|
| 104 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 105 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1),
|
| 106 |
+
)
|
| 107 |
+
elif mode == "upsample3d":
|
| 108 |
+
self.resample = nn.Sequential(
|
| 109 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 110 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1),
|
| 111 |
+
)
|
| 112 |
+
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 113 |
+
|
| 114 |
+
elif mode == "downsample2d":
|
| 115 |
+
self.resample = nn.Sequential(
|
| 116 |
+
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
| 117 |
+
)
|
| 118 |
+
elif mode == "downsample3d":
|
| 119 |
+
self.resample = nn.Sequential(
|
| 120 |
+
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
| 121 |
+
)
|
| 122 |
+
self.time_conv = CausalConv3d(
|
| 123 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
else:
|
| 127 |
+
self.resample = nn.Identity()
|
| 128 |
+
|
| 129 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 130 |
+
b, c, t, h, w = x.size()
|
| 131 |
+
if self.mode == "upsample3d":
|
| 132 |
+
if feat_cache is not None:
|
| 133 |
+
idx = feat_idx[0]
|
| 134 |
+
if feat_cache[idx] is None:
|
| 135 |
+
feat_cache[idx] = "Rep"
|
| 136 |
+
feat_idx[0] += 1
|
| 137 |
+
else:
|
| 138 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 139 |
+
if (
|
| 140 |
+
cache_x.shape[2] < 2
|
| 141 |
+
and feat_cache[idx] is not None
|
| 142 |
+
and feat_cache[idx] != "Rep"
|
| 143 |
+
):
|
| 144 |
+
# cache last frame of last two chunk
|
| 145 |
+
cache_x = torch.cat(
|
| 146 |
+
[
|
| 147 |
+
feat_cache[idx][:, :, -1, :, :]
|
| 148 |
+
.unsqueeze(2)
|
| 149 |
+
.to(cache_x.device),
|
| 150 |
+
cache_x,
|
| 151 |
+
],
|
| 152 |
+
dim=2,
|
| 153 |
+
)
|
| 154 |
+
if (
|
| 155 |
+
cache_x.shape[2] < 2
|
| 156 |
+
and feat_cache[idx] is not None
|
| 157 |
+
and feat_cache[idx] == "Rep"
|
| 158 |
+
):
|
| 159 |
+
cache_x = torch.cat(
|
| 160 |
+
[torch.zeros_like(cache_x).to(cache_x.device), cache_x],
|
| 161 |
+
dim=2,
|
| 162 |
+
)
|
| 163 |
+
if feat_cache[idx] == "Rep":
|
| 164 |
+
x = self.time_conv(x)
|
| 165 |
+
else:
|
| 166 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 167 |
+
feat_cache[idx] = cache_x
|
| 168 |
+
feat_idx[0] += 1
|
| 169 |
+
|
| 170 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 171 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
| 172 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 173 |
+
t = x.shape[2]
|
| 174 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 175 |
+
x = self.resample(x)
|
| 176 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
| 177 |
+
|
| 178 |
+
if self.mode == "downsample3d":
|
| 179 |
+
if feat_cache is not None:
|
| 180 |
+
idx = feat_idx[0]
|
| 181 |
+
if feat_cache[idx] is None:
|
| 182 |
+
feat_cache[idx] = x.clone()
|
| 183 |
+
feat_idx[0] += 1
|
| 184 |
+
else:
|
| 185 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 186 |
+
x = self.time_conv(
|
| 187 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
|
| 188 |
+
)
|
| 189 |
+
feat_cache[idx] = cache_x
|
| 190 |
+
feat_idx[0] += 1
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
def init_weight(self, conv):
|
| 194 |
+
conv_weight = conv.weight
|
| 195 |
+
nn.init.zeros_(conv_weight)
|
| 196 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 197 |
+
one_matrix = torch.eye(c1, c2)
|
| 198 |
+
init_matrix = one_matrix
|
| 199 |
+
nn.init.zeros_(conv_weight)
|
| 200 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix
|
| 201 |
+
conv.weight.data.copy_(conv_weight)
|
| 202 |
+
nn.init.zeros_(conv.bias.data)
|
| 203 |
+
|
| 204 |
+
def init_weight2(self, conv):
|
| 205 |
+
conv_weight = conv.weight.data
|
| 206 |
+
nn.init.zeros_(conv_weight)
|
| 207 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 208 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 209 |
+
conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
|
| 210 |
+
conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
|
| 211 |
+
conv.weight.data.copy_(conv_weight)
|
| 212 |
+
nn.init.zeros_(conv.bias.data)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def patchify(x, patch_size):
|
| 216 |
+
if patch_size == 1:
|
| 217 |
+
return x
|
| 218 |
+
if x.dim() == 4:
|
| 219 |
+
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
| 220 |
+
elif x.dim() == 5:
|
| 221 |
+
x = rearrange(
|
| 222 |
+
x, "b c f (h q) (w r) -> b (c r q) f h w", q=patch_size, r=patch_size
|
| 223 |
+
)
|
| 224 |
+
else:
|
| 225 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 226 |
+
return x
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def unpatchify(x, patch_size):
|
| 230 |
+
if patch_size == 1:
|
| 231 |
+
return x
|
| 232 |
+
if x.dim() == 4:
|
| 233 |
+
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
| 234 |
+
elif x.dim() == 5:
|
| 235 |
+
x = rearrange(
|
| 236 |
+
x, "b (c r q) f h w -> b c f (h q) (w r)", q=patch_size, r=patch_size
|
| 237 |
+
)
|
| 238 |
+
return x
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class Resample38(Resample):
|
| 242 |
+
def __init__(self, dim, mode):
|
| 243 |
+
assert mode in (
|
| 244 |
+
"none",
|
| 245 |
+
"upsample2d",
|
| 246 |
+
"upsample3d",
|
| 247 |
+
"downsample2d",
|
| 248 |
+
"downsample3d",
|
| 249 |
+
)
|
| 250 |
+
super(Resample, self).__init__()
|
| 251 |
+
self.dim = dim
|
| 252 |
+
self.mode = mode
|
| 253 |
+
|
| 254 |
+
# layers
|
| 255 |
+
if mode == "upsample2d":
|
| 256 |
+
self.resample = nn.Sequential(
|
| 257 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 258 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 259 |
+
)
|
| 260 |
+
elif mode == "upsample3d":
|
| 261 |
+
self.resample = nn.Sequential(
|
| 262 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 263 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 264 |
+
)
|
| 265 |
+
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 266 |
+
elif mode == "downsample2d":
|
| 267 |
+
self.resample = nn.Sequential(
|
| 268 |
+
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
| 269 |
+
)
|
| 270 |
+
elif mode == "downsample3d":
|
| 271 |
+
self.resample = nn.Sequential(
|
| 272 |
+
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
| 273 |
+
)
|
| 274 |
+
self.time_conv = CausalConv3d(
|
| 275 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
| 276 |
+
)
|
| 277 |
+
else:
|
| 278 |
+
self.resample = nn.Identity()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class ResidualBlock(nn.Module):
|
| 282 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.in_dim = in_dim
|
| 285 |
+
self.out_dim = out_dim
|
| 286 |
+
|
| 287 |
+
# layers
|
| 288 |
+
self.residual = nn.Sequential(
|
| 289 |
+
RMS_norm(in_dim, images=False),
|
| 290 |
+
nn.SiLU(),
|
| 291 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 292 |
+
RMS_norm(out_dim, images=False),
|
| 293 |
+
nn.SiLU(),
|
| 294 |
+
nn.Dropout(dropout),
|
| 295 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
| 296 |
+
)
|
| 297 |
+
self.shortcut = (
|
| 298 |
+
CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 302 |
+
h = self.shortcut(x)
|
| 303 |
+
for layer in self.residual:
|
| 304 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
| 305 |
+
idx = feat_idx[0]
|
| 306 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 307 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 308 |
+
# cache last frame of last two chunk
|
| 309 |
+
cache_x = torch.cat(
|
| 310 |
+
[
|
| 311 |
+
feat_cache[idx][:, :, -1, :, :]
|
| 312 |
+
.unsqueeze(2)
|
| 313 |
+
.to(cache_x.device),
|
| 314 |
+
cache_x,
|
| 315 |
+
],
|
| 316 |
+
dim=2,
|
| 317 |
+
)
|
| 318 |
+
x = layer(x, feat_cache[idx])
|
| 319 |
+
feat_cache[idx] = cache_x
|
| 320 |
+
feat_idx[0] += 1
|
| 321 |
+
else:
|
| 322 |
+
x = layer(x)
|
| 323 |
+
return x + h
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class AttentionBlock(nn.Module):
|
| 327 |
+
"""
|
| 328 |
+
Causal self-attention with a single head.
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def __init__(self, dim):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.dim = dim
|
| 334 |
+
|
| 335 |
+
# layers
|
| 336 |
+
self.norm = RMS_norm(dim)
|
| 337 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 338 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 339 |
+
|
| 340 |
+
# zero out the last layer params
|
| 341 |
+
nn.init.zeros_(self.proj.weight)
|
| 342 |
+
|
| 343 |
+
def forward(self, x):
|
| 344 |
+
identity = x
|
| 345 |
+
b, c, t, h, w = x.size()
|
| 346 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 347 |
+
x = self.norm(x)
|
| 348 |
+
# compute query, key, value
|
| 349 |
+
q, k, v = (
|
| 350 |
+
self.to_qkv(x)
|
| 351 |
+
.reshape(b * t, 1, c * 3, -1)
|
| 352 |
+
.permute(0, 1, 3, 2)
|
| 353 |
+
.contiguous()
|
| 354 |
+
.chunk(3, dim=-1)
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# apply attention
|
| 358 |
+
x = F.scaled_dot_product_attention(
|
| 359 |
+
q,
|
| 360 |
+
k,
|
| 361 |
+
v,
|
| 362 |
+
# attn_mask=block_causal_mask(q, block_size=h * w)
|
| 363 |
+
)
|
| 364 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
| 365 |
+
|
| 366 |
+
# output
|
| 367 |
+
x = self.proj(x)
|
| 368 |
+
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
|
| 369 |
+
return x + identity
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class AvgDown3D(nn.Module):
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
in_channels,
|
| 376 |
+
out_channels,
|
| 377 |
+
factor_t,
|
| 378 |
+
factor_s=1,
|
| 379 |
+
):
|
| 380 |
+
super().__init__()
|
| 381 |
+
self.in_channels = in_channels
|
| 382 |
+
self.out_channels = out_channels
|
| 383 |
+
self.factor_t = factor_t
|
| 384 |
+
self.factor_s = factor_s
|
| 385 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 386 |
+
|
| 387 |
+
assert in_channels * self.factor % out_channels == 0
|
| 388 |
+
self.group_size = in_channels * self.factor // out_channels
|
| 389 |
+
|
| 390 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 391 |
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
| 392 |
+
pad = (0, 0, 0, 0, pad_t, 0)
|
| 393 |
+
x = F.pad(x, pad)
|
| 394 |
+
B, C, T, H, W = x.shape
|
| 395 |
+
x = x.view(
|
| 396 |
+
B,
|
| 397 |
+
C,
|
| 398 |
+
T // self.factor_t,
|
| 399 |
+
self.factor_t,
|
| 400 |
+
H // self.factor_s,
|
| 401 |
+
self.factor_s,
|
| 402 |
+
W // self.factor_s,
|
| 403 |
+
self.factor_s,
|
| 404 |
+
)
|
| 405 |
+
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
| 406 |
+
x = x.view(
|
| 407 |
+
B,
|
| 408 |
+
C * self.factor,
|
| 409 |
+
T // self.factor_t,
|
| 410 |
+
H // self.factor_s,
|
| 411 |
+
W // self.factor_s,
|
| 412 |
+
)
|
| 413 |
+
x = x.view(
|
| 414 |
+
B,
|
| 415 |
+
self.out_channels,
|
| 416 |
+
self.group_size,
|
| 417 |
+
T // self.factor_t,
|
| 418 |
+
H // self.factor_s,
|
| 419 |
+
W // self.factor_s,
|
| 420 |
+
)
|
| 421 |
+
x = x.mean(dim=2)
|
| 422 |
+
return x
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class DupUp3D(nn.Module):
|
| 426 |
+
def __init__(
|
| 427 |
+
self,
|
| 428 |
+
in_channels: int,
|
| 429 |
+
out_channels: int,
|
| 430 |
+
factor_t,
|
| 431 |
+
factor_s=1,
|
| 432 |
+
):
|
| 433 |
+
super().__init__()
|
| 434 |
+
self.in_channels = in_channels
|
| 435 |
+
self.out_channels = out_channels
|
| 436 |
+
|
| 437 |
+
self.factor_t = factor_t
|
| 438 |
+
self.factor_s = factor_s
|
| 439 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 440 |
+
|
| 441 |
+
assert out_channels * self.factor % in_channels == 0
|
| 442 |
+
self.repeats = out_channels * self.factor // in_channels
|
| 443 |
+
|
| 444 |
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
| 445 |
+
x = x.repeat_interleave(self.repeats, dim=1)
|
| 446 |
+
x = x.view(
|
| 447 |
+
x.size(0),
|
| 448 |
+
self.out_channels,
|
| 449 |
+
self.factor_t,
|
| 450 |
+
self.factor_s,
|
| 451 |
+
self.factor_s,
|
| 452 |
+
x.size(2),
|
| 453 |
+
x.size(3),
|
| 454 |
+
x.size(4),
|
| 455 |
+
)
|
| 456 |
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
| 457 |
+
x = x.view(
|
| 458 |
+
x.size(0),
|
| 459 |
+
self.out_channels,
|
| 460 |
+
x.size(2) * self.factor_t,
|
| 461 |
+
x.size(4) * self.factor_s,
|
| 462 |
+
x.size(6) * self.factor_s,
|
| 463 |
+
)
|
| 464 |
+
if first_chunk:
|
| 465 |
+
x = x[:, :, self.factor_t - 1 :, :, :]
|
| 466 |
+
return x
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class Down_ResidualBlock(nn.Module):
|
| 470 |
+
def __init__(
|
| 471 |
+
self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False
|
| 472 |
+
):
|
| 473 |
+
super().__init__()
|
| 474 |
+
|
| 475 |
+
# Shortcut path with downsample
|
| 476 |
+
self.avg_shortcut = AvgDown3D(
|
| 477 |
+
in_dim,
|
| 478 |
+
out_dim,
|
| 479 |
+
factor_t=2 if temperal_downsample else 1,
|
| 480 |
+
factor_s=2 if down_flag else 1,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# Main path with residual blocks and downsample
|
| 484 |
+
downsamples = []
|
| 485 |
+
for _ in range(mult):
|
| 486 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 487 |
+
in_dim = out_dim
|
| 488 |
+
|
| 489 |
+
# Add the final downsample block
|
| 490 |
+
if down_flag:
|
| 491 |
+
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
| 492 |
+
downsamples.append(Resample38(out_dim, mode=mode))
|
| 493 |
+
|
| 494 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 495 |
+
|
| 496 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 497 |
+
x_copy = x.clone()
|
| 498 |
+
for module in self.downsamples:
|
| 499 |
+
x = module(x, feat_cache, feat_idx)
|
| 500 |
+
|
| 501 |
+
return x + self.avg_shortcut(x_copy)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
class Up_ResidualBlock(nn.Module):
|
| 505 |
+
def __init__(
|
| 506 |
+
self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False
|
| 507 |
+
):
|
| 508 |
+
super().__init__()
|
| 509 |
+
# Shortcut path with upsample
|
| 510 |
+
if up_flag:
|
| 511 |
+
self.avg_shortcut = DupUp3D(
|
| 512 |
+
in_dim,
|
| 513 |
+
out_dim,
|
| 514 |
+
factor_t=2 if temperal_upsample else 1,
|
| 515 |
+
factor_s=2 if up_flag else 1,
|
| 516 |
+
)
|
| 517 |
+
else:
|
| 518 |
+
self.avg_shortcut = None
|
| 519 |
+
|
| 520 |
+
# Main path with residual blocks and upsample
|
| 521 |
+
upsamples = []
|
| 522 |
+
for _ in range(mult):
|
| 523 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 524 |
+
in_dim = out_dim
|
| 525 |
+
|
| 526 |
+
# Add the final upsample block
|
| 527 |
+
if up_flag:
|
| 528 |
+
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
| 529 |
+
upsamples.append(Resample38(out_dim, mode=mode))
|
| 530 |
+
|
| 531 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 532 |
+
|
| 533 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 534 |
+
x_main = x.clone()
|
| 535 |
+
for module in self.upsamples:
|
| 536 |
+
x_main = module(x_main, feat_cache, feat_idx)
|
| 537 |
+
if self.avg_shortcut is not None:
|
| 538 |
+
x_shortcut = self.avg_shortcut(x, first_chunk)
|
| 539 |
+
return x_main + x_shortcut
|
| 540 |
+
else:
|
| 541 |
+
return x_main
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
class Encoder3d(nn.Module):
|
| 545 |
+
def __init__(
|
| 546 |
+
self,
|
| 547 |
+
dim=128,
|
| 548 |
+
z_dim=4,
|
| 549 |
+
dim_mult=[1, 2, 4, 4],
|
| 550 |
+
num_res_blocks=2,
|
| 551 |
+
attn_scales=[],
|
| 552 |
+
temperal_downsample=[True, True, False],
|
| 553 |
+
dropout=0.0,
|
| 554 |
+
):
|
| 555 |
+
super().__init__()
|
| 556 |
+
self.dim = dim
|
| 557 |
+
self.z_dim = z_dim
|
| 558 |
+
self.dim_mult = dim_mult
|
| 559 |
+
self.num_res_blocks = num_res_blocks
|
| 560 |
+
self.attn_scales = attn_scales
|
| 561 |
+
self.temperal_downsample = temperal_downsample
|
| 562 |
+
|
| 563 |
+
# dimensions
|
| 564 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 565 |
+
scale = 1.0
|
| 566 |
+
|
| 567 |
+
# init block
|
| 568 |
+
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
| 569 |
+
|
| 570 |
+
# downsample blocks
|
| 571 |
+
downsamples = []
|
| 572 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 573 |
+
# residual (+attention) blocks
|
| 574 |
+
for _ in range(num_res_blocks):
|
| 575 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 576 |
+
if scale in attn_scales:
|
| 577 |
+
downsamples.append(AttentionBlock(out_dim))
|
| 578 |
+
in_dim = out_dim
|
| 579 |
+
|
| 580 |
+
# downsample block
|
| 581 |
+
if i != len(dim_mult) - 1:
|
| 582 |
+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
| 583 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
| 584 |
+
scale /= 2.0
|
| 585 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 586 |
+
|
| 587 |
+
# middle blocks
|
| 588 |
+
self.middle = nn.Sequential(
|
| 589 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 590 |
+
AttentionBlock(out_dim),
|
| 591 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
# output blocks
|
| 595 |
+
self.head = nn.Sequential(
|
| 596 |
+
RMS_norm(out_dim, images=False),
|
| 597 |
+
nn.SiLU(),
|
| 598 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 602 |
+
if feat_cache is not None:
|
| 603 |
+
idx = feat_idx[0]
|
| 604 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 605 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 606 |
+
# cache last frame of last two chunk
|
| 607 |
+
cache_x = torch.cat(
|
| 608 |
+
[
|
| 609 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
| 610 |
+
cache_x,
|
| 611 |
+
],
|
| 612 |
+
dim=2,
|
| 613 |
+
)
|
| 614 |
+
x = self.conv1(x, feat_cache[idx])
|
| 615 |
+
feat_cache[idx] = cache_x
|
| 616 |
+
feat_idx[0] += 1
|
| 617 |
+
else:
|
| 618 |
+
x = self.conv1(x)
|
| 619 |
+
|
| 620 |
+
## downsamples
|
| 621 |
+
for layer in self.downsamples:
|
| 622 |
+
if feat_cache is not None:
|
| 623 |
+
x = layer(x, feat_cache, feat_idx)
|
| 624 |
+
else:
|
| 625 |
+
x = layer(x)
|
| 626 |
+
|
| 627 |
+
## middle
|
| 628 |
+
for layer in self.middle:
|
| 629 |
+
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
| 630 |
+
x = layer(x, feat_cache, feat_idx)
|
| 631 |
+
else:
|
| 632 |
+
x = layer(x)
|
| 633 |
+
|
| 634 |
+
## head
|
| 635 |
+
for layer in self.head:
|
| 636 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
| 637 |
+
idx = feat_idx[0]
|
| 638 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 639 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 640 |
+
# cache last frame of last two chunk
|
| 641 |
+
cache_x = torch.cat(
|
| 642 |
+
[
|
| 643 |
+
feat_cache[idx][:, :, -1, :, :]
|
| 644 |
+
.unsqueeze(2)
|
| 645 |
+
.to(cache_x.device),
|
| 646 |
+
cache_x,
|
| 647 |
+
],
|
| 648 |
+
dim=2,
|
| 649 |
+
)
|
| 650 |
+
x = layer(x, feat_cache[idx])
|
| 651 |
+
feat_cache[idx] = cache_x
|
| 652 |
+
feat_idx[0] += 1
|
| 653 |
+
else:
|
| 654 |
+
x = layer(x)
|
| 655 |
+
return x
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
class Encoder3d_38(nn.Module):
|
| 659 |
+
def __init__(
|
| 660 |
+
self,
|
| 661 |
+
dim=128,
|
| 662 |
+
z_dim=4,
|
| 663 |
+
dim_mult=[1, 2, 4, 4],
|
| 664 |
+
num_res_blocks=2,
|
| 665 |
+
attn_scales=[],
|
| 666 |
+
temperal_downsample=[False, True, True],
|
| 667 |
+
dropout=0.0,
|
| 668 |
+
):
|
| 669 |
+
super().__init__()
|
| 670 |
+
self.dim = dim
|
| 671 |
+
self.z_dim = z_dim
|
| 672 |
+
self.dim_mult = dim_mult
|
| 673 |
+
self.num_res_blocks = num_res_blocks
|
| 674 |
+
self.attn_scales = attn_scales
|
| 675 |
+
self.temperal_downsample = temperal_downsample
|
| 676 |
+
|
| 677 |
+
# dimensions
|
| 678 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 679 |
+
scale = 1.0
|
| 680 |
+
|
| 681 |
+
# init block
|
| 682 |
+
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
| 683 |
+
|
| 684 |
+
# downsample blocks
|
| 685 |
+
downsamples = []
|
| 686 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 687 |
+
t_down_flag = (
|
| 688 |
+
temperal_downsample[i] if i < len(temperal_downsample) else False
|
| 689 |
+
)
|
| 690 |
+
downsamples.append(
|
| 691 |
+
Down_ResidualBlock(
|
| 692 |
+
in_dim=in_dim,
|
| 693 |
+
out_dim=out_dim,
|
| 694 |
+
dropout=dropout,
|
| 695 |
+
mult=num_res_blocks,
|
| 696 |
+
temperal_downsample=t_down_flag,
|
| 697 |
+
down_flag=i != len(dim_mult) - 1,
|
| 698 |
+
)
|
| 699 |
+
)
|
| 700 |
+
scale /= 2.0
|
| 701 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 702 |
+
|
| 703 |
+
# middle blocks
|
| 704 |
+
self.middle = nn.Sequential(
|
| 705 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 706 |
+
AttentionBlock(out_dim),
|
| 707 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# # output blocks
|
| 711 |
+
self.head = nn.Sequential(
|
| 712 |
+
RMS_norm(out_dim, images=False),
|
| 713 |
+
nn.SiLU(),
|
| 714 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 718 |
+
if feat_cache is not None:
|
| 719 |
+
idx = feat_idx[0]
|
| 720 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 721 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 722 |
+
cache_x = torch.cat(
|
| 723 |
+
[
|
| 724 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
| 725 |
+
cache_x,
|
| 726 |
+
],
|
| 727 |
+
dim=2,
|
| 728 |
+
)
|
| 729 |
+
x = self.conv1(x, feat_cache[idx])
|
| 730 |
+
feat_cache[idx] = cache_x
|
| 731 |
+
feat_idx[0] += 1
|
| 732 |
+
else:
|
| 733 |
+
x = self.conv1(x)
|
| 734 |
+
|
| 735 |
+
## downsamples
|
| 736 |
+
for layer in self.downsamples:
|
| 737 |
+
if feat_cache is not None:
|
| 738 |
+
x = layer(x, feat_cache, feat_idx)
|
| 739 |
+
else:
|
| 740 |
+
x = layer(x)
|
| 741 |
+
|
| 742 |
+
## middle
|
| 743 |
+
for layer in self.middle:
|
| 744 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 745 |
+
x = layer(x, feat_cache, feat_idx)
|
| 746 |
+
else:
|
| 747 |
+
x = layer(x)
|
| 748 |
+
|
| 749 |
+
## head
|
| 750 |
+
for layer in self.head:
|
| 751 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 752 |
+
idx = feat_idx[0]
|
| 753 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 754 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 755 |
+
cache_x = torch.cat(
|
| 756 |
+
[
|
| 757 |
+
feat_cache[idx][:, :, -1, :, :]
|
| 758 |
+
.unsqueeze(2)
|
| 759 |
+
.to(cache_x.device),
|
| 760 |
+
cache_x,
|
| 761 |
+
],
|
| 762 |
+
dim=2,
|
| 763 |
+
)
|
| 764 |
+
x = layer(x, feat_cache[idx])
|
| 765 |
+
feat_cache[idx] = cache_x
|
| 766 |
+
feat_idx[0] += 1
|
| 767 |
+
else:
|
| 768 |
+
x = layer(x)
|
| 769 |
+
|
| 770 |
+
return x
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
class Decoder3d(nn.Module):
|
| 774 |
+
def __init__(
|
| 775 |
+
self,
|
| 776 |
+
dim=128,
|
| 777 |
+
z_dim=4,
|
| 778 |
+
dim_mult=[1, 2, 4, 4],
|
| 779 |
+
num_res_blocks=2,
|
| 780 |
+
attn_scales=[],
|
| 781 |
+
temperal_upsample=[False, True, True],
|
| 782 |
+
dropout=0.0,
|
| 783 |
+
):
|
| 784 |
+
super().__init__()
|
| 785 |
+
self.dim = dim
|
| 786 |
+
self.z_dim = z_dim
|
| 787 |
+
self.dim_mult = dim_mult
|
| 788 |
+
self.num_res_blocks = num_res_blocks
|
| 789 |
+
self.attn_scales = attn_scales
|
| 790 |
+
self.temperal_upsample = temperal_upsample
|
| 791 |
+
|
| 792 |
+
# dimensions
|
| 793 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 794 |
+
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
| 795 |
+
|
| 796 |
+
# init block
|
| 797 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 798 |
+
|
| 799 |
+
# middle blocks
|
| 800 |
+
self.middle = nn.Sequential(
|
| 801 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 802 |
+
AttentionBlock(dims[0]),
|
| 803 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
# upsample blocks
|
| 807 |
+
upsamples = []
|
| 808 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 809 |
+
# residual (+attention) blocks
|
| 810 |
+
if i == 1 or i == 2 or i == 3:
|
| 811 |
+
in_dim = in_dim // 2
|
| 812 |
+
for _ in range(num_res_blocks + 1):
|
| 813 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 814 |
+
if scale in attn_scales:
|
| 815 |
+
upsamples.append(AttentionBlock(out_dim))
|
| 816 |
+
in_dim = out_dim
|
| 817 |
+
|
| 818 |
+
# upsample block
|
| 819 |
+
if i != len(dim_mult) - 1:
|
| 820 |
+
mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
| 821 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 822 |
+
scale *= 2.0
|
| 823 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 824 |
+
|
| 825 |
+
# output blocks
|
| 826 |
+
self.head = nn.Sequential(
|
| 827 |
+
RMS_norm(out_dim, images=False),
|
| 828 |
+
nn.SiLU(),
|
| 829 |
+
CausalConv3d(out_dim, 3, 3, padding=1),
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 833 |
+
## conv1
|
| 834 |
+
if feat_cache is not None:
|
| 835 |
+
idx = feat_idx[0]
|
| 836 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 837 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 838 |
+
# cache last frame of last two chunk
|
| 839 |
+
cache_x = torch.cat(
|
| 840 |
+
[
|
| 841 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
| 842 |
+
cache_x,
|
| 843 |
+
],
|
| 844 |
+
dim=2,
|
| 845 |
+
)
|
| 846 |
+
x = self.conv1(x, feat_cache[idx])
|
| 847 |
+
feat_cache[idx] = cache_x
|
| 848 |
+
feat_idx[0] += 1
|
| 849 |
+
else:
|
| 850 |
+
x = self.conv1(x)
|
| 851 |
+
|
| 852 |
+
## middle
|
| 853 |
+
for layer in self.middle:
|
| 854 |
+
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
| 855 |
+
x = layer(x, feat_cache, feat_idx)
|
| 856 |
+
else:
|
| 857 |
+
x = layer(x)
|
| 858 |
+
|
| 859 |
+
## upsamples
|
| 860 |
+
for layer in self.upsamples:
|
| 861 |
+
if feat_cache is not None:
|
| 862 |
+
x = layer(x, feat_cache, feat_idx)
|
| 863 |
+
else:
|
| 864 |
+
x = layer(x)
|
| 865 |
+
|
| 866 |
+
## head
|
| 867 |
+
for layer in self.head:
|
| 868 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
| 869 |
+
idx = feat_idx[0]
|
| 870 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 871 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 872 |
+
# cache last frame of last two chunk
|
| 873 |
+
cache_x = torch.cat(
|
| 874 |
+
[
|
| 875 |
+
feat_cache[idx][:, :, -1, :, :]
|
| 876 |
+
.unsqueeze(2)
|
| 877 |
+
.to(cache_x.device),
|
| 878 |
+
cache_x,
|
| 879 |
+
],
|
| 880 |
+
dim=2,
|
| 881 |
+
)
|
| 882 |
+
x = layer(x, feat_cache[idx])
|
| 883 |
+
feat_cache[idx] = cache_x
|
| 884 |
+
feat_idx[0] += 1
|
| 885 |
+
else:
|
| 886 |
+
x = layer(x)
|
| 887 |
+
return x
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
class Decoder3d_38(nn.Module):
|
| 891 |
+
def __init__(
|
| 892 |
+
self,
|
| 893 |
+
dim=128,
|
| 894 |
+
z_dim=4,
|
| 895 |
+
dim_mult=[1, 2, 4, 4],
|
| 896 |
+
num_res_blocks=2,
|
| 897 |
+
attn_scales=[],
|
| 898 |
+
temperal_upsample=[False, True, True],
|
| 899 |
+
dropout=0.0,
|
| 900 |
+
):
|
| 901 |
+
super().__init__()
|
| 902 |
+
self.dim = dim
|
| 903 |
+
self.z_dim = z_dim
|
| 904 |
+
self.dim_mult = dim_mult
|
| 905 |
+
self.num_res_blocks = num_res_blocks
|
| 906 |
+
self.attn_scales = attn_scales
|
| 907 |
+
self.temperal_upsample = temperal_upsample
|
| 908 |
+
|
| 909 |
+
# dimensions
|
| 910 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 911 |
+
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
| 912 |
+
# init block
|
| 913 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 914 |
+
|
| 915 |
+
# middle blocks
|
| 916 |
+
self.middle = nn.Sequential(
|
| 917 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 918 |
+
AttentionBlock(dims[0]),
|
| 919 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
# upsample blocks
|
| 923 |
+
upsamples = []
|
| 924 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 925 |
+
t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
|
| 926 |
+
upsamples.append(
|
| 927 |
+
Up_ResidualBlock(
|
| 928 |
+
in_dim=in_dim,
|
| 929 |
+
out_dim=out_dim,
|
| 930 |
+
dropout=dropout,
|
| 931 |
+
mult=num_res_blocks + 1,
|
| 932 |
+
temperal_upsample=t_up_flag,
|
| 933 |
+
up_flag=i != len(dim_mult) - 1,
|
| 934 |
+
)
|
| 935 |
+
)
|
| 936 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 937 |
+
|
| 938 |
+
# output blocks
|
| 939 |
+
self.head = nn.Sequential(
|
| 940 |
+
RMS_norm(out_dim, images=False),
|
| 941 |
+
nn.SiLU(),
|
| 942 |
+
CausalConv3d(out_dim, 12, 3, padding=1),
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 946 |
+
if feat_cache is not None:
|
| 947 |
+
idx = feat_idx[0]
|
| 948 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 949 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 950 |
+
cache_x = torch.cat(
|
| 951 |
+
[
|
| 952 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
| 953 |
+
cache_x,
|
| 954 |
+
],
|
| 955 |
+
dim=2,
|
| 956 |
+
)
|
| 957 |
+
x = self.conv1(x, feat_cache[idx])
|
| 958 |
+
feat_cache[idx] = cache_x
|
| 959 |
+
feat_idx[0] += 1
|
| 960 |
+
else:
|
| 961 |
+
x = self.conv1(x)
|
| 962 |
+
|
| 963 |
+
for layer in self.middle:
|
| 964 |
+
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
| 965 |
+
x = layer(x, feat_cache, feat_idx)
|
| 966 |
+
else:
|
| 967 |
+
x = layer(x)
|
| 968 |
+
|
| 969 |
+
## upsamples
|
| 970 |
+
for layer in self.upsamples:
|
| 971 |
+
if feat_cache is not None:
|
| 972 |
+
x = layer(x, feat_cache, feat_idx, first_chunk)
|
| 973 |
+
else:
|
| 974 |
+
x = layer(x)
|
| 975 |
+
|
| 976 |
+
## head
|
| 977 |
+
for layer in self.head:
|
| 978 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
| 979 |
+
idx = feat_idx[0]
|
| 980 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 981 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 982 |
+
cache_x = torch.cat(
|
| 983 |
+
[
|
| 984 |
+
feat_cache[idx][:, :, -1, :, :]
|
| 985 |
+
.unsqueeze(2)
|
| 986 |
+
.to(cache_x.device),
|
| 987 |
+
cache_x,
|
| 988 |
+
],
|
| 989 |
+
dim=2,
|
| 990 |
+
)
|
| 991 |
+
x = layer(x, feat_cache[idx])
|
| 992 |
+
feat_cache[idx] = cache_x
|
| 993 |
+
feat_idx[0] += 1
|
| 994 |
+
else:
|
| 995 |
+
x = layer(x)
|
| 996 |
+
return x
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def count_conv3d(model):
|
| 1000 |
+
count = 0
|
| 1001 |
+
for m in model.modules():
|
| 1002 |
+
if isinstance(m, CausalConv3d):
|
| 1003 |
+
count += 1
|
| 1004 |
+
return count
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
class VideoVAE_(nn.Module):
|
| 1008 |
+
def __init__(
|
| 1009 |
+
self,
|
| 1010 |
+
dim=96,
|
| 1011 |
+
z_dim=16,
|
| 1012 |
+
dim_mult=[1, 2, 4, 4],
|
| 1013 |
+
num_res_blocks=2,
|
| 1014 |
+
attn_scales=[],
|
| 1015 |
+
temperal_downsample=[False, True, True],
|
| 1016 |
+
dropout=0.0,
|
| 1017 |
+
):
|
| 1018 |
+
super().__init__()
|
| 1019 |
+
self.dim = dim
|
| 1020 |
+
self.z_dim = z_dim
|
| 1021 |
+
self.dim_mult = dim_mult
|
| 1022 |
+
self.num_res_blocks = num_res_blocks
|
| 1023 |
+
self.attn_scales = attn_scales
|
| 1024 |
+
self.temperal_downsample = temperal_downsample
|
| 1025 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 1026 |
+
|
| 1027 |
+
# modules
|
| 1028 |
+
self.encoder = Encoder3d(
|
| 1029 |
+
dim,
|
| 1030 |
+
z_dim * 2,
|
| 1031 |
+
dim_mult,
|
| 1032 |
+
num_res_blocks,
|
| 1033 |
+
attn_scales,
|
| 1034 |
+
self.temperal_downsample,
|
| 1035 |
+
dropout,
|
| 1036 |
+
)
|
| 1037 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 1038 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 1039 |
+
self.decoder = Decoder3d(
|
| 1040 |
+
dim,
|
| 1041 |
+
z_dim,
|
| 1042 |
+
dim_mult,
|
| 1043 |
+
num_res_blocks,
|
| 1044 |
+
attn_scales,
|
| 1045 |
+
self.temperal_upsample,
|
| 1046 |
+
dropout,
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
def forward(self, x):
|
| 1050 |
+
mu, log_var = self.encode(x)
|
| 1051 |
+
z = self.reparameterize(mu, log_var)
|
| 1052 |
+
x_recon = self.decode(z)
|
| 1053 |
+
return x_recon, mu, log_var
|
| 1054 |
+
|
| 1055 |
+
def encode(self, x, scale):
|
| 1056 |
+
self.clear_cache()
|
| 1057 |
+
## cache
|
| 1058 |
+
t = x.shape[2]
|
| 1059 |
+
iter_ = 1 + (t - 1) // 4
|
| 1060 |
+
|
| 1061 |
+
for i in range(iter_):
|
| 1062 |
+
self._enc_conv_idx = [0]
|
| 1063 |
+
if i == 0:
|
| 1064 |
+
out = self.encoder(
|
| 1065 |
+
x[:, :, :1, :, :],
|
| 1066 |
+
feat_cache=self._enc_feat_map,
|
| 1067 |
+
feat_idx=self._enc_conv_idx,
|
| 1068 |
+
)
|
| 1069 |
+
else:
|
| 1070 |
+
out_ = self.encoder(
|
| 1071 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
|
| 1072 |
+
feat_cache=self._enc_feat_map,
|
| 1073 |
+
feat_idx=self._enc_conv_idx,
|
| 1074 |
+
)
|
| 1075 |
+
out = torch.cat([out, out_], 2)
|
| 1076 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 1077 |
+
if isinstance(scale[0], torch.Tensor):
|
| 1078 |
+
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
| 1079 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
| 1080 |
+
1, self.z_dim, 1, 1, 1
|
| 1081 |
+
)
|
| 1082 |
+
else:
|
| 1083 |
+
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
| 1084 |
+
mu = (mu - scale[0]) * scale[1]
|
| 1085 |
+
return mu
|
| 1086 |
+
|
| 1087 |
+
def decode(self, z, scale):
|
| 1088 |
+
self.clear_cache()
|
| 1089 |
+
# z: [b,c,t,h,w]
|
| 1090 |
+
if isinstance(scale[0], torch.Tensor):
|
| 1091 |
+
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
| 1092 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 1093 |
+
1, self.z_dim, 1, 1, 1
|
| 1094 |
+
)
|
| 1095 |
+
else:
|
| 1096 |
+
scale = scale.to(dtype=z.dtype, device=z.device)
|
| 1097 |
+
z = z / scale[1] + scale[0]
|
| 1098 |
+
iter_ = z.shape[2]
|
| 1099 |
+
x = self.conv2(z)
|
| 1100 |
+
for i in range(iter_):
|
| 1101 |
+
self._conv_idx = [0]
|
| 1102 |
+
if i == 0:
|
| 1103 |
+
out = self.decoder(
|
| 1104 |
+
x[:, :, i : i + 1, :, :],
|
| 1105 |
+
feat_cache=self._feat_map,
|
| 1106 |
+
feat_idx=self._conv_idx,
|
| 1107 |
+
)
|
| 1108 |
+
else:
|
| 1109 |
+
out_ = self.decoder(
|
| 1110 |
+
x[:, :, i : i + 1, :, :],
|
| 1111 |
+
feat_cache=self._feat_map,
|
| 1112 |
+
feat_idx=self._conv_idx,
|
| 1113 |
+
)
|
| 1114 |
+
out = torch.cat([out, out_], 2) # may add tensor offload
|
| 1115 |
+
return out
|
| 1116 |
+
|
| 1117 |
+
def reparameterize(self, mu, log_var):
|
| 1118 |
+
std = torch.exp(0.5 * log_var)
|
| 1119 |
+
eps = torch.randn_like(std)
|
| 1120 |
+
return eps * std + mu
|
| 1121 |
+
|
| 1122 |
+
def sample(self, imgs, deterministic=False):
|
| 1123 |
+
mu, log_var = self.encode(imgs)
|
| 1124 |
+
if deterministic:
|
| 1125 |
+
return mu
|
| 1126 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 1127 |
+
return mu + std * torch.randn_like(std)
|
| 1128 |
+
|
| 1129 |
+
def clear_cache(self):
|
| 1130 |
+
self._conv_num = count_conv3d(self.decoder)
|
| 1131 |
+
self._conv_idx = [0]
|
| 1132 |
+
self._feat_map = [None] * self._conv_num
|
| 1133 |
+
# cache encode
|
| 1134 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
| 1135 |
+
self._enc_conv_idx = [0]
|
| 1136 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 1137 |
+
|
| 1138 |
+
|
| 1139 |
+
class WanVideoVAE(nn.Module):
|
| 1140 |
+
def __init__(self, z_dim=16):
|
| 1141 |
+
super().__init__()
|
| 1142 |
+
|
| 1143 |
+
mean = [
|
| 1144 |
+
-0.7571,
|
| 1145 |
+
-0.7089,
|
| 1146 |
+
-0.9113,
|
| 1147 |
+
0.1075,
|
| 1148 |
+
-0.1745,
|
| 1149 |
+
0.9653,
|
| 1150 |
+
-0.1517,
|
| 1151 |
+
1.5508,
|
| 1152 |
+
0.4134,
|
| 1153 |
+
-0.0715,
|
| 1154 |
+
0.5517,
|
| 1155 |
+
-0.3632,
|
| 1156 |
+
-0.1922,
|
| 1157 |
+
-0.9497,
|
| 1158 |
+
0.2503,
|
| 1159 |
+
-0.2921,
|
| 1160 |
+
]
|
| 1161 |
+
std = [
|
| 1162 |
+
2.8184,
|
| 1163 |
+
1.4541,
|
| 1164 |
+
2.3275,
|
| 1165 |
+
2.6558,
|
| 1166 |
+
1.2196,
|
| 1167 |
+
1.7708,
|
| 1168 |
+
2.6052,
|
| 1169 |
+
2.0743,
|
| 1170 |
+
3.2687,
|
| 1171 |
+
2.1526,
|
| 1172 |
+
2.8652,
|
| 1173 |
+
1.5579,
|
| 1174 |
+
1.6382,
|
| 1175 |
+
1.1253,
|
| 1176 |
+
2.8251,
|
| 1177 |
+
1.9160,
|
| 1178 |
+
]
|
| 1179 |
+
self.mean = torch.tensor(mean)
|
| 1180 |
+
self.std = torch.tensor(std)
|
| 1181 |
+
self.scale = [self.mean, 1.0 / self.std]
|
| 1182 |
+
|
| 1183 |
+
# init model
|
| 1184 |
+
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
|
| 1185 |
+
self.upsampling_factor = 8
|
| 1186 |
+
self.z_dim = z_dim
|
| 1187 |
+
|
| 1188 |
+
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
| 1189 |
+
x = torch.ones((length,))
|
| 1190 |
+
if not left_bound:
|
| 1191 |
+
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
| 1192 |
+
if not right_bound:
|
| 1193 |
+
x[-border_width:] = torch.flip(
|
| 1194 |
+
(torch.arange(border_width) + 1) / border_width, dims=(0,)
|
| 1195 |
+
)
|
| 1196 |
+
return x
|
| 1197 |
+
|
| 1198 |
+
def build_mask(self, data, is_bound, border_width):
|
| 1199 |
+
_, _, _, H, W = data.shape
|
| 1200 |
+
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
|
| 1201 |
+
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
|
| 1202 |
+
|
| 1203 |
+
h = repeat(h, "H -> H W", H=H, W=W)
|
| 1204 |
+
w = repeat(w, "W -> H W", H=H, W=W)
|
| 1205 |
+
|
| 1206 |
+
mask = torch.stack([h, w]).min(dim=0).values
|
| 1207 |
+
mask = rearrange(mask, "H W -> 1 1 1 H W")
|
| 1208 |
+
return mask
|
| 1209 |
+
|
| 1210 |
+
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
|
| 1211 |
+
_, _, T, H, W = hidden_states.shape
|
| 1212 |
+
size_h, size_w = tile_size
|
| 1213 |
+
stride_h, stride_w = tile_stride
|
| 1214 |
+
|
| 1215 |
+
# Split tasks
|
| 1216 |
+
tasks = []
|
| 1217 |
+
for h in range(0, H, stride_h):
|
| 1218 |
+
if h - stride_h >= 0 and h - stride_h + size_h >= H:
|
| 1219 |
+
continue
|
| 1220 |
+
for w in range(0, W, stride_w):
|
| 1221 |
+
if w - stride_w >= 0 and w - stride_w + size_w >= W:
|
| 1222 |
+
continue
|
| 1223 |
+
h_, w_ = h + size_h, w + size_w
|
| 1224 |
+
tasks.append((h, h_, w, w_))
|
| 1225 |
+
|
| 1226 |
+
data_device = "cpu"
|
| 1227 |
+
computation_device = device
|
| 1228 |
+
|
| 1229 |
+
out_T = T * 4 - 3
|
| 1230 |
+
weight = torch.zeros(
|
| 1231 |
+
(1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor),
|
| 1232 |
+
dtype=hidden_states.dtype,
|
| 1233 |
+
device=data_device,
|
| 1234 |
+
)
|
| 1235 |
+
values = torch.zeros(
|
| 1236 |
+
(1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor),
|
| 1237 |
+
dtype=hidden_states.dtype,
|
| 1238 |
+
device=data_device,
|
| 1239 |
+
)
|
| 1240 |
+
|
| 1241 |
+
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
| 1242 |
+
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(
|
| 1243 |
+
computation_device
|
| 1244 |
+
)
|
| 1245 |
+
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(
|
| 1246 |
+
data_device
|
| 1247 |
+
)
|
| 1248 |
+
|
| 1249 |
+
mask = self.build_mask(
|
| 1250 |
+
hidden_states_batch,
|
| 1251 |
+
is_bound=(h == 0, h_ >= H, w == 0, w_ >= W),
|
| 1252 |
+
border_width=(
|
| 1253 |
+
(size_h - stride_h) * self.upsampling_factor,
|
| 1254 |
+
(size_w - stride_w) * self.upsampling_factor,
|
| 1255 |
+
),
|
| 1256 |
+
).to(dtype=hidden_states.dtype, device=data_device)
|
| 1257 |
+
|
| 1258 |
+
target_h = h * self.upsampling_factor
|
| 1259 |
+
target_w = w * self.upsampling_factor
|
| 1260 |
+
values[
|
| 1261 |
+
:,
|
| 1262 |
+
:,
|
| 1263 |
+
:,
|
| 1264 |
+
target_h : target_h + hidden_states_batch.shape[3],
|
| 1265 |
+
target_w : target_w + hidden_states_batch.shape[4],
|
| 1266 |
+
] += hidden_states_batch * mask
|
| 1267 |
+
weight[
|
| 1268 |
+
:,
|
| 1269 |
+
:,
|
| 1270 |
+
:,
|
| 1271 |
+
target_h : target_h + hidden_states_batch.shape[3],
|
| 1272 |
+
target_w : target_w + hidden_states_batch.shape[4],
|
| 1273 |
+
] += mask
|
| 1274 |
+
values = values / weight
|
| 1275 |
+
values = values.clamp_(-1, 1)
|
| 1276 |
+
return values
|
| 1277 |
+
|
| 1278 |
+
def tiled_encode(self, video, device, tile_size, tile_stride):
|
| 1279 |
+
_, _, T, H, W = video.shape
|
| 1280 |
+
size_h, size_w = tile_size
|
| 1281 |
+
stride_h, stride_w = tile_stride
|
| 1282 |
+
|
| 1283 |
+
# Split tasks
|
| 1284 |
+
tasks = []
|
| 1285 |
+
for h in range(0, H, stride_h):
|
| 1286 |
+
if h - stride_h >= 0 and h - stride_h + size_h >= H:
|
| 1287 |
+
continue
|
| 1288 |
+
for w in range(0, W, stride_w):
|
| 1289 |
+
if w - stride_w >= 0 and w - stride_w + size_w >= W:
|
| 1290 |
+
continue
|
| 1291 |
+
h_, w_ = h + size_h, w + size_w
|
| 1292 |
+
tasks.append((h, h_, w, w_))
|
| 1293 |
+
|
| 1294 |
+
data_device = "cpu"
|
| 1295 |
+
computation_device = device
|
| 1296 |
+
|
| 1297 |
+
out_T = (T + 3) // 4
|
| 1298 |
+
weight = torch.zeros(
|
| 1299 |
+
(1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor),
|
| 1300 |
+
dtype=video.dtype,
|
| 1301 |
+
device=data_device,
|
| 1302 |
+
)
|
| 1303 |
+
values = torch.zeros(
|
| 1304 |
+
(
|
| 1305 |
+
1,
|
| 1306 |
+
self.z_dim,
|
| 1307 |
+
out_T,
|
| 1308 |
+
H // self.upsampling_factor,
|
| 1309 |
+
W // self.upsampling_factor,
|
| 1310 |
+
),
|
| 1311 |
+
dtype=video.dtype,
|
| 1312 |
+
device=data_device,
|
| 1313 |
+
)
|
| 1314 |
+
|
| 1315 |
+
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
| 1316 |
+
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
| 1317 |
+
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(
|
| 1318 |
+
data_device
|
| 1319 |
+
)
|
| 1320 |
+
|
| 1321 |
+
mask = self.build_mask(
|
| 1322 |
+
hidden_states_batch,
|
| 1323 |
+
is_bound=(h == 0, h_ >= H, w == 0, w_ >= W),
|
| 1324 |
+
border_width=(
|
| 1325 |
+
(size_h - stride_h) // self.upsampling_factor,
|
| 1326 |
+
(size_w - stride_w) // self.upsampling_factor,
|
| 1327 |
+
),
|
| 1328 |
+
).to(dtype=video.dtype, device=data_device)
|
| 1329 |
+
|
| 1330 |
+
target_h = h // self.upsampling_factor
|
| 1331 |
+
target_w = w // self.upsampling_factor
|
| 1332 |
+
values[
|
| 1333 |
+
:,
|
| 1334 |
+
:,
|
| 1335 |
+
:,
|
| 1336 |
+
target_h : target_h + hidden_states_batch.shape[3],
|
| 1337 |
+
target_w : target_w + hidden_states_batch.shape[4],
|
| 1338 |
+
] += hidden_states_batch * mask
|
| 1339 |
+
weight[
|
| 1340 |
+
:,
|
| 1341 |
+
:,
|
| 1342 |
+
:,
|
| 1343 |
+
target_h : target_h + hidden_states_batch.shape[3],
|
| 1344 |
+
target_w : target_w + hidden_states_batch.shape[4],
|
| 1345 |
+
] += mask
|
| 1346 |
+
values = values / weight
|
| 1347 |
+
return values
|
| 1348 |
+
|
| 1349 |
+
def single_encode(self, video, device):
|
| 1350 |
+
video = video.to(device)
|
| 1351 |
+
x = self.model.encode(video, self.scale)
|
| 1352 |
+
return x
|
| 1353 |
+
|
| 1354 |
+
def single_decode(self, hidden_state, device):
|
| 1355 |
+
hidden_state = hidden_state.to(device)
|
| 1356 |
+
video = self.model.decode(hidden_state, self.scale)
|
| 1357 |
+
return video.clamp_(-1, 1)
|
| 1358 |
+
|
| 1359 |
+
def encode(
|
| 1360 |
+
self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)
|
| 1361 |
+
):
|
| 1362 |
+
videos = [video.to("cpu") for video in videos]
|
| 1363 |
+
hidden_states = []
|
| 1364 |
+
for video in videos:
|
| 1365 |
+
video = video.unsqueeze(0)
|
| 1366 |
+
if tiled:
|
| 1367 |
+
tile_size = (
|
| 1368 |
+
tile_size[0] * self.upsampling_factor,
|
| 1369 |
+
tile_size[1] * self.upsampling_factor,
|
| 1370 |
+
)
|
| 1371 |
+
tile_stride = (
|
| 1372 |
+
tile_stride[0] * self.upsampling_factor,
|
| 1373 |
+
tile_stride[1] * self.upsampling_factor,
|
| 1374 |
+
)
|
| 1375 |
+
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
| 1376 |
+
else:
|
| 1377 |
+
hidden_state = self.single_encode(video, device)
|
| 1378 |
+
hidden_state = hidden_state.squeeze(0)
|
| 1379 |
+
hidden_states.append(hidden_state)
|
| 1380 |
+
hidden_states = torch.stack(hidden_states)
|
| 1381 |
+
return hidden_states
|
| 1382 |
+
|
| 1383 |
+
def decode(
|
| 1384 |
+
self,
|
| 1385 |
+
hidden_states,
|
| 1386 |
+
device,
|
| 1387 |
+
tiled=False,
|
| 1388 |
+
tile_size=(34, 34),
|
| 1389 |
+
tile_stride=(18, 16),
|
| 1390 |
+
):
|
| 1391 |
+
if tiled:
|
| 1392 |
+
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
|
| 1393 |
+
else:
|
| 1394 |
+
video = self.single_decode(hidden_states, device)
|
| 1395 |
+
return video
|
| 1396 |
+
|
| 1397 |
+
@staticmethod
|
| 1398 |
+
def state_dict_converter():
|
| 1399 |
+
return WanVideoVAEStateDictConverter()
|
| 1400 |
+
|
| 1401 |
+
|
| 1402 |
+
class WanVideoVAEStateDictConverter:
|
| 1403 |
+
def __init__(self):
|
| 1404 |
+
pass
|
| 1405 |
+
|
| 1406 |
+
def from_civitai(self, state_dict):
|
| 1407 |
+
state_dict_ = {}
|
| 1408 |
+
if "model_state" in state_dict:
|
| 1409 |
+
state_dict = state_dict["model_state"]
|
| 1410 |
+
for name in state_dict:
|
| 1411 |
+
state_dict_["model." + name] = state_dict[name]
|
| 1412 |
+
return state_dict_
|
| 1413 |
+
|
| 1414 |
+
|
| 1415 |
+
class VideoVAE38_(VideoVAE_):
|
| 1416 |
+
def __init__(
|
| 1417 |
+
self,
|
| 1418 |
+
dim=160,
|
| 1419 |
+
z_dim=48,
|
| 1420 |
+
dec_dim=256,
|
| 1421 |
+
dim_mult=[1, 2, 4, 4],
|
| 1422 |
+
num_res_blocks=2,
|
| 1423 |
+
attn_scales=[],
|
| 1424 |
+
temperal_downsample=[False, True, True],
|
| 1425 |
+
dropout=0.0,
|
| 1426 |
+
):
|
| 1427 |
+
super(VideoVAE_, self).__init__()
|
| 1428 |
+
self.dim = dim
|
| 1429 |
+
self.z_dim = z_dim
|
| 1430 |
+
self.dim_mult = dim_mult
|
| 1431 |
+
self.num_res_blocks = num_res_blocks
|
| 1432 |
+
self.attn_scales = attn_scales
|
| 1433 |
+
self.temperal_downsample = temperal_downsample
|
| 1434 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 1435 |
+
|
| 1436 |
+
# modules
|
| 1437 |
+
self.encoder = Encoder3d_38(
|
| 1438 |
+
dim,
|
| 1439 |
+
z_dim * 2,
|
| 1440 |
+
dim_mult,
|
| 1441 |
+
num_res_blocks,
|
| 1442 |
+
attn_scales,
|
| 1443 |
+
self.temperal_downsample,
|
| 1444 |
+
dropout,
|
| 1445 |
+
)
|
| 1446 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 1447 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 1448 |
+
self.decoder = Decoder3d_38(
|
| 1449 |
+
dec_dim,
|
| 1450 |
+
z_dim,
|
| 1451 |
+
dim_mult,
|
| 1452 |
+
num_res_blocks,
|
| 1453 |
+
attn_scales,
|
| 1454 |
+
self.temperal_upsample,
|
| 1455 |
+
dropout,
|
| 1456 |
+
)
|
| 1457 |
+
|
| 1458 |
+
def encode(self, x, scale):
|
| 1459 |
+
self.clear_cache()
|
| 1460 |
+
x = patchify(x, patch_size=2)
|
| 1461 |
+
t = x.shape[2]
|
| 1462 |
+
iter_ = 1 + (t - 1) // 4
|
| 1463 |
+
for i in range(iter_):
|
| 1464 |
+
self._enc_conv_idx = [0]
|
| 1465 |
+
if i == 0:
|
| 1466 |
+
out = self.encoder(
|
| 1467 |
+
x[:, :, :1, :, :],
|
| 1468 |
+
feat_cache=self._enc_feat_map,
|
| 1469 |
+
feat_idx=self._enc_conv_idx,
|
| 1470 |
+
)
|
| 1471 |
+
else:
|
| 1472 |
+
out_ = self.encoder(
|
| 1473 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
|
| 1474 |
+
feat_cache=self._enc_feat_map,
|
| 1475 |
+
feat_idx=self._enc_conv_idx,
|
| 1476 |
+
)
|
| 1477 |
+
out = torch.cat([out, out_], 2)
|
| 1478 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 1479 |
+
if isinstance(scale[0], torch.Tensor):
|
| 1480 |
+
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
| 1481 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
| 1482 |
+
1, self.z_dim, 1, 1, 1
|
| 1483 |
+
)
|
| 1484 |
+
else:
|
| 1485 |
+
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
| 1486 |
+
mu = (mu - scale[0]) * scale[1]
|
| 1487 |
+
self.clear_cache()
|
| 1488 |
+
return mu
|
| 1489 |
+
|
| 1490 |
+
def decode(self, z, scale):
|
| 1491 |
+
self.clear_cache()
|
| 1492 |
+
if isinstance(scale[0], torch.Tensor):
|
| 1493 |
+
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
| 1494 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 1495 |
+
1, self.z_dim, 1, 1, 1
|
| 1496 |
+
)
|
| 1497 |
+
else:
|
| 1498 |
+
scale = scale.to(dtype=z.dtype, device=z.device)
|
| 1499 |
+
z = z / scale[1] + scale[0]
|
| 1500 |
+
iter_ = z.shape[2]
|
| 1501 |
+
x = self.conv2(z)
|
| 1502 |
+
for i in range(iter_):
|
| 1503 |
+
self._conv_idx = [0]
|
| 1504 |
+
if i == 0:
|
| 1505 |
+
out = self.decoder(
|
| 1506 |
+
x[:, :, i : i + 1, :, :],
|
| 1507 |
+
feat_cache=self._feat_map,
|
| 1508 |
+
feat_idx=self._conv_idx,
|
| 1509 |
+
first_chunk=True,
|
| 1510 |
+
)
|
| 1511 |
+
else:
|
| 1512 |
+
out_ = self.decoder(
|
| 1513 |
+
x[:, :, i : i + 1, :, :],
|
| 1514 |
+
feat_cache=self._feat_map,
|
| 1515 |
+
feat_idx=self._conv_idx,
|
| 1516 |
+
)
|
| 1517 |
+
out = torch.cat([out, out_], 2)
|
| 1518 |
+
out = unpatchify(out, patch_size=2)
|
| 1519 |
+
self.clear_cache()
|
| 1520 |
+
return out
|
| 1521 |
+
|
| 1522 |
+
|
| 1523 |
+
class WanVideoVAE38(WanVideoVAE):
|
| 1524 |
+
def __init__(self, z_dim=48, dim=160):
|
| 1525 |
+
super(WanVideoVAE, self).__init__()
|
| 1526 |
+
|
| 1527 |
+
mean = [
|
| 1528 |
+
-0.2289,
|
| 1529 |
+
-0.0052,
|
| 1530 |
+
-0.1323,
|
| 1531 |
+
-0.2339,
|
| 1532 |
+
-0.2799,
|
| 1533 |
+
0.0174,
|
| 1534 |
+
0.1838,
|
| 1535 |
+
0.1557,
|
| 1536 |
+
-0.1382,
|
| 1537 |
+
0.0542,
|
| 1538 |
+
0.2813,
|
| 1539 |
+
0.0891,
|
| 1540 |
+
0.1570,
|
| 1541 |
+
-0.0098,
|
| 1542 |
+
0.0375,
|
| 1543 |
+
-0.1825,
|
| 1544 |
+
-0.2246,
|
| 1545 |
+
-0.1207,
|
| 1546 |
+
-0.0698,
|
| 1547 |
+
0.5109,
|
| 1548 |
+
0.2665,
|
| 1549 |
+
-0.2108,
|
| 1550 |
+
-0.2158,
|
| 1551 |
+
0.2502,
|
| 1552 |
+
-0.2055,
|
| 1553 |
+
-0.0322,
|
| 1554 |
+
0.1109,
|
| 1555 |
+
0.1567,
|
| 1556 |
+
-0.0729,
|
| 1557 |
+
0.0899,
|
| 1558 |
+
-0.2799,
|
| 1559 |
+
-0.1230,
|
| 1560 |
+
-0.0313,
|
| 1561 |
+
-0.1649,
|
| 1562 |
+
0.0117,
|
| 1563 |
+
0.0723,
|
| 1564 |
+
-0.2839,
|
| 1565 |
+
-0.2083,
|
| 1566 |
+
-0.0520,
|
| 1567 |
+
0.3748,
|
| 1568 |
+
0.0152,
|
| 1569 |
+
0.1957,
|
| 1570 |
+
0.1433,
|
| 1571 |
+
-0.2944,
|
| 1572 |
+
0.3573,
|
| 1573 |
+
-0.0548,
|
| 1574 |
+
-0.1681,
|
| 1575 |
+
-0.0667,
|
| 1576 |
+
]
|
| 1577 |
+
std = [
|
| 1578 |
+
0.4765,
|
| 1579 |
+
1.0364,
|
| 1580 |
+
0.4514,
|
| 1581 |
+
1.1677,
|
| 1582 |
+
0.5313,
|
| 1583 |
+
0.4990,
|
| 1584 |
+
0.4818,
|
| 1585 |
+
0.5013,
|
| 1586 |
+
0.8158,
|
| 1587 |
+
1.0344,
|
| 1588 |
+
0.5894,
|
| 1589 |
+
1.0901,
|
| 1590 |
+
0.6885,
|
| 1591 |
+
0.6165,
|
| 1592 |
+
0.8454,
|
| 1593 |
+
0.4978,
|
| 1594 |
+
0.5759,
|
| 1595 |
+
0.3523,
|
| 1596 |
+
0.7135,
|
| 1597 |
+
0.6804,
|
| 1598 |
+
0.5833,
|
| 1599 |
+
1.4146,
|
| 1600 |
+
0.8986,
|
| 1601 |
+
0.5659,
|
| 1602 |
+
0.7069,
|
| 1603 |
+
0.5338,
|
| 1604 |
+
0.4889,
|
| 1605 |
+
0.4917,
|
| 1606 |
+
0.4069,
|
| 1607 |
+
0.4999,
|
| 1608 |
+
0.6866,
|
| 1609 |
+
0.4093,
|
| 1610 |
+
0.5709,
|
| 1611 |
+
0.6065,
|
| 1612 |
+
0.6415,
|
| 1613 |
+
0.4944,
|
| 1614 |
+
0.5726,
|
| 1615 |
+
1.2042,
|
| 1616 |
+
0.5458,
|
| 1617 |
+
1.6887,
|
| 1618 |
+
0.3971,
|
| 1619 |
+
1.0600,
|
| 1620 |
+
0.3943,
|
| 1621 |
+
0.5537,
|
| 1622 |
+
0.5444,
|
| 1623 |
+
0.4089,
|
| 1624 |
+
0.7468,
|
| 1625 |
+
0.7744,
|
| 1626 |
+
]
|
| 1627 |
+
self.mean = torch.tensor(mean)
|
| 1628 |
+
self.std = torch.tensor(std)
|
| 1629 |
+
self.scale = [self.mean, 1.0 / self.std]
|
| 1630 |
+
|
| 1631 |
+
# init model
|
| 1632 |
+
self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False)
|
| 1633 |
+
self.upsampling_factor = 16
|
| 1634 |
+
self.z_dim = z_dim
|
pipelines/base.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from torchvision.transforms import GaussianBlur
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BasePipeline(torch.nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
device="cuda",
|
| 11 |
+
torch_dtype=torch.float16,
|
| 12 |
+
height_division_factor=64,
|
| 13 |
+
width_division_factor=64,
|
| 14 |
+
):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.device = device
|
| 17 |
+
self.torch_dtype = torch_dtype
|
| 18 |
+
self.height_division_factor = height_division_factor
|
| 19 |
+
self.width_division_factor = width_division_factor
|
| 20 |
+
self.cpu_offload = False
|
| 21 |
+
self.model_names = []
|
| 22 |
+
|
| 23 |
+
def check_resize_height_width(self, height, width):
|
| 24 |
+
if height % self.height_division_factor != 0:
|
| 25 |
+
height = (
|
| 26 |
+
(height + self.height_division_factor - 1)
|
| 27 |
+
// self.height_division_factor
|
| 28 |
+
* self.height_division_factor
|
| 29 |
+
)
|
| 30 |
+
print(
|
| 31 |
+
f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}."
|
| 32 |
+
)
|
| 33 |
+
if width % self.width_division_factor != 0:
|
| 34 |
+
width = (
|
| 35 |
+
(width + self.width_division_factor - 1)
|
| 36 |
+
// self.width_division_factor
|
| 37 |
+
* self.width_division_factor
|
| 38 |
+
)
|
| 39 |
+
print(
|
| 40 |
+
f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}."
|
| 41 |
+
)
|
| 42 |
+
return height, width
|
| 43 |
+
|
| 44 |
+
def preprocess_image(self, image):
|
| 45 |
+
image = (
|
| 46 |
+
torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1)
|
| 47 |
+
.permute(2, 0, 1)
|
| 48 |
+
.unsqueeze(0)
|
| 49 |
+
)
|
| 50 |
+
return image
|
| 51 |
+
|
| 52 |
+
def preprocess_images(self, images):
|
| 53 |
+
return [self.preprocess_image(image) for image in images]
|
| 54 |
+
|
| 55 |
+
def vae_output_to_image(self, vae_output):
|
| 56 |
+
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
|
| 57 |
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
| 58 |
+
return image
|
| 59 |
+
|
| 60 |
+
def vae_output_to_video(self, vae_output):
|
| 61 |
+
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
| 62 |
+
video = [
|
| 63 |
+
Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
| 64 |
+
for image in video
|
| 65 |
+
]
|
| 66 |
+
return video
|
| 67 |
+
|
| 68 |
+
def merge_latents(
|
| 69 |
+
self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0
|
| 70 |
+
):
|
| 71 |
+
if len(latents) > 0:
|
| 72 |
+
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
| 73 |
+
height, width = value.shape[-2:]
|
| 74 |
+
weight = torch.ones_like(value)
|
| 75 |
+
for latent, mask, scale in zip(latents, masks, scales):
|
| 76 |
+
mask = (
|
| 77 |
+
self.preprocess_image(mask.resize((width, height))).mean(
|
| 78 |
+
dim=1, keepdim=True
|
| 79 |
+
)
|
| 80 |
+
> 0
|
| 81 |
+
)
|
| 82 |
+
mask = mask.repeat(1, latent.shape[1], 1, 1).to(
|
| 83 |
+
dtype=latent.dtype, device=latent.device
|
| 84 |
+
)
|
| 85 |
+
mask = blur(mask)
|
| 86 |
+
value += latent * mask * scale
|
| 87 |
+
weight += mask * scale
|
| 88 |
+
value /= weight
|
| 89 |
+
return value
|
| 90 |
+
|
| 91 |
+
def control_noise_via_local_prompts(
|
| 92 |
+
self,
|
| 93 |
+
prompt_emb_global,
|
| 94 |
+
prompt_emb_locals,
|
| 95 |
+
masks,
|
| 96 |
+
mask_scales,
|
| 97 |
+
inference_callback,
|
| 98 |
+
special_kwargs=None,
|
| 99 |
+
special_local_kwargs_list=None,
|
| 100 |
+
):
|
| 101 |
+
if special_kwargs is None:
|
| 102 |
+
noise_pred_global = inference_callback(prompt_emb_global)
|
| 103 |
+
else:
|
| 104 |
+
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
|
| 105 |
+
if special_local_kwargs_list is None:
|
| 106 |
+
noise_pred_locals = [
|
| 107 |
+
inference_callback(prompt_emb_local)
|
| 108 |
+
for prompt_emb_local in prompt_emb_locals
|
| 109 |
+
]
|
| 110 |
+
else:
|
| 111 |
+
noise_pred_locals = [
|
| 112 |
+
inference_callback(prompt_emb_local, special_kwargs)
|
| 113 |
+
for prompt_emb_local, special_kwargs in zip(
|
| 114 |
+
prompt_emb_locals, special_local_kwargs_list
|
| 115 |
+
)
|
| 116 |
+
]
|
| 117 |
+
noise_pred = self.merge_latents(
|
| 118 |
+
noise_pred_global, noise_pred_locals, masks, mask_scales
|
| 119 |
+
)
|
| 120 |
+
return noise_pred
|
| 121 |
+
|
| 122 |
+
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
| 123 |
+
local_prompts = local_prompts or []
|
| 124 |
+
masks = masks or []
|
| 125 |
+
mask_scales = mask_scales or []
|
| 126 |
+
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
| 127 |
+
prompt = extended_prompt_dict.get("prompt", prompt)
|
| 128 |
+
local_prompts += extended_prompt_dict.get("prompts", [])
|
| 129 |
+
masks += extended_prompt_dict.get("masks", [])
|
| 130 |
+
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
| 131 |
+
return prompt, local_prompts, masks, mask_scales
|
| 132 |
+
|
| 133 |
+
def enable_cpu_offload(self):
|
| 134 |
+
self.cpu_offload = True
|
| 135 |
+
|
| 136 |
+
def load_models_to_device(self, loadmodel_names=[]):
|
| 137 |
+
# only load models to device if cpu_offload is enabled
|
| 138 |
+
if not self.cpu_offload:
|
| 139 |
+
return
|
| 140 |
+
# offload the unneeded models to cpu
|
| 141 |
+
for model_name in self.model_names:
|
| 142 |
+
if model_name not in loadmodel_names:
|
| 143 |
+
model = getattr(self, model_name)
|
| 144 |
+
if model is not None:
|
| 145 |
+
if (
|
| 146 |
+
hasattr(model, "vram_management_enabled")
|
| 147 |
+
and model.vram_management_enabled
|
| 148 |
+
):
|
| 149 |
+
for module in model.modules():
|
| 150 |
+
if hasattr(module, "offload"):
|
| 151 |
+
module.offload()
|
| 152 |
+
else:
|
| 153 |
+
model.cpu()
|
| 154 |
+
# load the needed models to device
|
| 155 |
+
for model_name in loadmodel_names:
|
| 156 |
+
model = getattr(self, model_name)
|
| 157 |
+
if model is not None:
|
| 158 |
+
if (
|
| 159 |
+
hasattr(model, "vram_management_enabled")
|
| 160 |
+
and model.vram_management_enabled
|
| 161 |
+
):
|
| 162 |
+
for module in model.modules():
|
| 163 |
+
if hasattr(module, "onload"):
|
| 164 |
+
module.onload()
|
| 165 |
+
else:
|
| 166 |
+
model.to(self.device)
|
| 167 |
+
# fresh the cuda cache
|
| 168 |
+
torch.cuda.empty_cache()
|
| 169 |
+
|
| 170 |
+
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
|
| 171 |
+
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
|
| 172 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
| 173 |
+
return noise
|
pipelines/wan_video.py
ADDED
|
@@ -0,0 +1,1793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, types
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from einops import repeat
|
| 5 |
+
from typing import Optional, Union
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from typing_extensions import Literal
|
| 11 |
+
import imageio
|
| 12 |
+
import os
|
| 13 |
+
from typing import List, Tuple
|
| 14 |
+
import PIL
|
| 15 |
+
from utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
|
| 16 |
+
from models import ModelManager, load_state_dict
|
| 17 |
+
from models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
| 18 |
+
from models.wan_video_text_encoder import (
|
| 19 |
+
WanTextEncoder,
|
| 20 |
+
T5RelativeEmbedding,
|
| 21 |
+
T5LayerNorm,
|
| 22 |
+
)
|
| 23 |
+
from models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
| 24 |
+
from models.wan_video_image_encoder import WanImageEncoder
|
| 25 |
+
from models.wan_video_vace import VaceWanModel
|
| 26 |
+
from models.wan_video_motion_controller import WanMotionControllerModel
|
| 27 |
+
from schedulers.flow_match import FlowMatchScheduler
|
| 28 |
+
from prompters import WanPrompter
|
| 29 |
+
from vram_management import (
|
| 30 |
+
enable_vram_management,
|
| 31 |
+
AutoWrappedModule,
|
| 32 |
+
AutoWrappedLinear,
|
| 33 |
+
WanAutoCastLayerNorm,
|
| 34 |
+
)
|
| 35 |
+
from lora import GeneralLoRALoader
|
| 36 |
+
|
| 37 |
+
def load_video_as_list(video_path: str) -> Tuple[List[Image.Image], int, int, int]:
|
| 38 |
+
if not os.path.isfile(video_path):
|
| 39 |
+
raise FileNotFoundError(f"Video file not found: {video_path}")
|
| 40 |
+
|
| 41 |
+
reader = imageio.get_reader(video_path)
|
| 42 |
+
|
| 43 |
+
meta_data = reader.get_meta_data()
|
| 44 |
+
original_width = meta_data['size'][0]
|
| 45 |
+
original_height = meta_data['size'][1]
|
| 46 |
+
|
| 47 |
+
new_width = (original_width // 16) * 16
|
| 48 |
+
new_height = (original_height // 16) * 16
|
| 49 |
+
|
| 50 |
+
left = (original_width - new_width) // 2
|
| 51 |
+
top = (original_height - new_height) // 2
|
| 52 |
+
right = left + new_width
|
| 53 |
+
bottom = top + new_height
|
| 54 |
+
crop_box = (left, top, right, bottom)
|
| 55 |
+
|
| 56 |
+
original_frame_count = reader.count_frames()
|
| 57 |
+
new_frame_count = original_frame_count - ((original_frame_count - 1) % 4)
|
| 58 |
+
|
| 59 |
+
frames = []
|
| 60 |
+
for i in range(new_frame_count):
|
| 61 |
+
try:
|
| 62 |
+
frame_data = reader.get_data(i)
|
| 63 |
+
pil_image = Image.fromarray(frame_data)
|
| 64 |
+
cropped_image = pil_image.crop(crop_box)
|
| 65 |
+
frames.append(cropped_image)
|
| 66 |
+
except IndexError:
|
| 67 |
+
print(f"Warning: Actual number of frames is less than expected. Stopping at frame {i}.")
|
| 68 |
+
new_frame_count = len(frames)
|
| 69 |
+
break
|
| 70 |
+
|
| 71 |
+
reader.close()
|
| 72 |
+
|
| 73 |
+
return frames, new_width, new_height, new_frame_count
|
| 74 |
+
|
| 75 |
+
class WanVideoPipeline(BasePipeline):
|
| 76 |
+
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
|
| 77 |
+
super().__init__(
|
| 78 |
+
device=device,
|
| 79 |
+
torch_dtype=torch_dtype,
|
| 80 |
+
height_division_factor=16,
|
| 81 |
+
width_division_factor=16,
|
| 82 |
+
time_division_factor=4,
|
| 83 |
+
time_division_remainder=1,
|
| 84 |
+
)
|
| 85 |
+
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
| 86 |
+
self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
|
| 87 |
+
self.text_encoder: WanTextEncoder = None
|
| 88 |
+
self.image_encoder: WanImageEncoder = None
|
| 89 |
+
self.dit: WanModel = None
|
| 90 |
+
self.dit2: WanModel = None
|
| 91 |
+
self.vae: WanVideoVAE = None
|
| 92 |
+
self.motion_controller: WanMotionControllerModel = None
|
| 93 |
+
self.vace: VaceWanModel = None
|
| 94 |
+
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
| 95 |
+
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
|
| 96 |
+
self.unit_runner = PipelineUnitRunner()
|
| 97 |
+
self.units = [
|
| 98 |
+
WanVideoUnit_ShapeChecker(),
|
| 99 |
+
WanVideoUnit_NoiseInitializer(),
|
| 100 |
+
WanVideoUnit_InputVideoEmbedder(),
|
| 101 |
+
WanVideoUnit_PromptEmbedder(),
|
| 102 |
+
WanVideoUnit_ImageEmbedderVAE(),
|
| 103 |
+
WanVideoUnit_ImageEmbedderCLIP(),
|
| 104 |
+
WanVideoUnit_ImageEmbedderFused(),
|
| 105 |
+
WanVideoUnit_FunControl(),
|
| 106 |
+
WanVideoUnit_FunReference(),
|
| 107 |
+
WanVideoUnit_FunCameraControl(),
|
| 108 |
+
WanVideoUnit_SpeedControl(),
|
| 109 |
+
WanVideoUnit_VACE(),
|
| 110 |
+
WanVideoUnit_UnifiedSequenceParallel(),
|
| 111 |
+
WanVideoUnit_TeaCache(),
|
| 112 |
+
WanVideoUnit_CfgMerger(),
|
| 113 |
+
]
|
| 114 |
+
self.model_fn = model_fn_wan_video
|
| 115 |
+
|
| 116 |
+
def encode_ip_image(self, ip_image):
|
| 117 |
+
self.load_models_to_device(["vae"])
|
| 118 |
+
ip_image = (
|
| 119 |
+
torch.tensor(np.array(ip_image)).permute(2, 0, 1).float() / 255.0
|
| 120 |
+
) # [3, H, W]
|
| 121 |
+
ip_image = (
|
| 122 |
+
ip_image.unsqueeze(1).unsqueeze(0).to(dtype=self.torch_dtype)
|
| 123 |
+
) # [B, 3, 1, H, W]
|
| 124 |
+
ip_image = ip_image * 2 - 1
|
| 125 |
+
ip_image_latent = self.vae.encode(ip_image, device=self.device, tiled=False)
|
| 126 |
+
return ip_image_latent
|
| 127 |
+
|
| 128 |
+
def load_lora(self, module, path, alpha=1):
|
| 129 |
+
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
| 130 |
+
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
| 131 |
+
loader.load(module, lora, alpha=alpha)
|
| 132 |
+
|
| 133 |
+
def training_loss(self, **inputs):
|
| 134 |
+
max_timestep_boundary = int(
|
| 135 |
+
inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps
|
| 136 |
+
)
|
| 137 |
+
min_timestep_boundary = int(
|
| 138 |
+
inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps
|
| 139 |
+
)
|
| 140 |
+
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
| 141 |
+
timestep = self.scheduler.timesteps[timestep_id].to(
|
| 142 |
+
dtype=self.torch_dtype, device=self.device
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
inputs["latents"] = self.scheduler.add_noise(
|
| 146 |
+
inputs["input_latents"], inputs["noise"], timestep
|
| 147 |
+
)
|
| 148 |
+
training_target = self.scheduler.training_target(
|
| 149 |
+
inputs["input_latents"], inputs["noise"], timestep
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
| 153 |
+
|
| 154 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
| 155 |
+
loss = loss * self.scheduler.training_weight(timestep)
|
| 156 |
+
return loss
|
| 157 |
+
|
| 158 |
+
def enable_vram_management(
|
| 159 |
+
self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5
|
| 160 |
+
):
|
| 161 |
+
self.vram_management_enabled = True
|
| 162 |
+
if num_persistent_param_in_dit is not None:
|
| 163 |
+
vram_limit = None
|
| 164 |
+
else:
|
| 165 |
+
if vram_limit is None:
|
| 166 |
+
vram_limit = self.get_vram()
|
| 167 |
+
vram_limit = vram_limit - vram_buffer
|
| 168 |
+
if self.text_encoder is not None:
|
| 169 |
+
dtype = next(iter(self.text_encoder.parameters())).dtype
|
| 170 |
+
enable_vram_management(
|
| 171 |
+
self.text_encoder,
|
| 172 |
+
module_map={
|
| 173 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 174 |
+
torch.nn.Embedding: AutoWrappedModule,
|
| 175 |
+
T5RelativeEmbedding: AutoWrappedModule,
|
| 176 |
+
T5LayerNorm: AutoWrappedModule,
|
| 177 |
+
},
|
| 178 |
+
module_config=dict(
|
| 179 |
+
offload_dtype=dtype,
|
| 180 |
+
offload_device="cpu",
|
| 181 |
+
onload_dtype=dtype,
|
| 182 |
+
onload_device="cpu",
|
| 183 |
+
computation_dtype=self.torch_dtype,
|
| 184 |
+
computation_device=self.device,
|
| 185 |
+
),
|
| 186 |
+
vram_limit=vram_limit,
|
| 187 |
+
)
|
| 188 |
+
if self.dit is not None:
|
| 189 |
+
dtype = next(iter(self.dit.parameters())).dtype
|
| 190 |
+
device = "cpu" if vram_limit is not None else self.device
|
| 191 |
+
enable_vram_management(
|
| 192 |
+
self.dit,
|
| 193 |
+
module_map={
|
| 194 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 195 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
| 196 |
+
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
| 197 |
+
RMSNorm: AutoWrappedModule,
|
| 198 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
| 199 |
+
},
|
| 200 |
+
module_config=dict(
|
| 201 |
+
offload_dtype=dtype,
|
| 202 |
+
offload_device="cpu",
|
| 203 |
+
onload_dtype=dtype,
|
| 204 |
+
onload_device=device,
|
| 205 |
+
computation_dtype=self.torch_dtype,
|
| 206 |
+
computation_device=self.device,
|
| 207 |
+
),
|
| 208 |
+
max_num_param=num_persistent_param_in_dit,
|
| 209 |
+
overflow_module_config=dict(
|
| 210 |
+
offload_dtype=dtype,
|
| 211 |
+
offload_device="cpu",
|
| 212 |
+
onload_dtype=dtype,
|
| 213 |
+
onload_device="cpu",
|
| 214 |
+
computation_dtype=self.torch_dtype,
|
| 215 |
+
computation_device=self.device,
|
| 216 |
+
),
|
| 217 |
+
vram_limit=vram_limit,
|
| 218 |
+
)
|
| 219 |
+
if self.dit2 is not None:
|
| 220 |
+
dtype = next(iter(self.dit2.parameters())).dtype
|
| 221 |
+
device = "cpu" if vram_limit is not None else self.device
|
| 222 |
+
enable_vram_management(
|
| 223 |
+
self.dit2,
|
| 224 |
+
module_map={
|
| 225 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 226 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
| 227 |
+
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
| 228 |
+
RMSNorm: AutoWrappedModule,
|
| 229 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
| 230 |
+
},
|
| 231 |
+
module_config=dict(
|
| 232 |
+
offload_dtype=dtype,
|
| 233 |
+
offload_device="cpu",
|
| 234 |
+
onload_dtype=dtype,
|
| 235 |
+
onload_device=device,
|
| 236 |
+
computation_dtype=self.torch_dtype,
|
| 237 |
+
computation_device=self.device,
|
| 238 |
+
),
|
| 239 |
+
max_num_param=num_persistent_param_in_dit,
|
| 240 |
+
overflow_module_config=dict(
|
| 241 |
+
offload_dtype=dtype,
|
| 242 |
+
offload_device="cpu",
|
| 243 |
+
onload_dtype=dtype,
|
| 244 |
+
onload_device="cpu",
|
| 245 |
+
computation_dtype=self.torch_dtype,
|
| 246 |
+
computation_device=self.device,
|
| 247 |
+
),
|
| 248 |
+
vram_limit=vram_limit,
|
| 249 |
+
)
|
| 250 |
+
if self.vae is not None:
|
| 251 |
+
dtype = next(iter(self.vae.parameters())).dtype
|
| 252 |
+
enable_vram_management(
|
| 253 |
+
self.vae,
|
| 254 |
+
module_map={
|
| 255 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 256 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
| 257 |
+
RMS_norm: AutoWrappedModule,
|
| 258 |
+
CausalConv3d: AutoWrappedModule,
|
| 259 |
+
Upsample: AutoWrappedModule,
|
| 260 |
+
torch.nn.SiLU: AutoWrappedModule,
|
| 261 |
+
torch.nn.Dropout: AutoWrappedModule,
|
| 262 |
+
},
|
| 263 |
+
module_config=dict(
|
| 264 |
+
offload_dtype=dtype,
|
| 265 |
+
offload_device="cpu",
|
| 266 |
+
onload_dtype=dtype,
|
| 267 |
+
onload_device=self.device,
|
| 268 |
+
computation_dtype=self.torch_dtype,
|
| 269 |
+
computation_device=self.device,
|
| 270 |
+
),
|
| 271 |
+
)
|
| 272 |
+
if self.image_encoder is not None:
|
| 273 |
+
dtype = next(iter(self.image_encoder.parameters())).dtype
|
| 274 |
+
enable_vram_management(
|
| 275 |
+
self.image_encoder,
|
| 276 |
+
module_map={
|
| 277 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 278 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
| 279 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
| 280 |
+
},
|
| 281 |
+
module_config=dict(
|
| 282 |
+
offload_dtype=dtype,
|
| 283 |
+
offload_device="cpu",
|
| 284 |
+
onload_dtype=dtype,
|
| 285 |
+
onload_device="cpu",
|
| 286 |
+
computation_dtype=dtype,
|
| 287 |
+
computation_device=self.device,
|
| 288 |
+
),
|
| 289 |
+
)
|
| 290 |
+
if self.motion_controller is not None:
|
| 291 |
+
dtype = next(iter(self.motion_controller.parameters())).dtype
|
| 292 |
+
enable_vram_management(
|
| 293 |
+
self.motion_controller,
|
| 294 |
+
module_map={
|
| 295 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 296 |
+
},
|
| 297 |
+
module_config=dict(
|
| 298 |
+
offload_dtype=dtype,
|
| 299 |
+
offload_device="cpu",
|
| 300 |
+
onload_dtype=dtype,
|
| 301 |
+
onload_device="cpu",
|
| 302 |
+
computation_dtype=dtype,
|
| 303 |
+
computation_device=self.device,
|
| 304 |
+
),
|
| 305 |
+
)
|
| 306 |
+
if self.vace is not None:
|
| 307 |
+
device = "cpu" if vram_limit is not None else self.device
|
| 308 |
+
enable_vram_management(
|
| 309 |
+
self.vace,
|
| 310 |
+
module_map={
|
| 311 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 312 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
| 313 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
| 314 |
+
RMSNorm: AutoWrappedModule,
|
| 315 |
+
},
|
| 316 |
+
module_config=dict(
|
| 317 |
+
offload_dtype=dtype,
|
| 318 |
+
offload_device="cpu",
|
| 319 |
+
onload_dtype=dtype,
|
| 320 |
+
onload_device=device,
|
| 321 |
+
computation_dtype=self.torch_dtype,
|
| 322 |
+
computation_device=self.device,
|
| 323 |
+
),
|
| 324 |
+
vram_limit=vram_limit,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
def initialize_usp(self):
|
| 328 |
+
import torch.distributed as dist
|
| 329 |
+
from xfuser.core.distributed import (
|
| 330 |
+
initialize_model_parallel,
|
| 331 |
+
init_distributed_environment,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
dist.init_process_group(backend="nccl", init_method="env://")
|
| 335 |
+
init_distributed_environment(
|
| 336 |
+
rank=dist.get_rank(), world_size=dist.get_world_size()
|
| 337 |
+
)
|
| 338 |
+
initialize_model_parallel(
|
| 339 |
+
sequence_parallel_degree=dist.get_world_size(),
|
| 340 |
+
ring_degree=1,
|
| 341 |
+
ulysses_degree=dist.get_world_size(),
|
| 342 |
+
)
|
| 343 |
+
torch.cuda.set_device(dist.get_rank())
|
| 344 |
+
|
| 345 |
+
def enable_usp(self):
|
| 346 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
| 347 |
+
from distributed.xdit_context_parallel import (
|
| 348 |
+
usp_attn_forward,
|
| 349 |
+
usp_dit_forward,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
for block in self.dit.blocks:
|
| 353 |
+
block.self_attn.forward = types.MethodType(
|
| 354 |
+
usp_attn_forward, block.self_attn
|
| 355 |
+
)
|
| 356 |
+
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
|
| 357 |
+
if self.dit2 is not None:
|
| 358 |
+
for block in self.dit2.blocks:
|
| 359 |
+
block.self_attn.forward = types.MethodType(
|
| 360 |
+
usp_attn_forward, block.self_attn
|
| 361 |
+
)
|
| 362 |
+
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
| 363 |
+
self.sp_size = get_sequence_parallel_world_size()
|
| 364 |
+
self.use_unified_sequence_parallel = True
|
| 365 |
+
|
| 366 |
+
@staticmethod
|
| 367 |
+
def from_pretrained(
|
| 368 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 369 |
+
device: Union[str, torch.device] = "cuda",
|
| 370 |
+
model_configs: list[ModelConfig] = [],
|
| 371 |
+
tokenizer_config: ModelConfig = ModelConfig(
|
| 372 |
+
model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"
|
| 373 |
+
),
|
| 374 |
+
redirect_common_files: bool = True,
|
| 375 |
+
use_usp=False,
|
| 376 |
+
):
|
| 377 |
+
# Redirect model path
|
| 378 |
+
if redirect_common_files:
|
| 379 |
+
redirect_dict = {
|
| 380 |
+
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
|
| 381 |
+
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
|
| 382 |
+
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
|
| 383 |
+
}
|
| 384 |
+
for model_config in model_configs:
|
| 385 |
+
if (
|
| 386 |
+
model_config.origin_file_pattern is None
|
| 387 |
+
or model_config.model_id is None
|
| 388 |
+
):
|
| 389 |
+
continue
|
| 390 |
+
if (
|
| 391 |
+
model_config.origin_file_pattern in redirect_dict
|
| 392 |
+
and model_config.model_id
|
| 393 |
+
!= redirect_dict[model_config.origin_file_pattern]
|
| 394 |
+
):
|
| 395 |
+
print(
|
| 396 |
+
f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection."
|
| 397 |
+
)
|
| 398 |
+
model_config.model_id = redirect_dict[
|
| 399 |
+
model_config.origin_file_pattern
|
| 400 |
+
]
|
| 401 |
+
|
| 402 |
+
# Initialize pipeline
|
| 403 |
+
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
| 404 |
+
if use_usp:
|
| 405 |
+
pipe.initialize_usp()
|
| 406 |
+
|
| 407 |
+
# Download and load models
|
| 408 |
+
model_manager = ModelManager()
|
| 409 |
+
for model_config in model_configs:
|
| 410 |
+
model_config.download_if_necessary(use_usp=use_usp)
|
| 411 |
+
model_manager.load_model(
|
| 412 |
+
model_config.path,
|
| 413 |
+
device=model_config.offload_device or device,
|
| 414 |
+
torch_dtype=model_config.offload_dtype or torch_dtype,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# Load models
|
| 418 |
+
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
| 419 |
+
dit = model_manager.fetch_model("wan_video_dit", index=2)
|
| 420 |
+
if isinstance(dit, list):
|
| 421 |
+
pipe.dit, pipe.dit2 = dit
|
| 422 |
+
else:
|
| 423 |
+
pipe.dit = dit
|
| 424 |
+
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
| 425 |
+
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
| 426 |
+
pipe.motion_controller = model_manager.fetch_model(
|
| 427 |
+
"wan_video_motion_controller"
|
| 428 |
+
)
|
| 429 |
+
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
| 430 |
+
|
| 431 |
+
# Size division factor
|
| 432 |
+
if pipe.vae is not None:
|
| 433 |
+
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
|
| 434 |
+
pipe.width_division_factor = pipe.vae.upsampling_factor * 2
|
| 435 |
+
|
| 436 |
+
# Initialize tokenizer
|
| 437 |
+
tokenizer_config.download_if_necessary(use_usp=use_usp)
|
| 438 |
+
pipe.prompter.fetch_models(pipe.text_encoder)
|
| 439 |
+
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
| 440 |
+
|
| 441 |
+
# Unified Sequence Parallel
|
| 442 |
+
if use_usp:
|
| 443 |
+
pipe.enable_usp()
|
| 444 |
+
return pipe
|
| 445 |
+
|
| 446 |
+
@torch.no_grad()
|
| 447 |
+
def __call__(
|
| 448 |
+
self,
|
| 449 |
+
# Prompt
|
| 450 |
+
prompt: str,
|
| 451 |
+
negative_prompt: Optional[str] = "",
|
| 452 |
+
# Image-to-video
|
| 453 |
+
input_image: Optional[Image.Image] = None,
|
| 454 |
+
# First-last-frame-to-video
|
| 455 |
+
end_image: Optional[Image.Image] = None,
|
| 456 |
+
# Video-to-video
|
| 457 |
+
input_video: Optional[list[Image.Image]] = None,
|
| 458 |
+
denoising_strength: Optional[float] = 1.0,
|
| 459 |
+
# ControlNet
|
| 460 |
+
control_video: Optional[list[Image.Image]] = None,
|
| 461 |
+
reference_image: Optional[Image.Image] = None,
|
| 462 |
+
# Camera control
|
| 463 |
+
camera_control_direction: Optional[
|
| 464 |
+
Literal[
|
| 465 |
+
"Left",
|
| 466 |
+
"Right",
|
| 467 |
+
"Up",
|
| 468 |
+
"Down",
|
| 469 |
+
"LeftUp",
|
| 470 |
+
"LeftDown",
|
| 471 |
+
"RightUp",
|
| 472 |
+
"RightDown",
|
| 473 |
+
]
|
| 474 |
+
] = None,
|
| 475 |
+
camera_control_speed: Optional[float] = 1 / 54,
|
| 476 |
+
camera_control_origin: Optional[tuple] = (
|
| 477 |
+
0,
|
| 478 |
+
0.532139961,
|
| 479 |
+
0.946026558,
|
| 480 |
+
0.5,
|
| 481 |
+
0.5,
|
| 482 |
+
0,
|
| 483 |
+
0,
|
| 484 |
+
1,
|
| 485 |
+
0,
|
| 486 |
+
0,
|
| 487 |
+
0,
|
| 488 |
+
0,
|
| 489 |
+
1,
|
| 490 |
+
0,
|
| 491 |
+
0,
|
| 492 |
+
0,
|
| 493 |
+
0,
|
| 494 |
+
1,
|
| 495 |
+
0,
|
| 496 |
+
),
|
| 497 |
+
# VACE
|
| 498 |
+
vace_video: Optional[list[Image.Image]] = None,
|
| 499 |
+
vace_video_mask: Optional[Image.Image] = None,
|
| 500 |
+
vace_reference_image: Optional[Image.Image] = None,
|
| 501 |
+
vace_scale: Optional[float] = 1.0,
|
| 502 |
+
# Randomness
|
| 503 |
+
seed: Optional[int] = None,
|
| 504 |
+
rand_device: Optional[str] = "cpu",
|
| 505 |
+
# Shape
|
| 506 |
+
height: Optional[int] = 480,
|
| 507 |
+
width: Optional[int] = 832,
|
| 508 |
+
num_frames=81,
|
| 509 |
+
# Classifier-free guidance
|
| 510 |
+
cfg_scale: Optional[float] = 5.0,
|
| 511 |
+
cfg_merge: Optional[bool] = False,
|
| 512 |
+
# Boundary
|
| 513 |
+
switch_DiT_boundary: Optional[float] = 0.875,
|
| 514 |
+
# Scheduler
|
| 515 |
+
num_inference_steps: Optional[int] = 50,
|
| 516 |
+
sigma_shift: Optional[float] = 5.0,
|
| 517 |
+
# Speed control
|
| 518 |
+
motion_bucket_id: Optional[int] = None,
|
| 519 |
+
# VAE tiling
|
| 520 |
+
tiled: Optional[bool] = True,
|
| 521 |
+
tile_size: Optional[tuple[int, int]] = (30, 52),
|
| 522 |
+
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
| 523 |
+
# Sliding window
|
| 524 |
+
sliding_window_size: Optional[int] = None,
|
| 525 |
+
sliding_window_stride: Optional[int] = None,
|
| 526 |
+
# Teacache
|
| 527 |
+
tea_cache_l1_thresh: Optional[float] = None,
|
| 528 |
+
tea_cache_model_id: Optional[str] = "",
|
| 529 |
+
# progress_bar
|
| 530 |
+
progress_bar_cmd=tqdm,
|
| 531 |
+
# Stand-In
|
| 532 |
+
ip_image=None,
|
| 533 |
+
):
|
| 534 |
+
if ip_image is not None:
|
| 535 |
+
ip_image = self.encode_ip_image(ip_image)
|
| 536 |
+
if vace_video is not None:
|
| 537 |
+
vace_video, width, height, num_frames = load_video_as_list(vace_video)
|
| 538 |
+
if vace_reference_image is not None:
|
| 539 |
+
vace_reference_image = Image.open(vace_reference_image).convert('RGB')
|
| 540 |
+
ref_width, ref_height = vace_reference_image.size
|
| 541 |
+
if ref_width != width or ref_height != height:
|
| 542 |
+
scale_ratio = min(width / ref_width, height / ref_height)
|
| 543 |
+
|
| 544 |
+
new_ref_width = int(ref_width * scale_ratio)
|
| 545 |
+
new_ref_height = int(ref_height * scale_ratio)
|
| 546 |
+
|
| 547 |
+
resized_image = vace_reference_image.resize((new_ref_width, new_ref_height), Image.LANCZOS)
|
| 548 |
+
|
| 549 |
+
background = Image.new('RGB', (width, height), (255, 255, 255))
|
| 550 |
+
|
| 551 |
+
paste_x = (width - new_ref_width) // 2
|
| 552 |
+
paste_y = (height - new_ref_height) // 2
|
| 553 |
+
background.paste(resized_image, (paste_x, paste_y))
|
| 554 |
+
|
| 555 |
+
vace_reference_image = background
|
| 556 |
+
# Scheduler
|
| 557 |
+
self.scheduler.set_timesteps(
|
| 558 |
+
num_inference_steps,
|
| 559 |
+
denoising_strength=denoising_strength,
|
| 560 |
+
shift=sigma_shift,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
# Inputs
|
| 564 |
+
inputs_posi = {
|
| 565 |
+
"prompt": prompt,
|
| 566 |
+
"tea_cache_l1_thresh": tea_cache_l1_thresh,
|
| 567 |
+
"tea_cache_model_id": tea_cache_model_id,
|
| 568 |
+
"num_inference_steps": num_inference_steps,
|
| 569 |
+
}
|
| 570 |
+
inputs_nega = {
|
| 571 |
+
"negative_prompt": negative_prompt,
|
| 572 |
+
"tea_cache_l1_thresh": tea_cache_l1_thresh,
|
| 573 |
+
"tea_cache_model_id": tea_cache_model_id,
|
| 574 |
+
"num_inference_steps": num_inference_steps,
|
| 575 |
+
}
|
| 576 |
+
inputs_shared = {
|
| 577 |
+
"input_image": input_image,
|
| 578 |
+
"end_image": end_image,
|
| 579 |
+
"input_video": input_video,
|
| 580 |
+
"denoising_strength": denoising_strength,
|
| 581 |
+
"control_video": control_video,
|
| 582 |
+
"reference_image": reference_image,
|
| 583 |
+
"camera_control_direction": camera_control_direction,
|
| 584 |
+
"camera_control_speed": camera_control_speed,
|
| 585 |
+
"camera_control_origin": camera_control_origin,
|
| 586 |
+
"vace_video": vace_video,
|
| 587 |
+
"vace_video_mask": vace_video_mask,
|
| 588 |
+
"vace_reference_image": vace_reference_image,
|
| 589 |
+
"vace_scale": vace_scale,
|
| 590 |
+
"seed": seed,
|
| 591 |
+
"rand_device": rand_device,
|
| 592 |
+
"height": height,
|
| 593 |
+
"width": width,
|
| 594 |
+
"num_frames": num_frames,
|
| 595 |
+
"cfg_scale": cfg_scale,
|
| 596 |
+
"cfg_merge": cfg_merge,
|
| 597 |
+
"sigma_shift": sigma_shift,
|
| 598 |
+
"motion_bucket_id": motion_bucket_id,
|
| 599 |
+
"tiled": tiled,
|
| 600 |
+
"tile_size": tile_size,
|
| 601 |
+
"tile_stride": tile_stride,
|
| 602 |
+
"sliding_window_size": sliding_window_size,
|
| 603 |
+
"sliding_window_stride": sliding_window_stride,
|
| 604 |
+
"ip_image": ip_image,
|
| 605 |
+
}
|
| 606 |
+
for unit in self.units:
|
| 607 |
+
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
| 608 |
+
unit, self, inputs_shared, inputs_posi, inputs_nega
|
| 609 |
+
)
|
| 610 |
+
# Denoise
|
| 611 |
+
self.load_models_to_device(self.in_iteration_models)
|
| 612 |
+
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
| 613 |
+
for progress_id, timestep in enumerate(
|
| 614 |
+
progress_bar_cmd(self.scheduler.timesteps)
|
| 615 |
+
):
|
| 616 |
+
# Switch DiT if necessary
|
| 617 |
+
if (
|
| 618 |
+
timestep.item()
|
| 619 |
+
< switch_DiT_boundary * self.scheduler.num_train_timesteps
|
| 620 |
+
and self.dit2 is not None
|
| 621 |
+
and not models["dit"] is self.dit2
|
| 622 |
+
):
|
| 623 |
+
self.load_models_to_device(self.in_iteration_models_2)
|
| 624 |
+
models["dit"] = self.dit2
|
| 625 |
+
|
| 626 |
+
# Timestep
|
| 627 |
+
timestep = timestep.unsqueeze(0).to(
|
| 628 |
+
dtype=self.torch_dtype, device=self.device
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# Inference
|
| 632 |
+
noise_pred_posi = self.model_fn(
|
| 633 |
+
**models, **inputs_shared, **inputs_posi, timestep=timestep
|
| 634 |
+
)
|
| 635 |
+
inputs_shared["ip_image"] = None
|
| 636 |
+
if cfg_scale != 1.0:
|
| 637 |
+
if cfg_merge:
|
| 638 |
+
noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
|
| 639 |
+
else:
|
| 640 |
+
noise_pred_nega = self.model_fn(
|
| 641 |
+
**models, **inputs_shared, **inputs_nega, timestep=timestep
|
| 642 |
+
)
|
| 643 |
+
noise_pred = noise_pred_nega + cfg_scale * (
|
| 644 |
+
noise_pred_posi - noise_pred_nega
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
noise_pred = noise_pred_posi
|
| 648 |
+
|
| 649 |
+
# Scheduler
|
| 650 |
+
inputs_shared["latents"] = self.scheduler.step(
|
| 651 |
+
noise_pred,
|
| 652 |
+
self.scheduler.timesteps[progress_id],
|
| 653 |
+
inputs_shared["latents"],
|
| 654 |
+
)
|
| 655 |
+
if "first_frame_latents" in inputs_shared:
|
| 656 |
+
inputs_shared["latents"][:, :, 0:1] = inputs_shared[
|
| 657 |
+
"first_frame_latents"
|
| 658 |
+
]
|
| 659 |
+
|
| 660 |
+
if vace_reference_image is not None:
|
| 661 |
+
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
| 662 |
+
|
| 663 |
+
# Decode
|
| 664 |
+
self.load_models_to_device(["vae"])
|
| 665 |
+
video = self.vae.decode(
|
| 666 |
+
inputs_shared["latents"],
|
| 667 |
+
device=self.device,
|
| 668 |
+
tiled=tiled,
|
| 669 |
+
tile_size=tile_size,
|
| 670 |
+
tile_stride=tile_stride,
|
| 671 |
+
)
|
| 672 |
+
video = self.vae_output_to_video(video)
|
| 673 |
+
self.load_models_to_device([])
|
| 674 |
+
|
| 675 |
+
return video
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
class WanVideoUnit_ShapeChecker(PipelineUnit):
|
| 679 |
+
def __init__(self):
|
| 680 |
+
super().__init__(input_params=("height", "width", "num_frames"))
|
| 681 |
+
|
| 682 |
+
def process(self, pipe: WanVideoPipeline, height, width, num_frames):
|
| 683 |
+
height, width, num_frames = pipe.check_resize_height_width(
|
| 684 |
+
height, width, num_frames
|
| 685 |
+
)
|
| 686 |
+
return {"height": height, "width": width, "num_frames": num_frames}
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
| 690 |
+
def __init__(self):
|
| 691 |
+
super().__init__(
|
| 692 |
+
input_params=(
|
| 693 |
+
"height",
|
| 694 |
+
"width",
|
| 695 |
+
"num_frames",
|
| 696 |
+
"seed",
|
| 697 |
+
"rand_device",
|
| 698 |
+
"vace_reference_image",
|
| 699 |
+
)
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
def process(
|
| 703 |
+
self,
|
| 704 |
+
pipe: WanVideoPipeline,
|
| 705 |
+
height,
|
| 706 |
+
width,
|
| 707 |
+
num_frames,
|
| 708 |
+
seed,
|
| 709 |
+
rand_device,
|
| 710 |
+
vace_reference_image,
|
| 711 |
+
):
|
| 712 |
+
length = (num_frames - 1) // 4 + 1
|
| 713 |
+
if vace_reference_image is not None:
|
| 714 |
+
length += 1
|
| 715 |
+
shape = (
|
| 716 |
+
1,
|
| 717 |
+
pipe.vae.model.z_dim,
|
| 718 |
+
length,
|
| 719 |
+
height // pipe.vae.upsampling_factor,
|
| 720 |
+
width // pipe.vae.upsampling_factor,
|
| 721 |
+
)
|
| 722 |
+
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
|
| 723 |
+
if vace_reference_image is not None:
|
| 724 |
+
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
|
| 725 |
+
return {"noise": noise}
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
|
| 729 |
+
def __init__(self):
|
| 730 |
+
super().__init__(
|
| 731 |
+
input_params=(
|
| 732 |
+
"input_video",
|
| 733 |
+
"noise",
|
| 734 |
+
"tiled",
|
| 735 |
+
"tile_size",
|
| 736 |
+
"tile_stride",
|
| 737 |
+
"vace_reference_image",
|
| 738 |
+
),
|
| 739 |
+
onload_model_names=("vae",),
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
def process(
|
| 743 |
+
self,
|
| 744 |
+
pipe: WanVideoPipeline,
|
| 745 |
+
input_video,
|
| 746 |
+
noise,
|
| 747 |
+
tiled,
|
| 748 |
+
tile_size,
|
| 749 |
+
tile_stride,
|
| 750 |
+
vace_reference_image,
|
| 751 |
+
):
|
| 752 |
+
if input_video is None:
|
| 753 |
+
return {"latents": noise}
|
| 754 |
+
pipe.load_models_to_device(["vae"])
|
| 755 |
+
input_video = pipe.preprocess_video(input_video)
|
| 756 |
+
input_latents = pipe.vae.encode(
|
| 757 |
+
input_video,
|
| 758 |
+
device=pipe.device,
|
| 759 |
+
tiled=tiled,
|
| 760 |
+
tile_size=tile_size,
|
| 761 |
+
tile_stride=tile_stride,
|
| 762 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 763 |
+
if vace_reference_image is not None:
|
| 764 |
+
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
| 765 |
+
vace_reference_latents = pipe.vae.encode(
|
| 766 |
+
vace_reference_image, device=pipe.device
|
| 767 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 768 |
+
input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
|
| 769 |
+
if pipe.scheduler.training:
|
| 770 |
+
return {"latents": noise, "input_latents": input_latents}
|
| 771 |
+
else:
|
| 772 |
+
latents = pipe.scheduler.add_noise(
|
| 773 |
+
input_latents, noise, timestep=pipe.scheduler.timesteps[0]
|
| 774 |
+
)
|
| 775 |
+
return {"latents": latents}
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
class WanVideoUnit_PromptEmbedder(PipelineUnit):
|
| 779 |
+
def __init__(self):
|
| 780 |
+
super().__init__(
|
| 781 |
+
seperate_cfg=True,
|
| 782 |
+
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
| 783 |
+
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
|
| 784 |
+
onload_model_names=("text_encoder",),
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict:
|
| 788 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 789 |
+
prompt_emb = pipe.prompter.encode_prompt(
|
| 790 |
+
prompt, positive=positive, device=pipe.device
|
| 791 |
+
)
|
| 792 |
+
return {"context": prompt_emb}
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
| 796 |
+
"""
|
| 797 |
+
Deprecated
|
| 798 |
+
"""
|
| 799 |
+
|
| 800 |
+
def __init__(self):
|
| 801 |
+
super().__init__(
|
| 802 |
+
input_params=(
|
| 803 |
+
"input_image",
|
| 804 |
+
"end_image",
|
| 805 |
+
"num_frames",
|
| 806 |
+
"height",
|
| 807 |
+
"width",
|
| 808 |
+
"tiled",
|
| 809 |
+
"tile_size",
|
| 810 |
+
"tile_stride",
|
| 811 |
+
),
|
| 812 |
+
onload_model_names=("image_encoder", "vae"),
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
def process(
|
| 816 |
+
self,
|
| 817 |
+
pipe: WanVideoPipeline,
|
| 818 |
+
input_image,
|
| 819 |
+
end_image,
|
| 820 |
+
num_frames,
|
| 821 |
+
height,
|
| 822 |
+
width,
|
| 823 |
+
tiled,
|
| 824 |
+
tile_size,
|
| 825 |
+
tile_stride,
|
| 826 |
+
):
|
| 827 |
+
if input_image is None or pipe.image_encoder is None:
|
| 828 |
+
return {}
|
| 829 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 830 |
+
image = pipe.preprocess_image(input_image.resize((width, height))).to(
|
| 831 |
+
pipe.device
|
| 832 |
+
)
|
| 833 |
+
clip_context = pipe.image_encoder.encode_image([image])
|
| 834 |
+
msk = torch.ones(1, num_frames, height // 8, width // 8, device=pipe.device)
|
| 835 |
+
msk[:, 1:] = 0
|
| 836 |
+
if end_image is not None:
|
| 837 |
+
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
|
| 838 |
+
pipe.device
|
| 839 |
+
)
|
| 840 |
+
vae_input = torch.concat(
|
| 841 |
+
[
|
| 842 |
+
image.transpose(0, 1),
|
| 843 |
+
torch.zeros(3, num_frames - 2, height, width).to(image.device),
|
| 844 |
+
end_image.transpose(0, 1),
|
| 845 |
+
],
|
| 846 |
+
dim=1,
|
| 847 |
+
)
|
| 848 |
+
if pipe.dit.has_image_pos_emb:
|
| 849 |
+
clip_context = torch.concat(
|
| 850 |
+
[clip_context, pipe.image_encoder.encode_image([end_image])], dim=1
|
| 851 |
+
)
|
| 852 |
+
msk[:, -1:] = 1
|
| 853 |
+
else:
|
| 854 |
+
vae_input = torch.concat(
|
| 855 |
+
[
|
| 856 |
+
image.transpose(0, 1),
|
| 857 |
+
torch.zeros(3, num_frames - 1, height, width).to(image.device),
|
| 858 |
+
],
|
| 859 |
+
dim=1,
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
msk = torch.concat(
|
| 863 |
+
[torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1
|
| 864 |
+
)
|
| 865 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
|
| 866 |
+
msk = msk.transpose(1, 2)[0]
|
| 867 |
+
|
| 868 |
+
y = pipe.vae.encode(
|
| 869 |
+
[vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)],
|
| 870 |
+
device=pipe.device,
|
| 871 |
+
tiled=tiled,
|
| 872 |
+
tile_size=tile_size,
|
| 873 |
+
tile_stride=tile_stride,
|
| 874 |
+
)[0]
|
| 875 |
+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 876 |
+
y = torch.concat([msk, y])
|
| 877 |
+
y = y.unsqueeze(0)
|
| 878 |
+
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 879 |
+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 880 |
+
return {"clip_feature": clip_context, "y": y}
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
|
| 884 |
+
def __init__(self):
|
| 885 |
+
super().__init__(
|
| 886 |
+
input_params=("input_image", "end_image", "height", "width"),
|
| 887 |
+
onload_model_names=("image_encoder",),
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):
|
| 891 |
+
if (
|
| 892 |
+
input_image is None
|
| 893 |
+
or pipe.image_encoder is None
|
| 894 |
+
or not pipe.dit.require_clip_embedding
|
| 895 |
+
):
|
| 896 |
+
return {}
|
| 897 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 898 |
+
image = pipe.preprocess_image(input_image.resize((width, height))).to(
|
| 899 |
+
pipe.device
|
| 900 |
+
)
|
| 901 |
+
clip_context = pipe.image_encoder.encode_image([image])
|
| 902 |
+
if end_image is not None:
|
| 903 |
+
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
|
| 904 |
+
pipe.device
|
| 905 |
+
)
|
| 906 |
+
if pipe.dit.has_image_pos_emb:
|
| 907 |
+
clip_context = torch.concat(
|
| 908 |
+
[clip_context, pipe.image_encoder.encode_image([end_image])], dim=1
|
| 909 |
+
)
|
| 910 |
+
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 911 |
+
return {"clip_feature": clip_context}
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
|
| 915 |
+
def __init__(self):
|
| 916 |
+
super().__init__(
|
| 917 |
+
input_params=(
|
| 918 |
+
"input_image",
|
| 919 |
+
"end_image",
|
| 920 |
+
"num_frames",
|
| 921 |
+
"height",
|
| 922 |
+
"width",
|
| 923 |
+
"tiled",
|
| 924 |
+
"tile_size",
|
| 925 |
+
"tile_stride",
|
| 926 |
+
),
|
| 927 |
+
onload_model_names=("vae",),
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
def process(
|
| 931 |
+
self,
|
| 932 |
+
pipe: WanVideoPipeline,
|
| 933 |
+
input_image,
|
| 934 |
+
end_image,
|
| 935 |
+
num_frames,
|
| 936 |
+
height,
|
| 937 |
+
width,
|
| 938 |
+
tiled,
|
| 939 |
+
tile_size,
|
| 940 |
+
tile_stride,
|
| 941 |
+
):
|
| 942 |
+
if input_image is None or not pipe.dit.require_vae_embedding:
|
| 943 |
+
return {}
|
| 944 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 945 |
+
image = pipe.preprocess_image(input_image.resize((width, height))).to(
|
| 946 |
+
pipe.device
|
| 947 |
+
)
|
| 948 |
+
msk = torch.ones(1, num_frames, height // 8, width // 8, device=pipe.device)
|
| 949 |
+
msk[:, 1:] = 0
|
| 950 |
+
if end_image is not None:
|
| 951 |
+
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
|
| 952 |
+
pipe.device
|
| 953 |
+
)
|
| 954 |
+
vae_input = torch.concat(
|
| 955 |
+
[
|
| 956 |
+
image.transpose(0, 1),
|
| 957 |
+
torch.zeros(3, num_frames - 2, height, width).to(image.device),
|
| 958 |
+
end_image.transpose(0, 1),
|
| 959 |
+
],
|
| 960 |
+
dim=1,
|
| 961 |
+
)
|
| 962 |
+
msk[:, -1:] = 1
|
| 963 |
+
else:
|
| 964 |
+
vae_input = torch.concat(
|
| 965 |
+
[
|
| 966 |
+
image.transpose(0, 1),
|
| 967 |
+
torch.zeros(3, num_frames - 1, height, width).to(image.device),
|
| 968 |
+
],
|
| 969 |
+
dim=1,
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
msk = torch.concat(
|
| 973 |
+
[torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1
|
| 974 |
+
)
|
| 975 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
|
| 976 |
+
msk = msk.transpose(1, 2)[0]
|
| 977 |
+
|
| 978 |
+
y = pipe.vae.encode(
|
| 979 |
+
[vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)],
|
| 980 |
+
device=pipe.device,
|
| 981 |
+
tiled=tiled,
|
| 982 |
+
tile_size=tile_size,
|
| 983 |
+
tile_stride=tile_stride,
|
| 984 |
+
)[0]
|
| 985 |
+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 986 |
+
y = torch.concat([msk, y])
|
| 987 |
+
y = y.unsqueeze(0)
|
| 988 |
+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 989 |
+
return {"y": y}
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
|
| 993 |
+
"""
|
| 994 |
+
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
|
| 995 |
+
"""
|
| 996 |
+
|
| 997 |
+
def __init__(self):
|
| 998 |
+
super().__init__(
|
| 999 |
+
input_params=(
|
| 1000 |
+
"input_image",
|
| 1001 |
+
"latents",
|
| 1002 |
+
"height",
|
| 1003 |
+
"width",
|
| 1004 |
+
"tiled",
|
| 1005 |
+
"tile_size",
|
| 1006 |
+
"tile_stride",
|
| 1007 |
+
),
|
| 1008 |
+
onload_model_names=("vae",),
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
def process(
|
| 1012 |
+
self,
|
| 1013 |
+
pipe: WanVideoPipeline,
|
| 1014 |
+
input_image,
|
| 1015 |
+
latents,
|
| 1016 |
+
height,
|
| 1017 |
+
width,
|
| 1018 |
+
tiled,
|
| 1019 |
+
tile_size,
|
| 1020 |
+
tile_stride,
|
| 1021 |
+
):
|
| 1022 |
+
if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
|
| 1023 |
+
return {}
|
| 1024 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 1025 |
+
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(
|
| 1026 |
+
0, 1
|
| 1027 |
+
)
|
| 1028 |
+
z = pipe.vae.encode(
|
| 1029 |
+
[image],
|
| 1030 |
+
device=pipe.device,
|
| 1031 |
+
tiled=tiled,
|
| 1032 |
+
tile_size=tile_size,
|
| 1033 |
+
tile_stride=tile_stride,
|
| 1034 |
+
)
|
| 1035 |
+
latents[:, :, 0:1] = z
|
| 1036 |
+
return {
|
| 1037 |
+
"latents": latents,
|
| 1038 |
+
"fuse_vae_embedding_in_latents": True,
|
| 1039 |
+
"first_frame_latents": z,
|
| 1040 |
+
}
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
class WanVideoUnit_FunControl(PipelineUnit):
|
| 1044 |
+
def __init__(self):
|
| 1045 |
+
super().__init__(
|
| 1046 |
+
input_params=(
|
| 1047 |
+
"control_video",
|
| 1048 |
+
"num_frames",
|
| 1049 |
+
"height",
|
| 1050 |
+
"width",
|
| 1051 |
+
"tiled",
|
| 1052 |
+
"tile_size",
|
| 1053 |
+
"tile_stride",
|
| 1054 |
+
"clip_feature",
|
| 1055 |
+
"y",
|
| 1056 |
+
),
|
| 1057 |
+
onload_model_names=("vae",),
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
def process(
|
| 1061 |
+
self,
|
| 1062 |
+
pipe: WanVideoPipeline,
|
| 1063 |
+
control_video,
|
| 1064 |
+
num_frames,
|
| 1065 |
+
height,
|
| 1066 |
+
width,
|
| 1067 |
+
tiled,
|
| 1068 |
+
tile_size,
|
| 1069 |
+
tile_stride,
|
| 1070 |
+
clip_feature,
|
| 1071 |
+
y,
|
| 1072 |
+
):
|
| 1073 |
+
if control_video is None:
|
| 1074 |
+
return {}
|
| 1075 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 1076 |
+
control_video = pipe.preprocess_video(control_video)
|
| 1077 |
+
control_latents = pipe.vae.encode(
|
| 1078 |
+
control_video,
|
| 1079 |
+
device=pipe.device,
|
| 1080 |
+
tiled=tiled,
|
| 1081 |
+
tile_size=tile_size,
|
| 1082 |
+
tile_stride=tile_stride,
|
| 1083 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1084 |
+
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1085 |
+
if clip_feature is None or y is None:
|
| 1086 |
+
clip_feature = torch.zeros(
|
| 1087 |
+
(1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device
|
| 1088 |
+
)
|
| 1089 |
+
y = torch.zeros(
|
| 1090 |
+
(1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8),
|
| 1091 |
+
dtype=pipe.torch_dtype,
|
| 1092 |
+
device=pipe.device,
|
| 1093 |
+
)
|
| 1094 |
+
else:
|
| 1095 |
+
y = y[:, -16:]
|
| 1096 |
+
y = torch.concat([control_latents, y], dim=1)
|
| 1097 |
+
return {"clip_feature": clip_feature, "y": y}
|
| 1098 |
+
|
| 1099 |
+
|
| 1100 |
+
class WanVideoUnit_FunReference(PipelineUnit):
|
| 1101 |
+
def __init__(self):
|
| 1102 |
+
super().__init__(
|
| 1103 |
+
input_params=("reference_image", "height", "width", "reference_image"),
|
| 1104 |
+
onload_model_names=("vae",),
|
| 1105 |
+
)
|
| 1106 |
+
|
| 1107 |
+
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
|
| 1108 |
+
if reference_image is None:
|
| 1109 |
+
return {}
|
| 1110 |
+
pipe.load_models_to_device(["vae"])
|
| 1111 |
+
reference_image = reference_image.resize((width, height))
|
| 1112 |
+
reference_latents = pipe.preprocess_video([reference_image])
|
| 1113 |
+
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
|
| 1114 |
+
clip_feature = pipe.preprocess_image(reference_image)
|
| 1115 |
+
clip_feature = pipe.image_encoder.encode_image([clip_feature])
|
| 1116 |
+
return {"reference_latents": reference_latents, "clip_feature": clip_feature}
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
class WanVideoUnit_FunCameraControl(PipelineUnit):
|
| 1120 |
+
def __init__(self):
|
| 1121 |
+
super().__init__(
|
| 1122 |
+
input_params=(
|
| 1123 |
+
"height",
|
| 1124 |
+
"width",
|
| 1125 |
+
"num_frames",
|
| 1126 |
+
"camera_control_direction",
|
| 1127 |
+
"camera_control_speed",
|
| 1128 |
+
"camera_control_origin",
|
| 1129 |
+
"latents",
|
| 1130 |
+
"input_image",
|
| 1131 |
+
),
|
| 1132 |
+
onload_model_names=("vae",),
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
def process(
|
| 1136 |
+
self,
|
| 1137 |
+
pipe: WanVideoPipeline,
|
| 1138 |
+
height,
|
| 1139 |
+
width,
|
| 1140 |
+
num_frames,
|
| 1141 |
+
camera_control_direction,
|
| 1142 |
+
camera_control_speed,
|
| 1143 |
+
camera_control_origin,
|
| 1144 |
+
latents,
|
| 1145 |
+
input_image,
|
| 1146 |
+
):
|
| 1147 |
+
if camera_control_direction is None:
|
| 1148 |
+
return {}
|
| 1149 |
+
camera_control_plucker_embedding = (
|
| 1150 |
+
pipe.dit.control_adapter.process_camera_coordinates(
|
| 1151 |
+
camera_control_direction,
|
| 1152 |
+
num_frames,
|
| 1153 |
+
height,
|
| 1154 |
+
width,
|
| 1155 |
+
camera_control_speed,
|
| 1156 |
+
camera_control_origin,
|
| 1157 |
+
)
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
control_camera_video = (
|
| 1161 |
+
camera_control_plucker_embedding[:num_frames]
|
| 1162 |
+
.permute([3, 0, 1, 2])
|
| 1163 |
+
.unsqueeze(0)
|
| 1164 |
+
)
|
| 1165 |
+
control_camera_latents = torch.concat(
|
| 1166 |
+
[
|
| 1167 |
+
torch.repeat_interleave(
|
| 1168 |
+
control_camera_video[:, :, 0:1], repeats=4, dim=2
|
| 1169 |
+
),
|
| 1170 |
+
control_camera_video[:, :, 1:],
|
| 1171 |
+
],
|
| 1172 |
+
dim=2,
|
| 1173 |
+
).transpose(1, 2)
|
| 1174 |
+
b, f, c, h, w = control_camera_latents.shape
|
| 1175 |
+
control_camera_latents = (
|
| 1176 |
+
control_camera_latents.contiguous()
|
| 1177 |
+
.view(b, f // 4, 4, c, h, w)
|
| 1178 |
+
.transpose(2, 3)
|
| 1179 |
+
)
|
| 1180 |
+
control_camera_latents = (
|
| 1181 |
+
control_camera_latents.contiguous()
|
| 1182 |
+
.view(b, f // 4, c * 4, h, w)
|
| 1183 |
+
.transpose(1, 2)
|
| 1184 |
+
)
|
| 1185 |
+
control_camera_latents_input = control_camera_latents.to(
|
| 1186 |
+
device=pipe.device, dtype=pipe.torch_dtype
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
input_image = input_image.resize((width, height))
|
| 1190 |
+
input_latents = pipe.preprocess_video([input_image])
|
| 1191 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 1192 |
+
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
|
| 1193 |
+
y = torch.zeros_like(latents).to(pipe.device)
|
| 1194 |
+
y[:, :, :1] = input_latents
|
| 1195 |
+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1196 |
+
return {"control_camera_latents_input": control_camera_latents_input, "y": y}
|
| 1197 |
+
|
| 1198 |
+
|
| 1199 |
+
class WanVideoUnit_SpeedControl(PipelineUnit):
|
| 1200 |
+
def __init__(self):
|
| 1201 |
+
super().__init__(input_params=("motion_bucket_id",))
|
| 1202 |
+
|
| 1203 |
+
def process(self, pipe: WanVideoPipeline, motion_bucket_id):
|
| 1204 |
+
if motion_bucket_id is None:
|
| 1205 |
+
return {}
|
| 1206 |
+
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(
|
| 1207 |
+
dtype=pipe.torch_dtype, device=pipe.device
|
| 1208 |
+
)
|
| 1209 |
+
return {"motion_bucket_id": motion_bucket_id}
|
| 1210 |
+
|
| 1211 |
+
|
| 1212 |
+
class WanVideoUnit_VACE(PipelineUnit):
|
| 1213 |
+
def __init__(self):
|
| 1214 |
+
super().__init__(
|
| 1215 |
+
input_params=(
|
| 1216 |
+
"vace_video",
|
| 1217 |
+
"vace_video_mask",
|
| 1218 |
+
"vace_reference_image",
|
| 1219 |
+
"vace_scale",
|
| 1220 |
+
"height",
|
| 1221 |
+
"width",
|
| 1222 |
+
"num_frames",
|
| 1223 |
+
"tiled",
|
| 1224 |
+
"tile_size",
|
| 1225 |
+
"tile_stride",
|
| 1226 |
+
),
|
| 1227 |
+
onload_model_names=("vae",),
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
def process(
|
| 1231 |
+
self,
|
| 1232 |
+
pipe: WanVideoPipeline,
|
| 1233 |
+
vace_video,
|
| 1234 |
+
vace_video_mask,
|
| 1235 |
+
vace_reference_image,
|
| 1236 |
+
vace_scale,
|
| 1237 |
+
height,
|
| 1238 |
+
width,
|
| 1239 |
+
num_frames,
|
| 1240 |
+
tiled,
|
| 1241 |
+
tile_size,
|
| 1242 |
+
tile_stride,
|
| 1243 |
+
):
|
| 1244 |
+
if (
|
| 1245 |
+
vace_video is not None
|
| 1246 |
+
or vace_video_mask is not None
|
| 1247 |
+
or vace_reference_image is not None
|
| 1248 |
+
):
|
| 1249 |
+
pipe.load_models_to_device(["vae"])
|
| 1250 |
+
if vace_video is None:
|
| 1251 |
+
vace_video = torch.zeros(
|
| 1252 |
+
(1, 3, num_frames, height, width),
|
| 1253 |
+
dtype=pipe.torch_dtype,
|
| 1254 |
+
device=pipe.device,
|
| 1255 |
+
)
|
| 1256 |
+
else:
|
| 1257 |
+
vace_video = pipe.preprocess_video(vace_video)
|
| 1258 |
+
|
| 1259 |
+
if vace_video_mask is None:
|
| 1260 |
+
vace_video_mask = torch.ones_like(vace_video)
|
| 1261 |
+
else:
|
| 1262 |
+
vace_video_mask = pipe.preprocess_video(
|
| 1263 |
+
vace_video_mask, min_value=0, max_value=1
|
| 1264 |
+
)
|
| 1265 |
+
|
| 1266 |
+
inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
|
| 1267 |
+
reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
|
| 1268 |
+
inactive = pipe.vae.encode(
|
| 1269 |
+
inactive,
|
| 1270 |
+
device=pipe.device,
|
| 1271 |
+
tiled=tiled,
|
| 1272 |
+
tile_size=tile_size,
|
| 1273 |
+
tile_stride=tile_stride,
|
| 1274 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1275 |
+
reactive = pipe.vae.encode(
|
| 1276 |
+
reactive,
|
| 1277 |
+
device=pipe.device,
|
| 1278 |
+
tiled=tiled,
|
| 1279 |
+
tile_size=tile_size,
|
| 1280 |
+
tile_stride=tile_stride,
|
| 1281 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1282 |
+
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
| 1283 |
+
|
| 1284 |
+
vace_mask_latents = rearrange(
|
| 1285 |
+
vace_video_mask[0, 0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8
|
| 1286 |
+
)
|
| 1287 |
+
vace_mask_latents = torch.nn.functional.interpolate(
|
| 1288 |
+
vace_mask_latents,
|
| 1289 |
+
size=(
|
| 1290 |
+
(vace_mask_latents.shape[2] + 3) // 4,
|
| 1291 |
+
vace_mask_latents.shape[3],
|
| 1292 |
+
vace_mask_latents.shape[4],
|
| 1293 |
+
),
|
| 1294 |
+
mode="nearest-exact",
|
| 1295 |
+
)
|
| 1296 |
+
|
| 1297 |
+
if vace_reference_image is None:
|
| 1298 |
+
pass
|
| 1299 |
+
else:
|
| 1300 |
+
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
| 1301 |
+
vace_reference_latents = pipe.vae.encode(
|
| 1302 |
+
vace_reference_image,
|
| 1303 |
+
device=pipe.device,
|
| 1304 |
+
tiled=tiled,
|
| 1305 |
+
tile_size=tile_size,
|
| 1306 |
+
tile_stride=tile_stride,
|
| 1307 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1308 |
+
vace_reference_latents = torch.concat(
|
| 1309 |
+
(vace_reference_latents, torch.zeros_like(vace_reference_latents)),
|
| 1310 |
+
dim=1,
|
| 1311 |
+
)
|
| 1312 |
+
vace_video_latents = torch.concat(
|
| 1313 |
+
(vace_reference_latents, vace_video_latents), dim=2
|
| 1314 |
+
)
|
| 1315 |
+
vace_mask_latents = torch.concat(
|
| 1316 |
+
(torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents),
|
| 1317 |
+
dim=2,
|
| 1318 |
+
)
|
| 1319 |
+
|
| 1320 |
+
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
|
| 1321 |
+
return {"vace_context": vace_context, "vace_scale": vace_scale}
|
| 1322 |
+
else:
|
| 1323 |
+
return {"vace_context": None, "vace_scale": vace_scale}
|
| 1324 |
+
|
| 1325 |
+
|
| 1326 |
+
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
|
| 1327 |
+
def __init__(self):
|
| 1328 |
+
super().__init__(input_params=())
|
| 1329 |
+
|
| 1330 |
+
def process(self, pipe: WanVideoPipeline):
|
| 1331 |
+
if hasattr(pipe, "use_unified_sequence_parallel"):
|
| 1332 |
+
if pipe.use_unified_sequence_parallel:
|
| 1333 |
+
return {"use_unified_sequence_parallel": True}
|
| 1334 |
+
return {}
|
| 1335 |
+
|
| 1336 |
+
|
| 1337 |
+
class WanVideoUnit_TeaCache(PipelineUnit):
|
| 1338 |
+
def __init__(self):
|
| 1339 |
+
super().__init__(
|
| 1340 |
+
seperate_cfg=True,
|
| 1341 |
+
input_params_posi={
|
| 1342 |
+
"num_inference_steps": "num_inference_steps",
|
| 1343 |
+
"tea_cache_l1_thresh": "tea_cache_l1_thresh",
|
| 1344 |
+
"tea_cache_model_id": "tea_cache_model_id",
|
| 1345 |
+
},
|
| 1346 |
+
input_params_nega={
|
| 1347 |
+
"num_inference_steps": "num_inference_steps",
|
| 1348 |
+
"tea_cache_l1_thresh": "tea_cache_l1_thresh",
|
| 1349 |
+
"tea_cache_model_id": "tea_cache_model_id",
|
| 1350 |
+
},
|
| 1351 |
+
)
|
| 1352 |
+
|
| 1353 |
+
def process(
|
| 1354 |
+
self,
|
| 1355 |
+
pipe: WanVideoPipeline,
|
| 1356 |
+
num_inference_steps,
|
| 1357 |
+
tea_cache_l1_thresh,
|
| 1358 |
+
tea_cache_model_id,
|
| 1359 |
+
):
|
| 1360 |
+
if tea_cache_l1_thresh is None:
|
| 1361 |
+
return {}
|
| 1362 |
+
return {
|
| 1363 |
+
"tea_cache": TeaCache(
|
| 1364 |
+
num_inference_steps,
|
| 1365 |
+
rel_l1_thresh=tea_cache_l1_thresh,
|
| 1366 |
+
model_id=tea_cache_model_id,
|
| 1367 |
+
)
|
| 1368 |
+
}
|
| 1369 |
+
|
| 1370 |
+
|
| 1371 |
+
class WanVideoUnit_CfgMerger(PipelineUnit):
|
| 1372 |
+
def __init__(self):
|
| 1373 |
+
super().__init__(take_over=True)
|
| 1374 |
+
self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"]
|
| 1375 |
+
|
| 1376 |
+
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
| 1377 |
+
if not inputs_shared["cfg_merge"]:
|
| 1378 |
+
return inputs_shared, inputs_posi, inputs_nega
|
| 1379 |
+
for name in self.concat_tensor_names:
|
| 1380 |
+
tensor_posi = inputs_posi.get(name)
|
| 1381 |
+
tensor_nega = inputs_nega.get(name)
|
| 1382 |
+
tensor_shared = inputs_shared.get(name)
|
| 1383 |
+
if tensor_posi is not None and tensor_nega is not None:
|
| 1384 |
+
inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0)
|
| 1385 |
+
elif tensor_shared is not None:
|
| 1386 |
+
inputs_shared[name] = torch.concat(
|
| 1387 |
+
(tensor_shared, tensor_shared), dim=0
|
| 1388 |
+
)
|
| 1389 |
+
inputs_posi.clear()
|
| 1390 |
+
inputs_nega.clear()
|
| 1391 |
+
return inputs_shared, inputs_posi, inputs_nega
|
| 1392 |
+
|
| 1393 |
+
|
| 1394 |
+
class TeaCache:
|
| 1395 |
+
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
| 1396 |
+
self.num_inference_steps = num_inference_steps
|
| 1397 |
+
self.step = 0
|
| 1398 |
+
self.accumulated_rel_l1_distance = 0
|
| 1399 |
+
self.previous_modulated_input = None
|
| 1400 |
+
self.rel_l1_thresh = rel_l1_thresh
|
| 1401 |
+
self.previous_residual = None
|
| 1402 |
+
self.previous_hidden_states = None
|
| 1403 |
+
|
| 1404 |
+
self.coefficients_dict = {
|
| 1405 |
+
"Wan2.1-T2V-1.3B": [
|
| 1406 |
+
-5.21862437e04,
|
| 1407 |
+
9.23041404e03,
|
| 1408 |
+
-5.28275948e02,
|
| 1409 |
+
1.36987616e01,
|
| 1410 |
+
-4.99875664e-02,
|
| 1411 |
+
],
|
| 1412 |
+
"Wan2.1-T2V-14B": [
|
| 1413 |
+
-3.03318725e05,
|
| 1414 |
+
4.90537029e04,
|
| 1415 |
+
-2.65530556e03,
|
| 1416 |
+
5.87365115e01,
|
| 1417 |
+
-3.15583525e-01,
|
| 1418 |
+
],
|
| 1419 |
+
"Wan2.1-I2V-14B-480P": [
|
| 1420 |
+
2.57151496e05,
|
| 1421 |
+
-3.54229917e04,
|
| 1422 |
+
1.40286849e03,
|
| 1423 |
+
-1.35890334e01,
|
| 1424 |
+
1.32517977e-01,
|
| 1425 |
+
],
|
| 1426 |
+
"Wan2.1-I2V-14B-720P": [
|
| 1427 |
+
8.10705460e03,
|
| 1428 |
+
2.13393892e03,
|
| 1429 |
+
-3.72934672e02,
|
| 1430 |
+
1.66203073e01,
|
| 1431 |
+
-4.17769401e-02,
|
| 1432 |
+
],
|
| 1433 |
+
}
|
| 1434 |
+
if model_id not in self.coefficients_dict:
|
| 1435 |
+
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
| 1436 |
+
raise ValueError(
|
| 1437 |
+
f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids})."
|
| 1438 |
+
)
|
| 1439 |
+
self.coefficients = self.coefficients_dict[model_id]
|
| 1440 |
+
|
| 1441 |
+
def check(self, dit: WanModel, x, t_mod):
|
| 1442 |
+
modulated_inp = t_mod.clone()
|
| 1443 |
+
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
| 1444 |
+
should_calc = True
|
| 1445 |
+
self.accumulated_rel_l1_distance = 0
|
| 1446 |
+
else:
|
| 1447 |
+
coefficients = self.coefficients
|
| 1448 |
+
rescale_func = np.poly1d(coefficients)
|
| 1449 |
+
self.accumulated_rel_l1_distance += rescale_func(
|
| 1450 |
+
(
|
| 1451 |
+
(modulated_inp - self.previous_modulated_input).abs().mean()
|
| 1452 |
+
/ self.previous_modulated_input.abs().mean()
|
| 1453 |
+
)
|
| 1454 |
+
.cpu()
|
| 1455 |
+
.item()
|
| 1456 |
+
)
|
| 1457 |
+
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
| 1458 |
+
should_calc = False
|
| 1459 |
+
else:
|
| 1460 |
+
should_calc = True
|
| 1461 |
+
self.accumulated_rel_l1_distance = 0
|
| 1462 |
+
self.previous_modulated_input = modulated_inp
|
| 1463 |
+
self.step += 1
|
| 1464 |
+
if self.step == self.num_inference_steps:
|
| 1465 |
+
self.step = 0
|
| 1466 |
+
if should_calc:
|
| 1467 |
+
self.previous_hidden_states = x.clone()
|
| 1468 |
+
return not should_calc
|
| 1469 |
+
|
| 1470 |
+
def store(self, hidden_states):
|
| 1471 |
+
self.previous_residual = hidden_states - self.previous_hidden_states
|
| 1472 |
+
self.previous_hidden_states = None
|
| 1473 |
+
|
| 1474 |
+
def update(self, hidden_states):
|
| 1475 |
+
hidden_states = hidden_states + self.previous_residual
|
| 1476 |
+
return hidden_states
|
| 1477 |
+
|
| 1478 |
+
|
| 1479 |
+
class TemporalTiler_BCTHW:
|
| 1480 |
+
def __init__(self):
|
| 1481 |
+
pass
|
| 1482 |
+
|
| 1483 |
+
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
| 1484 |
+
x = torch.ones((length,))
|
| 1485 |
+
if not left_bound:
|
| 1486 |
+
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
| 1487 |
+
if not right_bound:
|
| 1488 |
+
x[-border_width:] = torch.flip(
|
| 1489 |
+
(torch.arange(border_width) + 1) / border_width, dims=(0,)
|
| 1490 |
+
)
|
| 1491 |
+
return x
|
| 1492 |
+
|
| 1493 |
+
def build_mask(self, data, is_bound, border_width):
|
| 1494 |
+
_, _, T, _, _ = data.shape
|
| 1495 |
+
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
| 1496 |
+
mask = repeat(t, "T -> 1 1 T 1 1")
|
| 1497 |
+
return mask
|
| 1498 |
+
|
| 1499 |
+
def run(
|
| 1500 |
+
self,
|
| 1501 |
+
model_fn,
|
| 1502 |
+
sliding_window_size,
|
| 1503 |
+
sliding_window_stride,
|
| 1504 |
+
computation_device,
|
| 1505 |
+
computation_dtype,
|
| 1506 |
+
model_kwargs,
|
| 1507 |
+
tensor_names,
|
| 1508 |
+
batch_size=None,
|
| 1509 |
+
):
|
| 1510 |
+
tensor_names = [
|
| 1511 |
+
tensor_name
|
| 1512 |
+
for tensor_name in tensor_names
|
| 1513 |
+
if model_kwargs.get(tensor_name) is not None
|
| 1514 |
+
]
|
| 1515 |
+
tensor_dict = {
|
| 1516 |
+
tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names
|
| 1517 |
+
}
|
| 1518 |
+
B, C, T, H, W = tensor_dict[tensor_names[0]].shape
|
| 1519 |
+
if batch_size is not None:
|
| 1520 |
+
B *= batch_size
|
| 1521 |
+
data_device, data_dtype = (
|
| 1522 |
+
tensor_dict[tensor_names[0]].device,
|
| 1523 |
+
tensor_dict[tensor_names[0]].dtype,
|
| 1524 |
+
)
|
| 1525 |
+
value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)
|
| 1526 |
+
weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
|
| 1527 |
+
for t in range(0, T, sliding_window_stride):
|
| 1528 |
+
if (
|
| 1529 |
+
t - sliding_window_stride >= 0
|
| 1530 |
+
and t - sliding_window_stride + sliding_window_size >= T
|
| 1531 |
+
):
|
| 1532 |
+
continue
|
| 1533 |
+
t_ = min(t + sliding_window_size, T)
|
| 1534 |
+
model_kwargs.update(
|
| 1535 |
+
{
|
| 1536 |
+
tensor_name: tensor_dict[tensor_name][:, :, t:t_:, :].to(
|
| 1537 |
+
device=computation_device, dtype=computation_dtype
|
| 1538 |
+
)
|
| 1539 |
+
for tensor_name in tensor_names
|
| 1540 |
+
}
|
| 1541 |
+
)
|
| 1542 |
+
model_output = model_fn(**model_kwargs).to(
|
| 1543 |
+
device=data_device, dtype=data_dtype
|
| 1544 |
+
)
|
| 1545 |
+
mask = self.build_mask(
|
| 1546 |
+
model_output,
|
| 1547 |
+
is_bound=(t == 0, t_ == T),
|
| 1548 |
+
border_width=(sliding_window_size - sliding_window_stride,),
|
| 1549 |
+
).to(device=data_device, dtype=data_dtype)
|
| 1550 |
+
value[:, :, t:t_, :, :] += model_output * mask
|
| 1551 |
+
weight[:, :, t:t_, :, :] += mask
|
| 1552 |
+
value /= weight
|
| 1553 |
+
model_kwargs.update(tensor_dict)
|
| 1554 |
+
return value
|
| 1555 |
+
|
| 1556 |
+
|
| 1557 |
+
def model_fn_wan_video(
|
| 1558 |
+
dit: WanModel,
|
| 1559 |
+
motion_controller: WanMotionControllerModel = None,
|
| 1560 |
+
vace: VaceWanModel = None,
|
| 1561 |
+
latents: torch.Tensor = None,
|
| 1562 |
+
timestep: torch.Tensor = None,
|
| 1563 |
+
context: torch.Tensor = None,
|
| 1564 |
+
clip_feature: Optional[torch.Tensor] = None,
|
| 1565 |
+
y: Optional[torch.Tensor] = None,
|
| 1566 |
+
reference_latents=None,
|
| 1567 |
+
vace_context=None,
|
| 1568 |
+
vace_scale=1.0,
|
| 1569 |
+
tea_cache: TeaCache = None,
|
| 1570 |
+
use_unified_sequence_parallel: bool = False,
|
| 1571 |
+
motion_bucket_id: Optional[torch.Tensor] = None,
|
| 1572 |
+
sliding_window_size: Optional[int] = None,
|
| 1573 |
+
sliding_window_stride: Optional[int] = None,
|
| 1574 |
+
cfg_merge: bool = False,
|
| 1575 |
+
use_gradient_checkpointing: bool = False,
|
| 1576 |
+
use_gradient_checkpointing_offload: bool = False,
|
| 1577 |
+
control_camera_latents_input=None,
|
| 1578 |
+
fuse_vae_embedding_in_latents: bool = False,
|
| 1579 |
+
ip_image=None,
|
| 1580 |
+
**kwargs,
|
| 1581 |
+
):
|
| 1582 |
+
if sliding_window_size is not None and sliding_window_stride is not None:
|
| 1583 |
+
model_kwargs = dict(
|
| 1584 |
+
dit=dit,
|
| 1585 |
+
motion_controller=motion_controller,
|
| 1586 |
+
vace=vace,
|
| 1587 |
+
latents=latents,
|
| 1588 |
+
timestep=timestep,
|
| 1589 |
+
context=context,
|
| 1590 |
+
clip_feature=clip_feature,
|
| 1591 |
+
y=y,
|
| 1592 |
+
reference_latents=reference_latents,
|
| 1593 |
+
vace_context=vace_context,
|
| 1594 |
+
vace_scale=vace_scale,
|
| 1595 |
+
tea_cache=tea_cache,
|
| 1596 |
+
use_unified_sequence_parallel=use_unified_sequence_parallel,
|
| 1597 |
+
motion_bucket_id=motion_bucket_id,
|
| 1598 |
+
)
|
| 1599 |
+
return TemporalTiler_BCTHW().run(
|
| 1600 |
+
model_fn_wan_video,
|
| 1601 |
+
sliding_window_size,
|
| 1602 |
+
sliding_window_stride,
|
| 1603 |
+
latents.device,
|
| 1604 |
+
latents.dtype,
|
| 1605 |
+
model_kwargs=model_kwargs,
|
| 1606 |
+
tensor_names=["latents", "y"],
|
| 1607 |
+
batch_size=2 if cfg_merge else 1,
|
| 1608 |
+
)
|
| 1609 |
+
|
| 1610 |
+
if use_unified_sequence_parallel:
|
| 1611 |
+
import torch.distributed as dist
|
| 1612 |
+
from xfuser.core.distributed import (
|
| 1613 |
+
get_sequence_parallel_rank,
|
| 1614 |
+
get_sequence_parallel_world_size,
|
| 1615 |
+
get_sp_group,
|
| 1616 |
+
)
|
| 1617 |
+
x_ip = None
|
| 1618 |
+
t_mod_ip = None
|
| 1619 |
+
# Timestep
|
| 1620 |
+
if dit.seperated_timestep and fuse_vae_embedding_in_latents:
|
| 1621 |
+
timestep = torch.concat(
|
| 1622 |
+
[
|
| 1623 |
+
torch.zeros(
|
| 1624 |
+
(1, latents.shape[3] * latents.shape[4] // 4),
|
| 1625 |
+
dtype=latents.dtype,
|
| 1626 |
+
device=latents.device,
|
| 1627 |
+
),
|
| 1628 |
+
torch.ones(
|
| 1629 |
+
(latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4),
|
| 1630 |
+
dtype=latents.dtype,
|
| 1631 |
+
device=latents.device,
|
| 1632 |
+
)
|
| 1633 |
+
* timestep,
|
| 1634 |
+
]
|
| 1635 |
+
).flatten()
|
| 1636 |
+
t = dit.time_embedding(
|
| 1637 |
+
sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)
|
| 1638 |
+
)
|
| 1639 |
+
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
|
| 1640 |
+
else:
|
| 1641 |
+
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
| 1642 |
+
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
| 1643 |
+
|
| 1644 |
+
if ip_image is not None:
|
| 1645 |
+
timestep_ip = torch.zeros_like(timestep) # [B] with 0s
|
| 1646 |
+
t_ip = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep_ip))
|
| 1647 |
+
t_mod_ip = dit.time_projection(t_ip).unflatten(1, (6, dit.dim))
|
| 1648 |
+
|
| 1649 |
+
# Motion Controller
|
| 1650 |
+
if motion_bucket_id is not None and motion_controller is not None:
|
| 1651 |
+
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
| 1652 |
+
context = dit.text_embedding(context)
|
| 1653 |
+
|
| 1654 |
+
x = latents
|
| 1655 |
+
# Merged cfg
|
| 1656 |
+
if x.shape[0] != context.shape[0]:
|
| 1657 |
+
x = torch.concat([x] * context.shape[0], dim=0)
|
| 1658 |
+
if timestep.shape[0] != context.shape[0]:
|
| 1659 |
+
timestep = torch.concat([timestep] * context.shape[0], dim=0)
|
| 1660 |
+
|
| 1661 |
+
# Image Embedding
|
| 1662 |
+
if y is not None and dit.require_vae_embedding:
|
| 1663 |
+
x = torch.cat([x, y], dim=1)
|
| 1664 |
+
if clip_feature is not None and dit.require_clip_embedding:
|
| 1665 |
+
clip_embdding = dit.img_emb(clip_feature)
|
| 1666 |
+
context = torch.cat([clip_embdding, context], dim=1)
|
| 1667 |
+
|
| 1668 |
+
# Add camera control
|
| 1669 |
+
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
| 1670 |
+
|
| 1671 |
+
# Reference image
|
| 1672 |
+
if reference_latents is not None:
|
| 1673 |
+
if len(reference_latents.shape) == 5:
|
| 1674 |
+
reference_latents = reference_latents[:, :, 0]
|
| 1675 |
+
reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2)
|
| 1676 |
+
x = torch.concat([reference_latents, x], dim=1)
|
| 1677 |
+
f += 1
|
| 1678 |
+
|
| 1679 |
+
offset = 1
|
| 1680 |
+
freqs = (
|
| 1681 |
+
torch.cat(
|
| 1682 |
+
[
|
| 1683 |
+
dit.freqs[0][offset : f + offset].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 1684 |
+
dit.freqs[1][offset : h + offset].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 1685 |
+
dit.freqs[2][offset : w + offset].view(1, 1, w, -1).expand(f, h, w, -1),
|
| 1686 |
+
],
|
| 1687 |
+
dim=-1,
|
| 1688 |
+
)
|
| 1689 |
+
.reshape(f * h * w, 1, -1)
|
| 1690 |
+
.to(x.device)
|
| 1691 |
+
)
|
| 1692 |
+
|
| 1693 |
+
############################################################################################
|
| 1694 |
+
if ip_image is not None:
|
| 1695 |
+
x_ip, (f_ip, h_ip, w_ip) = dit.patchify(
|
| 1696 |
+
ip_image
|
| 1697 |
+
) # x_ip [1, 1024, 5120] [B, N, D] f_ip = 1 h_ip = 32 w_ip = 32
|
| 1698 |
+
freqs_ip = (
|
| 1699 |
+
torch.cat(
|
| 1700 |
+
[
|
| 1701 |
+
dit.freqs[0][0].view(f_ip, 1, 1, -1).expand(f_ip, h_ip, w_ip, -1),
|
| 1702 |
+
dit.freqs[1][h + offset : h + offset + h_ip]
|
| 1703 |
+
.view(1, h_ip, 1, -1)
|
| 1704 |
+
.expand(f_ip, h_ip, w_ip, -1),
|
| 1705 |
+
dit.freqs[2][w + offset : w + offset + w_ip]
|
| 1706 |
+
.view(1, 1, w_ip, -1)
|
| 1707 |
+
.expand(f_ip, h_ip, w_ip, -1),
|
| 1708 |
+
],
|
| 1709 |
+
dim=-1,
|
| 1710 |
+
)
|
| 1711 |
+
.reshape(f_ip * h_ip * w_ip, 1, -1)
|
| 1712 |
+
.to(x_ip.device)
|
| 1713 |
+
)
|
| 1714 |
+
freqs_original = freqs
|
| 1715 |
+
freqs = torch.cat([freqs, freqs_ip], dim=0)
|
| 1716 |
+
############################################################################################
|
| 1717 |
+
else:
|
| 1718 |
+
freqs_original = freqs
|
| 1719 |
+
# TeaCache
|
| 1720 |
+
if tea_cache is not None:
|
| 1721 |
+
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
| 1722 |
+
else:
|
| 1723 |
+
tea_cache_update = False
|
| 1724 |
+
|
| 1725 |
+
if vace_context is not None:
|
| 1726 |
+
vace_hints = vace(x, vace_context, context, t_mod, freqs_original)
|
| 1727 |
+
|
| 1728 |
+
# blocks
|
| 1729 |
+
if use_unified_sequence_parallel:
|
| 1730 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
| 1731 |
+
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[
|
| 1732 |
+
get_sequence_parallel_rank()
|
| 1733 |
+
]
|
| 1734 |
+
if tea_cache_update:
|
| 1735 |
+
x = tea_cache.update(x)
|
| 1736 |
+
else:
|
| 1737 |
+
|
| 1738 |
+
def create_custom_forward(module):
|
| 1739 |
+
def custom_forward(*inputs):
|
| 1740 |
+
return module(*inputs)
|
| 1741 |
+
|
| 1742 |
+
return custom_forward
|
| 1743 |
+
|
| 1744 |
+
for block_id, block in enumerate(dit.blocks):
|
| 1745 |
+
if use_gradient_checkpointing_offload:
|
| 1746 |
+
with torch.autograd.graph.save_on_cpu():
|
| 1747 |
+
x, x_ip = torch.utils.checkpoint.checkpoint(
|
| 1748 |
+
create_custom_forward(block),
|
| 1749 |
+
x,
|
| 1750 |
+
context,
|
| 1751 |
+
t_mod,
|
| 1752 |
+
freqs,
|
| 1753 |
+
x_ip=x_ip,
|
| 1754 |
+
t_mod_ip=t_mod_ip,
|
| 1755 |
+
use_reentrant=False,
|
| 1756 |
+
)
|
| 1757 |
+
elif use_gradient_checkpointing:
|
| 1758 |
+
x, x_ip = torch.utils.checkpoint.checkpoint(
|
| 1759 |
+
create_custom_forward(block),
|
| 1760 |
+
x,
|
| 1761 |
+
context,
|
| 1762 |
+
t_mod,
|
| 1763 |
+
freqs,
|
| 1764 |
+
x_ip=x_ip,
|
| 1765 |
+
t_mod_ip=t_mod_ip,
|
| 1766 |
+
use_reentrant=False,
|
| 1767 |
+
)
|
| 1768 |
+
else:
|
| 1769 |
+
x, x_ip = block(x, context, t_mod, freqs, x_ip=x_ip, t_mod_ip=t_mod_ip)
|
| 1770 |
+
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
| 1771 |
+
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
| 1772 |
+
if (
|
| 1773 |
+
use_unified_sequence_parallel
|
| 1774 |
+
and dist.is_initialized()
|
| 1775 |
+
and dist.get_world_size() > 1
|
| 1776 |
+
):
|
| 1777 |
+
current_vace_hint = torch.chunk(
|
| 1778 |
+
current_vace_hint, get_sequence_parallel_world_size(), dim=1
|
| 1779 |
+
)[get_sequence_parallel_rank()]
|
| 1780 |
+
x = x + current_vace_hint * vace_scale
|
| 1781 |
+
if tea_cache is not None:
|
| 1782 |
+
tea_cache.store(x)
|
| 1783 |
+
|
| 1784 |
+
x = dit.head(x, t)
|
| 1785 |
+
if use_unified_sequence_parallel:
|
| 1786 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
| 1787 |
+
x = get_sp_group().all_gather(x, dim=1)
|
| 1788 |
+
# Remove reference latents
|
| 1789 |
+
if reference_latents is not None:
|
| 1790 |
+
x = x[:, reference_latents.shape[1] :]
|
| 1791 |
+
f -= 1
|
| 1792 |
+
x = dit.unpatchify(x, (f, h, w))
|
| 1793 |
+
return x
|
pipelines/wan_video_face_swap.py
ADDED
|
@@ -0,0 +1,1786 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, types
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from einops import repeat
|
| 5 |
+
from typing import Optional, Union
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from typing_extensions import Literal
|
| 11 |
+
import imageio
|
| 12 |
+
import os
|
| 13 |
+
from typing import List
|
| 14 |
+
import cv2
|
| 15 |
+
|
| 16 |
+
from utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
|
| 17 |
+
from models import ModelManager, load_state_dict
|
| 18 |
+
from models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
| 19 |
+
from models.wan_video_text_encoder import (
|
| 20 |
+
WanTextEncoder,
|
| 21 |
+
T5RelativeEmbedding,
|
| 22 |
+
T5LayerNorm,
|
| 23 |
+
)
|
| 24 |
+
from models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
| 25 |
+
from models.wan_video_image_encoder import WanImageEncoder
|
| 26 |
+
from models.wan_video_vace import VaceWanModel
|
| 27 |
+
from models.wan_video_motion_controller import WanMotionControllerModel
|
| 28 |
+
from schedulers.flow_match import FlowMatchScheduler
|
| 29 |
+
from prompters import WanPrompter
|
| 30 |
+
from vram_management import (
|
| 31 |
+
enable_vram_management,
|
| 32 |
+
AutoWrappedModule,
|
| 33 |
+
AutoWrappedLinear,
|
| 34 |
+
WanAutoCastLayerNorm,
|
| 35 |
+
)
|
| 36 |
+
from lora import GeneralLoRALoader
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_video_as_list(video_path: str) -> List[Image.Image]:
|
| 40 |
+
if not os.path.isfile(video_path):
|
| 41 |
+
raise FileNotFoundError(video_path)
|
| 42 |
+
reader = imageio.get_reader(video_path)
|
| 43 |
+
frames = []
|
| 44 |
+
for i, frame_data in enumerate(reader):
|
| 45 |
+
pil_image = Image.fromarray(frame_data)
|
| 46 |
+
frames.append(pil_image)
|
| 47 |
+
reader.close()
|
| 48 |
+
return frames
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class WanVideoPipeline_FaceSwap(BasePipeline):
|
| 52 |
+
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
|
| 53 |
+
super().__init__(
|
| 54 |
+
device=device,
|
| 55 |
+
torch_dtype=torch_dtype,
|
| 56 |
+
height_division_factor=16,
|
| 57 |
+
width_division_factor=16,
|
| 58 |
+
time_division_factor=4,
|
| 59 |
+
time_division_remainder=1,
|
| 60 |
+
)
|
| 61 |
+
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
| 62 |
+
self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
|
| 63 |
+
self.text_encoder: WanTextEncoder = None
|
| 64 |
+
self.image_encoder: WanImageEncoder = None
|
| 65 |
+
self.dit: WanModel = None
|
| 66 |
+
self.dit2: WanModel = None
|
| 67 |
+
self.vae: WanVideoVAE = None
|
| 68 |
+
self.motion_controller: WanMotionControllerModel = None
|
| 69 |
+
self.vace: VaceWanModel = None
|
| 70 |
+
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
| 71 |
+
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
|
| 72 |
+
self.unit_runner = PipelineUnitRunner()
|
| 73 |
+
self.units = [
|
| 74 |
+
WanVideoUnit_ShapeChecker(),
|
| 75 |
+
WanVideoUnit_NoiseInitializer(),
|
| 76 |
+
WanVideoUnit_InputVideoEmbedder(),
|
| 77 |
+
WanVideoUnit_PromptEmbedder(),
|
| 78 |
+
WanVideoUnit_ImageEmbedderVAE(),
|
| 79 |
+
WanVideoUnit_ImageEmbedderCLIP(),
|
| 80 |
+
WanVideoUnit_ImageEmbedderFused(),
|
| 81 |
+
WanVideoUnit_FunControl(),
|
| 82 |
+
WanVideoUnit_FunReference(),
|
| 83 |
+
WanVideoUnit_FunCameraControl(),
|
| 84 |
+
WanVideoUnit_SpeedControl(),
|
| 85 |
+
WanVideoUnit_VACE(),
|
| 86 |
+
WanVideoUnit_UnifiedSequenceParallel(),
|
| 87 |
+
WanVideoUnit_TeaCache(),
|
| 88 |
+
WanVideoUnit_CfgMerger(),
|
| 89 |
+
]
|
| 90 |
+
self.model_fn = model_fn_wan_video
|
| 91 |
+
|
| 92 |
+
def encode_ip_image(self, ip_image):
|
| 93 |
+
self.load_models_to_device(["vae"])
|
| 94 |
+
ip_image = (
|
| 95 |
+
torch.tensor(np.array(ip_image)).permute(2, 0, 1).float() / 255.0
|
| 96 |
+
) # [3, H, W]
|
| 97 |
+
ip_image = (
|
| 98 |
+
ip_image.unsqueeze(1).unsqueeze(0).to(dtype=self.torch_dtype)
|
| 99 |
+
) # [B, 3, 1, H, W]
|
| 100 |
+
ip_image = ip_image * 2 - 1
|
| 101 |
+
ip_image_latent = self.vae.encode(ip_image, device=self.device, tiled=False)
|
| 102 |
+
return ip_image_latent
|
| 103 |
+
|
| 104 |
+
def load_lora(self, module, path, alpha=1):
|
| 105 |
+
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
| 106 |
+
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
| 107 |
+
loader.load(module, lora, alpha=alpha)
|
| 108 |
+
|
| 109 |
+
def training_loss(self, **inputs):
|
| 110 |
+
max_timestep_boundary = int(
|
| 111 |
+
inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps
|
| 112 |
+
)
|
| 113 |
+
min_timestep_boundary = int(
|
| 114 |
+
inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps
|
| 115 |
+
)
|
| 116 |
+
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
| 117 |
+
timestep = self.scheduler.timesteps[timestep_id].to(
|
| 118 |
+
dtype=self.torch_dtype, device=self.device
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
inputs["latents"] = self.scheduler.add_noise(
|
| 122 |
+
inputs["input_latents"], inputs["noise"], timestep
|
| 123 |
+
)
|
| 124 |
+
training_target = self.scheduler.training_target(
|
| 125 |
+
inputs["input_latents"], inputs["noise"], timestep
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
| 129 |
+
|
| 130 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
| 131 |
+
loss = loss * self.scheduler.training_weight(timestep)
|
| 132 |
+
return loss
|
| 133 |
+
|
| 134 |
+
def enable_vram_management(
|
| 135 |
+
self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5
|
| 136 |
+
):
|
| 137 |
+
self.vram_management_enabled = True
|
| 138 |
+
if num_persistent_param_in_dit is not None:
|
| 139 |
+
vram_limit = None
|
| 140 |
+
else:
|
| 141 |
+
if vram_limit is None:
|
| 142 |
+
vram_limit = self.get_vram()
|
| 143 |
+
vram_limit = vram_limit - vram_buffer
|
| 144 |
+
if self.text_encoder is not None:
|
| 145 |
+
dtype = next(iter(self.text_encoder.parameters())).dtype
|
| 146 |
+
enable_vram_management(
|
| 147 |
+
self.text_encoder,
|
| 148 |
+
module_map={
|
| 149 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 150 |
+
torch.nn.Embedding: AutoWrappedModule,
|
| 151 |
+
T5RelativeEmbedding: AutoWrappedModule,
|
| 152 |
+
T5LayerNorm: AutoWrappedModule,
|
| 153 |
+
},
|
| 154 |
+
module_config=dict(
|
| 155 |
+
offload_dtype=dtype,
|
| 156 |
+
offload_device="cpu",
|
| 157 |
+
onload_dtype=dtype,
|
| 158 |
+
onload_device="cpu",
|
| 159 |
+
computation_dtype=self.torch_dtype,
|
| 160 |
+
computation_device=self.device,
|
| 161 |
+
),
|
| 162 |
+
vram_limit=vram_limit,
|
| 163 |
+
)
|
| 164 |
+
if self.dit is not None:
|
| 165 |
+
dtype = next(iter(self.dit.parameters())).dtype
|
| 166 |
+
device = "cpu" if vram_limit is not None else self.device
|
| 167 |
+
enable_vram_management(
|
| 168 |
+
self.dit,
|
| 169 |
+
module_map={
|
| 170 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 171 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
| 172 |
+
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
| 173 |
+
RMSNorm: AutoWrappedModule,
|
| 174 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
| 175 |
+
},
|
| 176 |
+
module_config=dict(
|
| 177 |
+
offload_dtype=dtype,
|
| 178 |
+
offload_device="cpu",
|
| 179 |
+
onload_dtype=dtype,
|
| 180 |
+
onload_device=device,
|
| 181 |
+
computation_dtype=self.torch_dtype,
|
| 182 |
+
computation_device=self.device,
|
| 183 |
+
),
|
| 184 |
+
max_num_param=num_persistent_param_in_dit,
|
| 185 |
+
overflow_module_config=dict(
|
| 186 |
+
offload_dtype=dtype,
|
| 187 |
+
offload_device="cpu",
|
| 188 |
+
onload_dtype=dtype,
|
| 189 |
+
onload_device="cpu",
|
| 190 |
+
computation_dtype=self.torch_dtype,
|
| 191 |
+
computation_device=self.device,
|
| 192 |
+
),
|
| 193 |
+
vram_limit=vram_limit,
|
| 194 |
+
)
|
| 195 |
+
if self.dit2 is not None:
|
| 196 |
+
dtype = next(iter(self.dit2.parameters())).dtype
|
| 197 |
+
device = "cpu" if vram_limit is not None else self.device
|
| 198 |
+
enable_vram_management(
|
| 199 |
+
self.dit2,
|
| 200 |
+
module_map={
|
| 201 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 202 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
| 203 |
+
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
| 204 |
+
RMSNorm: AutoWrappedModule,
|
| 205 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
| 206 |
+
},
|
| 207 |
+
module_config=dict(
|
| 208 |
+
offload_dtype=dtype,
|
| 209 |
+
offload_device="cpu",
|
| 210 |
+
onload_dtype=dtype,
|
| 211 |
+
onload_device=device,
|
| 212 |
+
computation_dtype=self.torch_dtype,
|
| 213 |
+
computation_device=self.device,
|
| 214 |
+
),
|
| 215 |
+
max_num_param=num_persistent_param_in_dit,
|
| 216 |
+
overflow_module_config=dict(
|
| 217 |
+
offload_dtype=dtype,
|
| 218 |
+
offload_device="cpu",
|
| 219 |
+
onload_dtype=dtype,
|
| 220 |
+
onload_device="cpu",
|
| 221 |
+
computation_dtype=self.torch_dtype,
|
| 222 |
+
computation_device=self.device,
|
| 223 |
+
),
|
| 224 |
+
vram_limit=vram_limit,
|
| 225 |
+
)
|
| 226 |
+
if self.vae is not None:
|
| 227 |
+
dtype = next(iter(self.vae.parameters())).dtype
|
| 228 |
+
enable_vram_management(
|
| 229 |
+
self.vae,
|
| 230 |
+
module_map={
|
| 231 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 232 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
| 233 |
+
RMS_norm: AutoWrappedModule,
|
| 234 |
+
CausalConv3d: AutoWrappedModule,
|
| 235 |
+
Upsample: AutoWrappedModule,
|
| 236 |
+
torch.nn.SiLU: AutoWrappedModule,
|
| 237 |
+
torch.nn.Dropout: AutoWrappedModule,
|
| 238 |
+
},
|
| 239 |
+
module_config=dict(
|
| 240 |
+
offload_dtype=dtype,
|
| 241 |
+
offload_device="cpu",
|
| 242 |
+
onload_dtype=dtype,
|
| 243 |
+
onload_device=self.device,
|
| 244 |
+
computation_dtype=self.torch_dtype,
|
| 245 |
+
computation_device=self.device,
|
| 246 |
+
),
|
| 247 |
+
)
|
| 248 |
+
if self.image_encoder is not None:
|
| 249 |
+
dtype = next(iter(self.image_encoder.parameters())).dtype
|
| 250 |
+
enable_vram_management(
|
| 251 |
+
self.image_encoder,
|
| 252 |
+
module_map={
|
| 253 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 254 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
| 255 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
| 256 |
+
},
|
| 257 |
+
module_config=dict(
|
| 258 |
+
offload_dtype=dtype,
|
| 259 |
+
offload_device="cpu",
|
| 260 |
+
onload_dtype=dtype,
|
| 261 |
+
onload_device="cpu",
|
| 262 |
+
computation_dtype=dtype,
|
| 263 |
+
computation_device=self.device,
|
| 264 |
+
),
|
| 265 |
+
)
|
| 266 |
+
if self.motion_controller is not None:
|
| 267 |
+
dtype = next(iter(self.motion_controller.parameters())).dtype
|
| 268 |
+
enable_vram_management(
|
| 269 |
+
self.motion_controller,
|
| 270 |
+
module_map={
|
| 271 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 272 |
+
},
|
| 273 |
+
module_config=dict(
|
| 274 |
+
offload_dtype=dtype,
|
| 275 |
+
offload_device="cpu",
|
| 276 |
+
onload_dtype=dtype,
|
| 277 |
+
onload_device="cpu",
|
| 278 |
+
computation_dtype=dtype,
|
| 279 |
+
computation_device=self.device,
|
| 280 |
+
),
|
| 281 |
+
)
|
| 282 |
+
if self.vace is not None:
|
| 283 |
+
device = "cpu" if vram_limit is not None else self.device
|
| 284 |
+
enable_vram_management(
|
| 285 |
+
self.vace,
|
| 286 |
+
module_map={
|
| 287 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 288 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
| 289 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
| 290 |
+
RMSNorm: AutoWrappedModule,
|
| 291 |
+
},
|
| 292 |
+
module_config=dict(
|
| 293 |
+
offload_dtype=dtype,
|
| 294 |
+
offload_device="cpu",
|
| 295 |
+
onload_dtype=dtype,
|
| 296 |
+
onload_device=device,
|
| 297 |
+
computation_dtype=self.torch_dtype,
|
| 298 |
+
computation_device=self.device,
|
| 299 |
+
),
|
| 300 |
+
vram_limit=vram_limit,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def initialize_usp(self):
|
| 304 |
+
import torch.distributed as dist
|
| 305 |
+
from xfuser.core.distributed import (
|
| 306 |
+
initialize_model_parallel,
|
| 307 |
+
init_distributed_environment,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
dist.init_process_group(backend="nccl", init_method="env://")
|
| 311 |
+
init_distributed_environment(
|
| 312 |
+
rank=dist.get_rank(), world_size=dist.get_world_size()
|
| 313 |
+
)
|
| 314 |
+
initialize_model_parallel(
|
| 315 |
+
sequence_parallel_degree=dist.get_world_size(),
|
| 316 |
+
ring_degree=1,
|
| 317 |
+
ulysses_degree=dist.get_world_size(),
|
| 318 |
+
)
|
| 319 |
+
torch.cuda.set_device(dist.get_rank())
|
| 320 |
+
|
| 321 |
+
def enable_usp(self):
|
| 322 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
| 323 |
+
from distributed.xdit_context_parallel import (
|
| 324 |
+
usp_attn_forward,
|
| 325 |
+
usp_dit_forward,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
for block in self.dit.blocks:
|
| 329 |
+
block.self_attn.forward = types.MethodType(
|
| 330 |
+
usp_attn_forward, block.self_attn
|
| 331 |
+
)
|
| 332 |
+
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
|
| 333 |
+
if self.dit2 is not None:
|
| 334 |
+
for block in self.dit2.blocks:
|
| 335 |
+
block.self_attn.forward = types.MethodType(
|
| 336 |
+
usp_attn_forward, block.self_attn
|
| 337 |
+
)
|
| 338 |
+
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
| 339 |
+
self.sp_size = get_sequence_parallel_world_size()
|
| 340 |
+
self.use_unified_sequence_parallel = True
|
| 341 |
+
|
| 342 |
+
@staticmethod
|
| 343 |
+
def from_pretrained(
|
| 344 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 345 |
+
device: Union[str, torch.device] = "cuda",
|
| 346 |
+
model_configs: list[ModelConfig] = [],
|
| 347 |
+
tokenizer_config: ModelConfig = ModelConfig(
|
| 348 |
+
model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"
|
| 349 |
+
),
|
| 350 |
+
redirect_common_files: bool = True,
|
| 351 |
+
use_usp=False,
|
| 352 |
+
):
|
| 353 |
+
# Redirect model path
|
| 354 |
+
if redirect_common_files:
|
| 355 |
+
redirect_dict = {
|
| 356 |
+
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
|
| 357 |
+
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
|
| 358 |
+
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
|
| 359 |
+
}
|
| 360 |
+
for model_config in model_configs:
|
| 361 |
+
if (
|
| 362 |
+
model_config.origin_file_pattern is None
|
| 363 |
+
or model_config.model_id is None
|
| 364 |
+
):
|
| 365 |
+
continue
|
| 366 |
+
if (
|
| 367 |
+
model_config.origin_file_pattern in redirect_dict
|
| 368 |
+
and model_config.model_id
|
| 369 |
+
!= redirect_dict[model_config.origin_file_pattern]
|
| 370 |
+
):
|
| 371 |
+
print(
|
| 372 |
+
f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection."
|
| 373 |
+
)
|
| 374 |
+
model_config.model_id = redirect_dict[
|
| 375 |
+
model_config.origin_file_pattern
|
| 376 |
+
]
|
| 377 |
+
|
| 378 |
+
# Initialize pipeline
|
| 379 |
+
pipe = WanVideoPipeline_FaceSwap(device=device, torch_dtype=torch_dtype)
|
| 380 |
+
if use_usp:
|
| 381 |
+
pipe.initialize_usp()
|
| 382 |
+
|
| 383 |
+
# Download and load models
|
| 384 |
+
model_manager = ModelManager()
|
| 385 |
+
for model_config in model_configs:
|
| 386 |
+
model_config.download_if_necessary(use_usp=use_usp)
|
| 387 |
+
model_manager.load_model(
|
| 388 |
+
model_config.path,
|
| 389 |
+
device=model_config.offload_device or device,
|
| 390 |
+
torch_dtype=model_config.offload_dtype or torch_dtype,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# Load models
|
| 394 |
+
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
| 395 |
+
dit = model_manager.fetch_model("wan_video_dit", index=2)
|
| 396 |
+
if isinstance(dit, list):
|
| 397 |
+
pipe.dit, pipe.dit2 = dit
|
| 398 |
+
else:
|
| 399 |
+
pipe.dit = dit
|
| 400 |
+
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
| 401 |
+
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
| 402 |
+
pipe.motion_controller = model_manager.fetch_model(
|
| 403 |
+
"wan_video_motion_controller"
|
| 404 |
+
)
|
| 405 |
+
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
| 406 |
+
|
| 407 |
+
# Size division factor
|
| 408 |
+
if pipe.vae is not None:
|
| 409 |
+
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
|
| 410 |
+
pipe.width_division_factor = pipe.vae.upsampling_factor * 2
|
| 411 |
+
|
| 412 |
+
# Initialize tokenizer
|
| 413 |
+
tokenizer_config.download_if_necessary(use_usp=use_usp)
|
| 414 |
+
pipe.prompter.fetch_models(pipe.text_encoder)
|
| 415 |
+
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
| 416 |
+
|
| 417 |
+
# Unified Sequence Parallel
|
| 418 |
+
if use_usp:
|
| 419 |
+
pipe.enable_usp()
|
| 420 |
+
return pipe
|
| 421 |
+
|
| 422 |
+
@torch.no_grad()
|
| 423 |
+
def __call__(
|
| 424 |
+
self,
|
| 425 |
+
# Prompt
|
| 426 |
+
prompt: str,
|
| 427 |
+
negative_prompt: Optional[str] = "",
|
| 428 |
+
# Image-to-video
|
| 429 |
+
input_image: Optional[Image.Image] = None,
|
| 430 |
+
# First-last-frame-to-video
|
| 431 |
+
end_image: Optional[Image.Image] = None,
|
| 432 |
+
# Video-to-video
|
| 433 |
+
input_video: Optional[list[Image.Image]] = None,
|
| 434 |
+
denoising_strength: Optional[float] = 1,
|
| 435 |
+
# ControlNet
|
| 436 |
+
control_video: Optional[list[Image.Image]] = None,
|
| 437 |
+
reference_image: Optional[Image.Image] = None,
|
| 438 |
+
# Camera control
|
| 439 |
+
camera_control_direction: Optional[
|
| 440 |
+
Literal[
|
| 441 |
+
"Left",
|
| 442 |
+
"Right",
|
| 443 |
+
"Up",
|
| 444 |
+
"Down",
|
| 445 |
+
"LeftUp",
|
| 446 |
+
"LeftDown",
|
| 447 |
+
"RightUp",
|
| 448 |
+
"RightDown",
|
| 449 |
+
]
|
| 450 |
+
] = None,
|
| 451 |
+
camera_control_speed: Optional[float] = 1 / 54,
|
| 452 |
+
camera_control_origin: Optional[tuple] = (
|
| 453 |
+
0,
|
| 454 |
+
0.532139961,
|
| 455 |
+
0.946026558,
|
| 456 |
+
0.5,
|
| 457 |
+
0.5,
|
| 458 |
+
0,
|
| 459 |
+
0,
|
| 460 |
+
1,
|
| 461 |
+
0,
|
| 462 |
+
0,
|
| 463 |
+
0,
|
| 464 |
+
0,
|
| 465 |
+
1,
|
| 466 |
+
0,
|
| 467 |
+
0,
|
| 468 |
+
0,
|
| 469 |
+
0,
|
| 470 |
+
1,
|
| 471 |
+
0,
|
| 472 |
+
),
|
| 473 |
+
# VACE
|
| 474 |
+
vace_video: Optional[list[Image.Image]] = None,
|
| 475 |
+
vace_video_mask: Optional[Image.Image] = None,
|
| 476 |
+
vace_reference_image: Optional[Image.Image] = None,
|
| 477 |
+
vace_scale: Optional[float] = 1.0,
|
| 478 |
+
# Randomness
|
| 479 |
+
seed: Optional[int] = None,
|
| 480 |
+
rand_device: Optional[str] = "cpu",
|
| 481 |
+
# Shape
|
| 482 |
+
height: Optional[int] = 480,
|
| 483 |
+
width: Optional[int] = 832,
|
| 484 |
+
num_frames=81,
|
| 485 |
+
# Classifier-free guidance
|
| 486 |
+
cfg_scale: Optional[float] = 5.0,
|
| 487 |
+
cfg_merge: Optional[bool] = False,
|
| 488 |
+
# Boundary
|
| 489 |
+
switch_DiT_boundary: Optional[float] = 0.875,
|
| 490 |
+
# Scheduler
|
| 491 |
+
num_inference_steps: Optional[int] = 50,
|
| 492 |
+
sigma_shift: Optional[float] = 5.0,
|
| 493 |
+
# Speed control
|
| 494 |
+
motion_bucket_id: Optional[int] = None,
|
| 495 |
+
# VAE tiling
|
| 496 |
+
tiled: Optional[bool] = True,
|
| 497 |
+
tile_size: Optional[tuple[int, int]] = (30, 52),
|
| 498 |
+
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
| 499 |
+
# Sliding window
|
| 500 |
+
sliding_window_size: Optional[int] = None,
|
| 501 |
+
sliding_window_stride: Optional[int] = None,
|
| 502 |
+
# Teacache
|
| 503 |
+
tea_cache_l1_thresh: Optional[float] = None,
|
| 504 |
+
tea_cache_model_id: Optional[str] = "",
|
| 505 |
+
# progress_bar
|
| 506 |
+
progress_bar_cmd=tqdm,
|
| 507 |
+
# Stand-In
|
| 508 |
+
face_mask=None,
|
| 509 |
+
ip_image=None,
|
| 510 |
+
force_background_consistency=False
|
| 511 |
+
):
|
| 512 |
+
if ip_image is not None:
|
| 513 |
+
ip_image = self.encode_ip_image(ip_image)
|
| 514 |
+
# Scheduler
|
| 515 |
+
self.scheduler.set_timesteps(
|
| 516 |
+
num_inference_steps,
|
| 517 |
+
denoising_strength=denoising_strength,
|
| 518 |
+
shift=sigma_shift,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# Inputs
|
| 522 |
+
inputs_posi = {
|
| 523 |
+
"prompt": prompt,
|
| 524 |
+
"tea_cache_l1_thresh": tea_cache_l1_thresh,
|
| 525 |
+
"tea_cache_model_id": tea_cache_model_id,
|
| 526 |
+
"num_inference_steps": num_inference_steps,
|
| 527 |
+
}
|
| 528 |
+
inputs_nega = {
|
| 529 |
+
"negative_prompt": negative_prompt,
|
| 530 |
+
"tea_cache_l1_thresh": tea_cache_l1_thresh,
|
| 531 |
+
"tea_cache_model_id": tea_cache_model_id,
|
| 532 |
+
"num_inference_steps": num_inference_steps,
|
| 533 |
+
}
|
| 534 |
+
inputs_shared = {
|
| 535 |
+
"input_image": input_image,
|
| 536 |
+
"end_image": end_image,
|
| 537 |
+
"input_video": input_video,
|
| 538 |
+
"denoising_strength": denoising_strength,
|
| 539 |
+
"control_video": control_video,
|
| 540 |
+
"reference_image": reference_image,
|
| 541 |
+
"camera_control_direction": camera_control_direction,
|
| 542 |
+
"camera_control_speed": camera_control_speed,
|
| 543 |
+
"camera_control_origin": camera_control_origin,
|
| 544 |
+
"vace_video": vace_video,
|
| 545 |
+
"vace_video_mask": vace_video_mask,
|
| 546 |
+
"vace_reference_image": vace_reference_image,
|
| 547 |
+
"vace_scale": vace_scale,
|
| 548 |
+
"seed": seed,
|
| 549 |
+
"rand_device": rand_device,
|
| 550 |
+
"height": height,
|
| 551 |
+
"width": width,
|
| 552 |
+
"num_frames": num_frames,
|
| 553 |
+
"cfg_scale": cfg_scale,
|
| 554 |
+
"cfg_merge": cfg_merge,
|
| 555 |
+
"sigma_shift": sigma_shift,
|
| 556 |
+
"motion_bucket_id": motion_bucket_id,
|
| 557 |
+
"tiled": tiled,
|
| 558 |
+
"tile_size": tile_size,
|
| 559 |
+
"tile_stride": tile_stride,
|
| 560 |
+
"sliding_window_size": sliding_window_size,
|
| 561 |
+
"sliding_window_stride": sliding_window_stride,
|
| 562 |
+
"ip_image": ip_image,
|
| 563 |
+
}
|
| 564 |
+
for unit in self.units:
|
| 565 |
+
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
| 566 |
+
unit, self, inputs_shared, inputs_posi, inputs_nega
|
| 567 |
+
)
|
| 568 |
+
if face_mask is not None:
|
| 569 |
+
mask_processed = self.preprocess_video(face_mask)
|
| 570 |
+
mask_processed = mask_processed[:, 0:1, ...]
|
| 571 |
+
latent_mask = torch.nn.functional.interpolate(
|
| 572 |
+
mask_processed,
|
| 573 |
+
size=inputs_shared["latents"].shape[2:],
|
| 574 |
+
mode="nearest-exact",
|
| 575 |
+
)
|
| 576 |
+
# Denoise
|
| 577 |
+
self.load_models_to_device(self.in_iteration_models)
|
| 578 |
+
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
| 579 |
+
for progress_id, timestep in enumerate(
|
| 580 |
+
progress_bar_cmd(self.scheduler.timesteps)
|
| 581 |
+
):
|
| 582 |
+
# Switch DiT if necessary
|
| 583 |
+
if (
|
| 584 |
+
timestep.item()
|
| 585 |
+
< switch_DiT_boundary * self.scheduler.num_train_timesteps
|
| 586 |
+
and self.dit2 is not None
|
| 587 |
+
and not models["dit"] is self.dit2
|
| 588 |
+
):
|
| 589 |
+
self.load_models_to_device(self.in_iteration_models_2)
|
| 590 |
+
models["dit"] = self.dit2
|
| 591 |
+
|
| 592 |
+
# Timestep
|
| 593 |
+
timestep = timestep.unsqueeze(0).to(
|
| 594 |
+
dtype=self.torch_dtype, device=self.device
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# Inference
|
| 598 |
+
noise_pred_posi = self.model_fn(
|
| 599 |
+
**models, **inputs_shared, **inputs_posi, timestep=timestep
|
| 600 |
+
)
|
| 601 |
+
inputs_shared["ip_image"] = None
|
| 602 |
+
if cfg_scale != 1.0:
|
| 603 |
+
if cfg_merge:
|
| 604 |
+
noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
|
| 605 |
+
else:
|
| 606 |
+
noise_pred_nega = self.model_fn(
|
| 607 |
+
**models, **inputs_shared, **inputs_nega, timestep=timestep
|
| 608 |
+
)
|
| 609 |
+
noise_pred = noise_pred_nega + cfg_scale * (
|
| 610 |
+
noise_pred_posi - noise_pred_nega
|
| 611 |
+
)
|
| 612 |
+
else:
|
| 613 |
+
noise_pred = noise_pred_posi
|
| 614 |
+
|
| 615 |
+
# Scheduler
|
| 616 |
+
inputs_shared["latents"] = self.scheduler.step(
|
| 617 |
+
noise_pred,
|
| 618 |
+
self.scheduler.timesteps[progress_id],
|
| 619 |
+
inputs_shared["latents"],
|
| 620 |
+
)
|
| 621 |
+
if force_background_consistency:
|
| 622 |
+
if (
|
| 623 |
+
inputs_shared["input_latents"] is not None
|
| 624 |
+
and latent_mask is not None
|
| 625 |
+
):
|
| 626 |
+
if progress_id == len(self.scheduler.timesteps) - 1:
|
| 627 |
+
noised_original_latents = inputs_shared["input_latents"]
|
| 628 |
+
else:
|
| 629 |
+
next_timestep = self.scheduler.timesteps[progress_id + 1]
|
| 630 |
+
noised_original_latents = self.scheduler.add_noise(
|
| 631 |
+
inputs_shared["input_latents"],
|
| 632 |
+
inputs_shared["noise"],
|
| 633 |
+
timestep=next_timestep,
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
hard_mask = (latent_mask > 0.5).to(
|
| 637 |
+
dtype=inputs_shared["latents"].dtype
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
inputs_shared["latents"] = (
|
| 641 |
+
1 - hard_mask
|
| 642 |
+
) * noised_original_latents + hard_mask * inputs_shared["latents"]
|
| 643 |
+
|
| 644 |
+
if "first_frame_latents" in inputs_shared:
|
| 645 |
+
inputs_shared["latents"][:, :, 0:1] = inputs_shared[
|
| 646 |
+
"first_frame_latents"
|
| 647 |
+
]
|
| 648 |
+
|
| 649 |
+
if vace_reference_image is not None:
|
| 650 |
+
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
| 651 |
+
|
| 652 |
+
# Decode
|
| 653 |
+
self.load_models_to_device(["vae"])
|
| 654 |
+
video = self.vae.decode(
|
| 655 |
+
inputs_shared["latents"],
|
| 656 |
+
device=self.device,
|
| 657 |
+
tiled=tiled,
|
| 658 |
+
tile_size=tile_size,
|
| 659 |
+
tile_stride=tile_stride,
|
| 660 |
+
)
|
| 661 |
+
video = self.vae_output_to_video(video)
|
| 662 |
+
self.load_models_to_device([])
|
| 663 |
+
|
| 664 |
+
return video
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
class WanVideoUnit_ShapeChecker(PipelineUnit):
|
| 668 |
+
def __init__(self):
|
| 669 |
+
super().__init__(input_params=("height", "width", "num_frames"))
|
| 670 |
+
|
| 671 |
+
def process(self, pipe: WanVideoPipeline_FaceSwap, height, width, num_frames):
|
| 672 |
+
height, width, num_frames = pipe.check_resize_height_width(
|
| 673 |
+
height, width, num_frames
|
| 674 |
+
)
|
| 675 |
+
return {"height": height, "width": width, "num_frames": num_frames}
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
| 679 |
+
def __init__(self):
|
| 680 |
+
super().__init__(
|
| 681 |
+
input_params=(
|
| 682 |
+
"height",
|
| 683 |
+
"width",
|
| 684 |
+
"num_frames",
|
| 685 |
+
"seed",
|
| 686 |
+
"rand_device",
|
| 687 |
+
"vace_reference_image",
|
| 688 |
+
)
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
def process(
|
| 692 |
+
self,
|
| 693 |
+
pipe: WanVideoPipeline_FaceSwap,
|
| 694 |
+
height,
|
| 695 |
+
width,
|
| 696 |
+
num_frames,
|
| 697 |
+
seed,
|
| 698 |
+
rand_device,
|
| 699 |
+
vace_reference_image,
|
| 700 |
+
):
|
| 701 |
+
length = (num_frames - 1) // 4 + 1
|
| 702 |
+
if vace_reference_image is not None:
|
| 703 |
+
length += 1
|
| 704 |
+
shape = (
|
| 705 |
+
1,
|
| 706 |
+
pipe.vae.model.z_dim,
|
| 707 |
+
length,
|
| 708 |
+
height // pipe.vae.upsampling_factor,
|
| 709 |
+
width // pipe.vae.upsampling_factor,
|
| 710 |
+
)
|
| 711 |
+
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
|
| 712 |
+
if vace_reference_image is not None:
|
| 713 |
+
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
|
| 714 |
+
return {"noise": noise}
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
|
| 718 |
+
def __init__(self):
|
| 719 |
+
super().__init__(
|
| 720 |
+
input_params=(
|
| 721 |
+
"input_video",
|
| 722 |
+
"noise",
|
| 723 |
+
"tiled",
|
| 724 |
+
"tile_size",
|
| 725 |
+
"tile_stride",
|
| 726 |
+
"vace_reference_image",
|
| 727 |
+
),
|
| 728 |
+
onload_model_names=("vae",),
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
def process(
|
| 732 |
+
self,
|
| 733 |
+
pipe: WanVideoPipeline_FaceSwap,
|
| 734 |
+
input_video,
|
| 735 |
+
noise,
|
| 736 |
+
tiled,
|
| 737 |
+
tile_size,
|
| 738 |
+
tile_stride,
|
| 739 |
+
vace_reference_image,
|
| 740 |
+
):
|
| 741 |
+
if input_video is None:
|
| 742 |
+
return {"latents": noise}
|
| 743 |
+
pipe.load_models_to_device(["vae"])
|
| 744 |
+
input_video = pipe.preprocess_video(input_video)
|
| 745 |
+
input_latents = pipe.vae.encode(
|
| 746 |
+
input_video,
|
| 747 |
+
device=pipe.device,
|
| 748 |
+
tiled=tiled,
|
| 749 |
+
tile_size=tile_size,
|
| 750 |
+
tile_stride=tile_stride,
|
| 751 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 752 |
+
if vace_reference_image is not None:
|
| 753 |
+
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
| 754 |
+
vace_reference_latents = pipe.vae.encode(
|
| 755 |
+
vace_reference_image, device=pipe.device
|
| 756 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 757 |
+
input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
|
| 758 |
+
if pipe.scheduler.training:
|
| 759 |
+
return {"latents": noise, "input_latents": input_latents}
|
| 760 |
+
else:
|
| 761 |
+
latents = pipe.scheduler.add_noise(
|
| 762 |
+
input_latents, noise, timestep=pipe.scheduler.timesteps[0]
|
| 763 |
+
)
|
| 764 |
+
return {"latents": latents, "input_latents": input_latents}
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
class WanVideoUnit_PromptEmbedder(PipelineUnit):
|
| 768 |
+
def __init__(self):
|
| 769 |
+
super().__init__(
|
| 770 |
+
seperate_cfg=True,
|
| 771 |
+
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
| 772 |
+
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
|
| 773 |
+
onload_model_names=("text_encoder",),
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
def process(self, pipe: WanVideoPipeline_FaceSwap, prompt, positive) -> dict:
|
| 777 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 778 |
+
prompt_emb = pipe.prompter.encode_prompt(
|
| 779 |
+
prompt, positive=positive, device=pipe.device
|
| 780 |
+
)
|
| 781 |
+
return {"context": prompt_emb}
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
| 785 |
+
"""
|
| 786 |
+
Deprecated
|
| 787 |
+
"""
|
| 788 |
+
|
| 789 |
+
def __init__(self):
|
| 790 |
+
super().__init__(
|
| 791 |
+
input_params=(
|
| 792 |
+
"input_image",
|
| 793 |
+
"end_image",
|
| 794 |
+
"num_frames",
|
| 795 |
+
"height",
|
| 796 |
+
"width",
|
| 797 |
+
"tiled",
|
| 798 |
+
"tile_size",
|
| 799 |
+
"tile_stride",
|
| 800 |
+
),
|
| 801 |
+
onload_model_names=("image_encoder", "vae"),
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
def process(
|
| 805 |
+
self,
|
| 806 |
+
pipe: WanVideoPipeline_FaceSwap,
|
| 807 |
+
input_image,
|
| 808 |
+
end_image,
|
| 809 |
+
num_frames,
|
| 810 |
+
height,
|
| 811 |
+
width,
|
| 812 |
+
tiled,
|
| 813 |
+
tile_size,
|
| 814 |
+
tile_stride,
|
| 815 |
+
):
|
| 816 |
+
if input_image is None or pipe.image_encoder is None:
|
| 817 |
+
return {}
|
| 818 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 819 |
+
image = pipe.preprocess_image(input_image.resize((width, height))).to(
|
| 820 |
+
pipe.device
|
| 821 |
+
)
|
| 822 |
+
clip_context = pipe.image_encoder.encode_image([image])
|
| 823 |
+
msk = torch.ones(1, num_frames, height // 8, width // 8, device=pipe.device)
|
| 824 |
+
msk[:, 1:] = 0
|
| 825 |
+
if end_image is not None:
|
| 826 |
+
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
|
| 827 |
+
pipe.device
|
| 828 |
+
)
|
| 829 |
+
vae_input = torch.concat(
|
| 830 |
+
[
|
| 831 |
+
image.transpose(0, 1),
|
| 832 |
+
torch.zeros(3, num_frames - 2, height, width).to(image.device),
|
| 833 |
+
end_image.transpose(0, 1),
|
| 834 |
+
],
|
| 835 |
+
dim=1,
|
| 836 |
+
)
|
| 837 |
+
if pipe.dit.has_image_pos_emb:
|
| 838 |
+
clip_context = torch.concat(
|
| 839 |
+
[clip_context, pipe.image_encoder.encode_image([end_image])], dim=1
|
| 840 |
+
)
|
| 841 |
+
msk[:, -1:] = 1
|
| 842 |
+
else:
|
| 843 |
+
vae_input = torch.concat(
|
| 844 |
+
[
|
| 845 |
+
image.transpose(0, 1),
|
| 846 |
+
torch.zeros(3, num_frames - 1, height, width).to(image.device),
|
| 847 |
+
],
|
| 848 |
+
dim=1,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
msk = torch.concat(
|
| 852 |
+
[torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1
|
| 853 |
+
)
|
| 854 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
|
| 855 |
+
msk = msk.transpose(1, 2)[0]
|
| 856 |
+
|
| 857 |
+
y = pipe.vae.encode(
|
| 858 |
+
[vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)],
|
| 859 |
+
device=pipe.device,
|
| 860 |
+
tiled=tiled,
|
| 861 |
+
tile_size=tile_size,
|
| 862 |
+
tile_stride=tile_stride,
|
| 863 |
+
)[0]
|
| 864 |
+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 865 |
+
y = torch.concat([msk, y])
|
| 866 |
+
y = y.unsqueeze(0)
|
| 867 |
+
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 868 |
+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 869 |
+
return {"clip_feature": clip_context, "y": y}
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
|
| 873 |
+
def __init__(self):
|
| 874 |
+
super().__init__(
|
| 875 |
+
input_params=("input_image", "end_image", "height", "width"),
|
| 876 |
+
onload_model_names=("image_encoder",),
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
def process(
|
| 880 |
+
self, pipe: WanVideoPipeline_FaceSwap, input_image, end_image, height, width
|
| 881 |
+
):
|
| 882 |
+
if (
|
| 883 |
+
input_image is None
|
| 884 |
+
or pipe.image_encoder is None
|
| 885 |
+
or not pipe.dit.require_clip_embedding
|
| 886 |
+
):
|
| 887 |
+
return {}
|
| 888 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 889 |
+
image = pipe.preprocess_image(input_image.resize((width, height))).to(
|
| 890 |
+
pipe.device
|
| 891 |
+
)
|
| 892 |
+
clip_context = pipe.image_encoder.encode_image([image])
|
| 893 |
+
if end_image is not None:
|
| 894 |
+
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
|
| 895 |
+
pipe.device
|
| 896 |
+
)
|
| 897 |
+
if pipe.dit.has_image_pos_emb:
|
| 898 |
+
clip_context = torch.concat(
|
| 899 |
+
[clip_context, pipe.image_encoder.encode_image([end_image])], dim=1
|
| 900 |
+
)
|
| 901 |
+
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 902 |
+
return {"clip_feature": clip_context}
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
|
| 906 |
+
def __init__(self):
|
| 907 |
+
super().__init__(
|
| 908 |
+
input_params=(
|
| 909 |
+
"input_image",
|
| 910 |
+
"end_image",
|
| 911 |
+
"num_frames",
|
| 912 |
+
"height",
|
| 913 |
+
"width",
|
| 914 |
+
"tiled",
|
| 915 |
+
"tile_size",
|
| 916 |
+
"tile_stride",
|
| 917 |
+
),
|
| 918 |
+
onload_model_names=("vae",),
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
def process(
|
| 922 |
+
self,
|
| 923 |
+
pipe: WanVideoPipeline_FaceSwap,
|
| 924 |
+
input_image,
|
| 925 |
+
end_image,
|
| 926 |
+
num_frames,
|
| 927 |
+
height,
|
| 928 |
+
width,
|
| 929 |
+
tiled,
|
| 930 |
+
tile_size,
|
| 931 |
+
tile_stride,
|
| 932 |
+
):
|
| 933 |
+
if input_image is None or not pipe.dit.require_vae_embedding:
|
| 934 |
+
return {}
|
| 935 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 936 |
+
image = pipe.preprocess_image(input_image.resize((width, height))).to(
|
| 937 |
+
pipe.device
|
| 938 |
+
)
|
| 939 |
+
msk = torch.ones(1, num_frames, height // 8, width // 8, device=pipe.device)
|
| 940 |
+
msk[:, 1:] = 0
|
| 941 |
+
if end_image is not None:
|
| 942 |
+
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(
|
| 943 |
+
pipe.device
|
| 944 |
+
)
|
| 945 |
+
vae_input = torch.concat(
|
| 946 |
+
[
|
| 947 |
+
image.transpose(0, 1),
|
| 948 |
+
torch.zeros(3, num_frames - 2, height, width).to(image.device),
|
| 949 |
+
end_image.transpose(0, 1),
|
| 950 |
+
],
|
| 951 |
+
dim=1,
|
| 952 |
+
)
|
| 953 |
+
msk[:, -1:] = 1
|
| 954 |
+
else:
|
| 955 |
+
vae_input = torch.concat(
|
| 956 |
+
[
|
| 957 |
+
image.transpose(0, 1),
|
| 958 |
+
torch.zeros(3, num_frames - 1, height, width).to(image.device),
|
| 959 |
+
],
|
| 960 |
+
dim=1,
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
msk = torch.concat(
|
| 964 |
+
[torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1
|
| 965 |
+
)
|
| 966 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8)
|
| 967 |
+
msk = msk.transpose(1, 2)[0]
|
| 968 |
+
|
| 969 |
+
y = pipe.vae.encode(
|
| 970 |
+
[vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)],
|
| 971 |
+
device=pipe.device,
|
| 972 |
+
tiled=tiled,
|
| 973 |
+
tile_size=tile_size,
|
| 974 |
+
tile_stride=tile_stride,
|
| 975 |
+
)[0]
|
| 976 |
+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 977 |
+
y = torch.concat([msk, y])
|
| 978 |
+
y = y.unsqueeze(0)
|
| 979 |
+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 980 |
+
return {"y": y}
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
|
| 984 |
+
"""
|
| 985 |
+
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
|
| 986 |
+
"""
|
| 987 |
+
|
| 988 |
+
def __init__(self):
|
| 989 |
+
super().__init__(
|
| 990 |
+
input_params=(
|
| 991 |
+
"input_image",
|
| 992 |
+
"latents",
|
| 993 |
+
"height",
|
| 994 |
+
"width",
|
| 995 |
+
"tiled",
|
| 996 |
+
"tile_size",
|
| 997 |
+
"tile_stride",
|
| 998 |
+
),
|
| 999 |
+
onload_model_names=("vae",),
|
| 1000 |
+
)
|
| 1001 |
+
|
| 1002 |
+
def process(
|
| 1003 |
+
self,
|
| 1004 |
+
pipe: WanVideoPipeline_FaceSwap,
|
| 1005 |
+
input_image,
|
| 1006 |
+
latents,
|
| 1007 |
+
height,
|
| 1008 |
+
width,
|
| 1009 |
+
tiled,
|
| 1010 |
+
tile_size,
|
| 1011 |
+
tile_stride,
|
| 1012 |
+
):
|
| 1013 |
+
if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
|
| 1014 |
+
return {}
|
| 1015 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 1016 |
+
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(
|
| 1017 |
+
0, 1
|
| 1018 |
+
)
|
| 1019 |
+
z = pipe.vae.encode(
|
| 1020 |
+
[image],
|
| 1021 |
+
device=pipe.device,
|
| 1022 |
+
tiled=tiled,
|
| 1023 |
+
tile_size=tile_size,
|
| 1024 |
+
tile_stride=tile_stride,
|
| 1025 |
+
)
|
| 1026 |
+
latents[:, :, 0:1] = z
|
| 1027 |
+
return {
|
| 1028 |
+
"latents": latents,
|
| 1029 |
+
"fuse_vae_embedding_in_latents": True,
|
| 1030 |
+
"first_frame_latents": z,
|
| 1031 |
+
}
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
class WanVideoUnit_FunControl(PipelineUnit):
|
| 1035 |
+
def __init__(self):
|
| 1036 |
+
super().__init__(
|
| 1037 |
+
input_params=(
|
| 1038 |
+
"control_video",
|
| 1039 |
+
"num_frames",
|
| 1040 |
+
"height",
|
| 1041 |
+
"width",
|
| 1042 |
+
"tiled",
|
| 1043 |
+
"tile_size",
|
| 1044 |
+
"tile_stride",
|
| 1045 |
+
"clip_feature",
|
| 1046 |
+
"y",
|
| 1047 |
+
),
|
| 1048 |
+
onload_model_names=("vae",),
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
def process(
|
| 1052 |
+
self,
|
| 1053 |
+
pipe: WanVideoPipeline_FaceSwap,
|
| 1054 |
+
control_video,
|
| 1055 |
+
num_frames,
|
| 1056 |
+
height,
|
| 1057 |
+
width,
|
| 1058 |
+
tiled,
|
| 1059 |
+
tile_size,
|
| 1060 |
+
tile_stride,
|
| 1061 |
+
clip_feature,
|
| 1062 |
+
y,
|
| 1063 |
+
):
|
| 1064 |
+
if control_video is None:
|
| 1065 |
+
return {}
|
| 1066 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 1067 |
+
control_video = pipe.preprocess_video(control_video)
|
| 1068 |
+
control_latents = pipe.vae.encode(
|
| 1069 |
+
control_video,
|
| 1070 |
+
device=pipe.device,
|
| 1071 |
+
tiled=tiled,
|
| 1072 |
+
tile_size=tile_size,
|
| 1073 |
+
tile_stride=tile_stride,
|
| 1074 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1075 |
+
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1076 |
+
if clip_feature is None or y is None:
|
| 1077 |
+
clip_feature = torch.zeros(
|
| 1078 |
+
(1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device
|
| 1079 |
+
)
|
| 1080 |
+
y = torch.zeros(
|
| 1081 |
+
(1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8),
|
| 1082 |
+
dtype=pipe.torch_dtype,
|
| 1083 |
+
device=pipe.device,
|
| 1084 |
+
)
|
| 1085 |
+
else:
|
| 1086 |
+
y = y[:, -16:]
|
| 1087 |
+
y = torch.concat([control_latents, y], dim=1)
|
| 1088 |
+
return {"clip_feature": clip_feature, "y": y}
|
| 1089 |
+
|
| 1090 |
+
|
| 1091 |
+
class WanVideoUnit_FunReference(PipelineUnit):
|
| 1092 |
+
def __init__(self):
|
| 1093 |
+
super().__init__(
|
| 1094 |
+
input_params=("reference_image", "height", "width", "reference_image"),
|
| 1095 |
+
onload_model_names=("vae",),
|
| 1096 |
+
)
|
| 1097 |
+
|
| 1098 |
+
def process(self, pipe: WanVideoPipeline_FaceSwap, reference_image, height, width):
|
| 1099 |
+
if reference_image is None:
|
| 1100 |
+
return {}
|
| 1101 |
+
pipe.load_models_to_device(["vae"])
|
| 1102 |
+
reference_image = reference_image.resize((width, height))
|
| 1103 |
+
reference_latents = pipe.preprocess_video([reference_image])
|
| 1104 |
+
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
|
| 1105 |
+
clip_feature = pipe.preprocess_image(reference_image)
|
| 1106 |
+
clip_feature = pipe.image_encoder.encode_image([clip_feature])
|
| 1107 |
+
return {"reference_latents": reference_latents, "clip_feature": clip_feature}
|
| 1108 |
+
|
| 1109 |
+
|
| 1110 |
+
class WanVideoUnit_FunCameraControl(PipelineUnit):
|
| 1111 |
+
def __init__(self):
|
| 1112 |
+
super().__init__(
|
| 1113 |
+
input_params=(
|
| 1114 |
+
"height",
|
| 1115 |
+
"width",
|
| 1116 |
+
"num_frames",
|
| 1117 |
+
"camera_control_direction",
|
| 1118 |
+
"camera_control_speed",
|
| 1119 |
+
"camera_control_origin",
|
| 1120 |
+
"latents",
|
| 1121 |
+
"input_image",
|
| 1122 |
+
),
|
| 1123 |
+
onload_model_names=("vae",),
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
def process(
|
| 1127 |
+
self,
|
| 1128 |
+
pipe: WanVideoPipeline_FaceSwap,
|
| 1129 |
+
height,
|
| 1130 |
+
width,
|
| 1131 |
+
num_frames,
|
| 1132 |
+
camera_control_direction,
|
| 1133 |
+
camera_control_speed,
|
| 1134 |
+
camera_control_origin,
|
| 1135 |
+
latents,
|
| 1136 |
+
input_image,
|
| 1137 |
+
):
|
| 1138 |
+
if camera_control_direction is None:
|
| 1139 |
+
return {}
|
| 1140 |
+
camera_control_plucker_embedding = (
|
| 1141 |
+
pipe.dit.control_adapter.process_camera_coordinates(
|
| 1142 |
+
camera_control_direction,
|
| 1143 |
+
num_frames,
|
| 1144 |
+
height,
|
| 1145 |
+
width,
|
| 1146 |
+
camera_control_speed,
|
| 1147 |
+
camera_control_origin,
|
| 1148 |
+
)
|
| 1149 |
+
)
|
| 1150 |
+
|
| 1151 |
+
control_camera_video = (
|
| 1152 |
+
camera_control_plucker_embedding[:num_frames]
|
| 1153 |
+
.permute([3, 0, 1, 2])
|
| 1154 |
+
.unsqueeze(0)
|
| 1155 |
+
)
|
| 1156 |
+
control_camera_latents = torch.concat(
|
| 1157 |
+
[
|
| 1158 |
+
torch.repeat_interleave(
|
| 1159 |
+
control_camera_video[:, :, 0:1], repeats=4, dim=2
|
| 1160 |
+
),
|
| 1161 |
+
control_camera_video[:, :, 1:],
|
| 1162 |
+
],
|
| 1163 |
+
dim=2,
|
| 1164 |
+
).transpose(1, 2)
|
| 1165 |
+
b, f, c, h, w = control_camera_latents.shape
|
| 1166 |
+
control_camera_latents = (
|
| 1167 |
+
control_camera_latents.contiguous()
|
| 1168 |
+
.view(b, f // 4, 4, c, h, w)
|
| 1169 |
+
.transpose(2, 3)
|
| 1170 |
+
)
|
| 1171 |
+
control_camera_latents = (
|
| 1172 |
+
control_camera_latents.contiguous()
|
| 1173 |
+
.view(b, f // 4, c * 4, h, w)
|
| 1174 |
+
.transpose(1, 2)
|
| 1175 |
+
)
|
| 1176 |
+
control_camera_latents_input = control_camera_latents.to(
|
| 1177 |
+
device=pipe.device, dtype=pipe.torch_dtype
|
| 1178 |
+
)
|
| 1179 |
+
|
| 1180 |
+
input_image = input_image.resize((width, height))
|
| 1181 |
+
input_latents = pipe.preprocess_video([input_image])
|
| 1182 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 1183 |
+
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
|
| 1184 |
+
y = torch.zeros_like(latents).to(pipe.device)
|
| 1185 |
+
y[:, :, :1] = input_latents
|
| 1186 |
+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1187 |
+
return {"control_camera_latents_input": control_camera_latents_input, "y": y}
|
| 1188 |
+
|
| 1189 |
+
|
| 1190 |
+
class WanVideoUnit_SpeedControl(PipelineUnit):
|
| 1191 |
+
def __init__(self):
|
| 1192 |
+
super().__init__(input_params=("motion_bucket_id",))
|
| 1193 |
+
|
| 1194 |
+
def process(self, pipe: WanVideoPipeline_FaceSwap, motion_bucket_id):
|
| 1195 |
+
if motion_bucket_id is None:
|
| 1196 |
+
return {}
|
| 1197 |
+
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(
|
| 1198 |
+
dtype=pipe.torch_dtype, device=pipe.device
|
| 1199 |
+
)
|
| 1200 |
+
return {"motion_bucket_id": motion_bucket_id}
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
class WanVideoUnit_VACE(PipelineUnit):
|
| 1204 |
+
def __init__(self):
|
| 1205 |
+
super().__init__(
|
| 1206 |
+
input_params=(
|
| 1207 |
+
"vace_video",
|
| 1208 |
+
"vace_video_mask",
|
| 1209 |
+
"vace_reference_image",
|
| 1210 |
+
"vace_scale",
|
| 1211 |
+
"height",
|
| 1212 |
+
"width",
|
| 1213 |
+
"num_frames",
|
| 1214 |
+
"tiled",
|
| 1215 |
+
"tile_size",
|
| 1216 |
+
"tile_stride",
|
| 1217 |
+
),
|
| 1218 |
+
onload_model_names=("vae",),
|
| 1219 |
+
)
|
| 1220 |
+
|
| 1221 |
+
def process(
|
| 1222 |
+
self,
|
| 1223 |
+
pipe: WanVideoPipeline_FaceSwap,
|
| 1224 |
+
vace_video,
|
| 1225 |
+
vace_video_mask,
|
| 1226 |
+
vace_reference_image,
|
| 1227 |
+
vace_scale,
|
| 1228 |
+
height,
|
| 1229 |
+
width,
|
| 1230 |
+
num_frames,
|
| 1231 |
+
tiled,
|
| 1232 |
+
tile_size,
|
| 1233 |
+
tile_stride,
|
| 1234 |
+
):
|
| 1235 |
+
if (
|
| 1236 |
+
vace_video is not None
|
| 1237 |
+
or vace_video_mask is not None
|
| 1238 |
+
or vace_reference_image is not None
|
| 1239 |
+
):
|
| 1240 |
+
pipe.load_models_to_device(["vae"])
|
| 1241 |
+
if vace_video is None:
|
| 1242 |
+
vace_video = torch.zeros(
|
| 1243 |
+
(1, 3, num_frames, height, width),
|
| 1244 |
+
dtype=pipe.torch_dtype,
|
| 1245 |
+
device=pipe.device,
|
| 1246 |
+
)
|
| 1247 |
+
else:
|
| 1248 |
+
vace_video = pipe.preprocess_video(vace_video)
|
| 1249 |
+
|
| 1250 |
+
if vace_video_mask is None:
|
| 1251 |
+
vace_video_mask = torch.ones_like(vace_video)
|
| 1252 |
+
else:
|
| 1253 |
+
vace_video_mask = pipe.preprocess_video(
|
| 1254 |
+
vace_video_mask, min_value=0, max_value=1
|
| 1255 |
+
)
|
| 1256 |
+
|
| 1257 |
+
inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
|
| 1258 |
+
reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
|
| 1259 |
+
inactive = pipe.vae.encode(
|
| 1260 |
+
inactive,
|
| 1261 |
+
device=pipe.device,
|
| 1262 |
+
tiled=tiled,
|
| 1263 |
+
tile_size=tile_size,
|
| 1264 |
+
tile_stride=tile_stride,
|
| 1265 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1266 |
+
reactive = pipe.vae.encode(
|
| 1267 |
+
reactive,
|
| 1268 |
+
device=pipe.device,
|
| 1269 |
+
tiled=tiled,
|
| 1270 |
+
tile_size=tile_size,
|
| 1271 |
+
tile_stride=tile_stride,
|
| 1272 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1273 |
+
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
| 1274 |
+
|
| 1275 |
+
vace_mask_latents = rearrange(
|
| 1276 |
+
vace_video_mask[0, 0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8
|
| 1277 |
+
)
|
| 1278 |
+
vace_mask_latents = torch.nn.functional.interpolate(
|
| 1279 |
+
vace_mask_latents,
|
| 1280 |
+
size=(
|
| 1281 |
+
(vace_mask_latents.shape[2] + 3) // 4,
|
| 1282 |
+
vace_mask_latents.shape[3],
|
| 1283 |
+
vace_mask_latents.shape[4],
|
| 1284 |
+
),
|
| 1285 |
+
mode="nearest-exact",
|
| 1286 |
+
)
|
| 1287 |
+
|
| 1288 |
+
if vace_reference_image is None:
|
| 1289 |
+
pass
|
| 1290 |
+
else:
|
| 1291 |
+
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
| 1292 |
+
vace_reference_latents = pipe.vae.encode(
|
| 1293 |
+
vace_reference_image,
|
| 1294 |
+
device=pipe.device,
|
| 1295 |
+
tiled=tiled,
|
| 1296 |
+
tile_size=tile_size,
|
| 1297 |
+
tile_stride=tile_stride,
|
| 1298 |
+
).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 1299 |
+
vace_reference_latents = torch.concat(
|
| 1300 |
+
(vace_reference_latents, torch.zeros_like(vace_reference_latents)),
|
| 1301 |
+
dim=1,
|
| 1302 |
+
)
|
| 1303 |
+
vace_video_latents = torch.concat(
|
| 1304 |
+
(vace_reference_latents, vace_video_latents), dim=2
|
| 1305 |
+
)
|
| 1306 |
+
vace_mask_latents = torch.concat(
|
| 1307 |
+
(torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents),
|
| 1308 |
+
dim=2,
|
| 1309 |
+
)
|
| 1310 |
+
|
| 1311 |
+
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
|
| 1312 |
+
return {"vace_context": vace_context, "vace_scale": vace_scale}
|
| 1313 |
+
else:
|
| 1314 |
+
return {"vace_context": None, "vace_scale": vace_scale}
|
| 1315 |
+
|
| 1316 |
+
|
| 1317 |
+
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
|
| 1318 |
+
def __init__(self):
|
| 1319 |
+
super().__init__(input_params=())
|
| 1320 |
+
|
| 1321 |
+
def process(self, pipe: WanVideoPipeline_FaceSwap):
|
| 1322 |
+
if hasattr(pipe, "use_unified_sequence_parallel"):
|
| 1323 |
+
if pipe.use_unified_sequence_parallel:
|
| 1324 |
+
return {"use_unified_sequence_parallel": True}
|
| 1325 |
+
return {}
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
class WanVideoUnit_TeaCache(PipelineUnit):
|
| 1329 |
+
def __init__(self):
|
| 1330 |
+
super().__init__(
|
| 1331 |
+
seperate_cfg=True,
|
| 1332 |
+
input_params_posi={
|
| 1333 |
+
"num_inference_steps": "num_inference_steps",
|
| 1334 |
+
"tea_cache_l1_thresh": "tea_cache_l1_thresh",
|
| 1335 |
+
"tea_cache_model_id": "tea_cache_model_id",
|
| 1336 |
+
},
|
| 1337 |
+
input_params_nega={
|
| 1338 |
+
"num_inference_steps": "num_inference_steps",
|
| 1339 |
+
"tea_cache_l1_thresh": "tea_cache_l1_thresh",
|
| 1340 |
+
"tea_cache_model_id": "tea_cache_model_id",
|
| 1341 |
+
},
|
| 1342 |
+
)
|
| 1343 |
+
|
| 1344 |
+
def process(
|
| 1345 |
+
self,
|
| 1346 |
+
pipe: WanVideoPipeline_FaceSwap,
|
| 1347 |
+
num_inference_steps,
|
| 1348 |
+
tea_cache_l1_thresh,
|
| 1349 |
+
tea_cache_model_id,
|
| 1350 |
+
):
|
| 1351 |
+
if tea_cache_l1_thresh is None:
|
| 1352 |
+
return {}
|
| 1353 |
+
return {
|
| 1354 |
+
"tea_cache": TeaCache(
|
| 1355 |
+
num_inference_steps,
|
| 1356 |
+
rel_l1_thresh=tea_cache_l1_thresh,
|
| 1357 |
+
model_id=tea_cache_model_id,
|
| 1358 |
+
)
|
| 1359 |
+
}
|
| 1360 |
+
|
| 1361 |
+
|
| 1362 |
+
class WanVideoUnit_CfgMerger(PipelineUnit):
|
| 1363 |
+
def __init__(self):
|
| 1364 |
+
super().__init__(take_over=True)
|
| 1365 |
+
self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"]
|
| 1366 |
+
|
| 1367 |
+
def process(
|
| 1368 |
+
self, pipe: WanVideoPipeline_FaceSwap, inputs_shared, inputs_posi, inputs_nega
|
| 1369 |
+
):
|
| 1370 |
+
if not inputs_shared["cfg_merge"]:
|
| 1371 |
+
return inputs_shared, inputs_posi, inputs_nega
|
| 1372 |
+
for name in self.concat_tensor_names:
|
| 1373 |
+
tensor_posi = inputs_posi.get(name)
|
| 1374 |
+
tensor_nega = inputs_nega.get(name)
|
| 1375 |
+
tensor_shared = inputs_shared.get(name)
|
| 1376 |
+
if tensor_posi is not None and tensor_nega is not None:
|
| 1377 |
+
inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0)
|
| 1378 |
+
elif tensor_shared is not None:
|
| 1379 |
+
inputs_shared[name] = torch.concat(
|
| 1380 |
+
(tensor_shared, tensor_shared), dim=0
|
| 1381 |
+
)
|
| 1382 |
+
inputs_posi.clear()
|
| 1383 |
+
inputs_nega.clear()
|
| 1384 |
+
return inputs_shared, inputs_posi, inputs_nega
|
| 1385 |
+
|
| 1386 |
+
|
| 1387 |
+
class TeaCache:
|
| 1388 |
+
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
| 1389 |
+
self.num_inference_steps = num_inference_steps
|
| 1390 |
+
self.step = 0
|
| 1391 |
+
self.accumulated_rel_l1_distance = 0
|
| 1392 |
+
self.previous_modulated_input = None
|
| 1393 |
+
self.rel_l1_thresh = rel_l1_thresh
|
| 1394 |
+
self.previous_residual = None
|
| 1395 |
+
self.previous_hidden_states = None
|
| 1396 |
+
|
| 1397 |
+
self.coefficients_dict = {
|
| 1398 |
+
"Wan2.1-T2V-1.3B": [
|
| 1399 |
+
-5.21862437e04,
|
| 1400 |
+
9.23041404e03,
|
| 1401 |
+
-5.28275948e02,
|
| 1402 |
+
1.36987616e01,
|
| 1403 |
+
-4.99875664e-02,
|
| 1404 |
+
],
|
| 1405 |
+
"Wan2.1-T2V-14B": [
|
| 1406 |
+
-3.03318725e05,
|
| 1407 |
+
4.90537029e04,
|
| 1408 |
+
-2.65530556e03,
|
| 1409 |
+
5.87365115e01,
|
| 1410 |
+
-3.15583525e-01,
|
| 1411 |
+
],
|
| 1412 |
+
"Wan2.1-I2V-14B-480P": [
|
| 1413 |
+
2.57151496e05,
|
| 1414 |
+
-3.54229917e04,
|
| 1415 |
+
1.40286849e03,
|
| 1416 |
+
-1.35890334e01,
|
| 1417 |
+
1.32517977e-01,
|
| 1418 |
+
],
|
| 1419 |
+
"Wan2.1-I2V-14B-720P": [
|
| 1420 |
+
8.10705460e03,
|
| 1421 |
+
2.13393892e03,
|
| 1422 |
+
-3.72934672e02,
|
| 1423 |
+
1.66203073e01,
|
| 1424 |
+
-4.17769401e-02,
|
| 1425 |
+
],
|
| 1426 |
+
}
|
| 1427 |
+
if model_id not in self.coefficients_dict:
|
| 1428 |
+
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
| 1429 |
+
raise ValueError(
|
| 1430 |
+
f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids})."
|
| 1431 |
+
)
|
| 1432 |
+
self.coefficients = self.coefficients_dict[model_id]
|
| 1433 |
+
|
| 1434 |
+
def check(self, dit: WanModel, x, t_mod):
|
| 1435 |
+
modulated_inp = t_mod.clone()
|
| 1436 |
+
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
| 1437 |
+
should_calc = True
|
| 1438 |
+
self.accumulated_rel_l1_distance = 0
|
| 1439 |
+
else:
|
| 1440 |
+
coefficients = self.coefficients
|
| 1441 |
+
rescale_func = np.poly1d(coefficients)
|
| 1442 |
+
self.accumulated_rel_l1_distance += rescale_func(
|
| 1443 |
+
(
|
| 1444 |
+
(modulated_inp - self.previous_modulated_input).abs().mean()
|
| 1445 |
+
/ self.previous_modulated_input.abs().mean()
|
| 1446 |
+
)
|
| 1447 |
+
.cpu()
|
| 1448 |
+
.item()
|
| 1449 |
+
)
|
| 1450 |
+
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
| 1451 |
+
should_calc = False
|
| 1452 |
+
else:
|
| 1453 |
+
should_calc = True
|
| 1454 |
+
self.accumulated_rel_l1_distance = 0
|
| 1455 |
+
self.previous_modulated_input = modulated_inp
|
| 1456 |
+
self.step += 1
|
| 1457 |
+
if self.step == self.num_inference_steps:
|
| 1458 |
+
self.step = 0
|
| 1459 |
+
if should_calc:
|
| 1460 |
+
self.previous_hidden_states = x.clone()
|
| 1461 |
+
return not should_calc
|
| 1462 |
+
|
| 1463 |
+
def store(self, hidden_states):
|
| 1464 |
+
self.previous_residual = hidden_states - self.previous_hidden_states
|
| 1465 |
+
self.previous_hidden_states = None
|
| 1466 |
+
|
| 1467 |
+
def update(self, hidden_states):
|
| 1468 |
+
hidden_states = hidden_states + self.previous_residual
|
| 1469 |
+
return hidden_states
|
| 1470 |
+
|
| 1471 |
+
|
| 1472 |
+
class TemporalTiler_BCTHW:
|
| 1473 |
+
def __init__(self):
|
| 1474 |
+
pass
|
| 1475 |
+
|
| 1476 |
+
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
| 1477 |
+
x = torch.ones((length,))
|
| 1478 |
+
if not left_bound:
|
| 1479 |
+
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
| 1480 |
+
if not right_bound:
|
| 1481 |
+
x[-border_width:] = torch.flip(
|
| 1482 |
+
(torch.arange(border_width) + 1) / border_width, dims=(0,)
|
| 1483 |
+
)
|
| 1484 |
+
return x
|
| 1485 |
+
|
| 1486 |
+
def build_mask(self, data, is_bound, border_width):
|
| 1487 |
+
_, _, T, _, _ = data.shape
|
| 1488 |
+
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
| 1489 |
+
mask = repeat(t, "T -> 1 1 T 1 1")
|
| 1490 |
+
return mask
|
| 1491 |
+
|
| 1492 |
+
def run(
|
| 1493 |
+
self,
|
| 1494 |
+
model_fn,
|
| 1495 |
+
sliding_window_size,
|
| 1496 |
+
sliding_window_stride,
|
| 1497 |
+
computation_device,
|
| 1498 |
+
computation_dtype,
|
| 1499 |
+
model_kwargs,
|
| 1500 |
+
tensor_names,
|
| 1501 |
+
batch_size=None,
|
| 1502 |
+
):
|
| 1503 |
+
tensor_names = [
|
| 1504 |
+
tensor_name
|
| 1505 |
+
for tensor_name in tensor_names
|
| 1506 |
+
if model_kwargs.get(tensor_name) is not None
|
| 1507 |
+
]
|
| 1508 |
+
tensor_dict = {
|
| 1509 |
+
tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names
|
| 1510 |
+
}
|
| 1511 |
+
B, C, T, H, W = tensor_dict[tensor_names[0]].shape
|
| 1512 |
+
if batch_size is not None:
|
| 1513 |
+
B *= batch_size
|
| 1514 |
+
data_device, data_dtype = (
|
| 1515 |
+
tensor_dict[tensor_names[0]].device,
|
| 1516 |
+
tensor_dict[tensor_names[0]].dtype,
|
| 1517 |
+
)
|
| 1518 |
+
value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)
|
| 1519 |
+
weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
|
| 1520 |
+
for t in range(0, T, sliding_window_stride):
|
| 1521 |
+
if (
|
| 1522 |
+
t - sliding_window_stride >= 0
|
| 1523 |
+
and t - sliding_window_stride + sliding_window_size >= T
|
| 1524 |
+
):
|
| 1525 |
+
continue
|
| 1526 |
+
t_ = min(t + sliding_window_size, T)
|
| 1527 |
+
model_kwargs.update(
|
| 1528 |
+
{
|
| 1529 |
+
tensor_name: tensor_dict[tensor_name][:, :, t:t_:, :].to(
|
| 1530 |
+
device=computation_device, dtype=computation_dtype
|
| 1531 |
+
)
|
| 1532 |
+
for tensor_name in tensor_names
|
| 1533 |
+
}
|
| 1534 |
+
)
|
| 1535 |
+
model_output = model_fn(**model_kwargs).to(
|
| 1536 |
+
device=data_device, dtype=data_dtype
|
| 1537 |
+
)
|
| 1538 |
+
mask = self.build_mask(
|
| 1539 |
+
model_output,
|
| 1540 |
+
is_bound=(t == 0, t_ == T),
|
| 1541 |
+
border_width=(sliding_window_size - sliding_window_stride,),
|
| 1542 |
+
).to(device=data_device, dtype=data_dtype)
|
| 1543 |
+
value[:, :, t:t_, :, :] += model_output * mask
|
| 1544 |
+
weight[:, :, t:t_, :, :] += mask
|
| 1545 |
+
value /= weight
|
| 1546 |
+
model_kwargs.update(tensor_dict)
|
| 1547 |
+
return value
|
| 1548 |
+
|
| 1549 |
+
|
| 1550 |
+
def model_fn_wan_video(
|
| 1551 |
+
dit: WanModel,
|
| 1552 |
+
motion_controller: WanMotionControllerModel = None,
|
| 1553 |
+
vace: VaceWanModel = None,
|
| 1554 |
+
latents: torch.Tensor = None,
|
| 1555 |
+
timestep: torch.Tensor = None,
|
| 1556 |
+
context: torch.Tensor = None,
|
| 1557 |
+
clip_feature: Optional[torch.Tensor] = None,
|
| 1558 |
+
y: Optional[torch.Tensor] = None,
|
| 1559 |
+
reference_latents=None,
|
| 1560 |
+
vace_context=None,
|
| 1561 |
+
vace_scale=1.0,
|
| 1562 |
+
tea_cache: TeaCache = None,
|
| 1563 |
+
use_unified_sequence_parallel: bool = False,
|
| 1564 |
+
motion_bucket_id: Optional[torch.Tensor] = None,
|
| 1565 |
+
sliding_window_size: Optional[int] = None,
|
| 1566 |
+
sliding_window_stride: Optional[int] = None,
|
| 1567 |
+
cfg_merge: bool = False,
|
| 1568 |
+
use_gradient_checkpointing: bool = False,
|
| 1569 |
+
use_gradient_checkpointing_offload: bool = False,
|
| 1570 |
+
control_camera_latents_input=None,
|
| 1571 |
+
fuse_vae_embedding_in_latents: bool = False,
|
| 1572 |
+
ip_image=None,
|
| 1573 |
+
**kwargs,
|
| 1574 |
+
):
|
| 1575 |
+
if sliding_window_size is not None and sliding_window_stride is not None:
|
| 1576 |
+
model_kwargs = dict(
|
| 1577 |
+
dit=dit,
|
| 1578 |
+
motion_controller=motion_controller,
|
| 1579 |
+
vace=vace,
|
| 1580 |
+
latents=latents,
|
| 1581 |
+
timestep=timestep,
|
| 1582 |
+
context=context,
|
| 1583 |
+
clip_feature=clip_feature,
|
| 1584 |
+
y=y,
|
| 1585 |
+
reference_latents=reference_latents,
|
| 1586 |
+
vace_context=vace_context,
|
| 1587 |
+
vace_scale=vace_scale,
|
| 1588 |
+
tea_cache=tea_cache,
|
| 1589 |
+
use_unified_sequence_parallel=use_unified_sequence_parallel,
|
| 1590 |
+
motion_bucket_id=motion_bucket_id,
|
| 1591 |
+
)
|
| 1592 |
+
return TemporalTiler_BCTHW().run(
|
| 1593 |
+
model_fn_wan_video,
|
| 1594 |
+
sliding_window_size,
|
| 1595 |
+
sliding_window_stride,
|
| 1596 |
+
latents.device,
|
| 1597 |
+
latents.dtype,
|
| 1598 |
+
model_kwargs=model_kwargs,
|
| 1599 |
+
tensor_names=["latents", "y"],
|
| 1600 |
+
batch_size=2 if cfg_merge else 1,
|
| 1601 |
+
)
|
| 1602 |
+
|
| 1603 |
+
if use_unified_sequence_parallel:
|
| 1604 |
+
import torch.distributed as dist
|
| 1605 |
+
from xfuser.core.distributed import (
|
| 1606 |
+
get_sequence_parallel_rank,
|
| 1607 |
+
get_sequence_parallel_world_size,
|
| 1608 |
+
get_sp_group,
|
| 1609 |
+
)
|
| 1610 |
+
x_ip = None
|
| 1611 |
+
t_mod_ip = None
|
| 1612 |
+
# Timestep
|
| 1613 |
+
if dit.seperated_timestep and fuse_vae_embedding_in_latents:
|
| 1614 |
+
timestep = torch.concat(
|
| 1615 |
+
[
|
| 1616 |
+
torch.zeros(
|
| 1617 |
+
(1, latents.shape[3] * latents.shape[4] // 4),
|
| 1618 |
+
dtype=latents.dtype,
|
| 1619 |
+
device=latents.device,
|
| 1620 |
+
),
|
| 1621 |
+
torch.ones(
|
| 1622 |
+
(latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4),
|
| 1623 |
+
dtype=latents.dtype,
|
| 1624 |
+
device=latents.device,
|
| 1625 |
+
)
|
| 1626 |
+
* timestep,
|
| 1627 |
+
]
|
| 1628 |
+
).flatten()
|
| 1629 |
+
t = dit.time_embedding(
|
| 1630 |
+
sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)
|
| 1631 |
+
)
|
| 1632 |
+
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
|
| 1633 |
+
else:
|
| 1634 |
+
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
| 1635 |
+
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
| 1636 |
+
|
| 1637 |
+
if ip_image is not None:
|
| 1638 |
+
timestep_ip = torch.zeros_like(timestep) # [B] with 0s
|
| 1639 |
+
t_ip = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep_ip))
|
| 1640 |
+
t_mod_ip = dit.time_projection(t_ip).unflatten(1, (6, dit.dim))
|
| 1641 |
+
|
| 1642 |
+
# Motion Controller
|
| 1643 |
+
if motion_bucket_id is not None and motion_controller is not None:
|
| 1644 |
+
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
| 1645 |
+
context = dit.text_embedding(context)
|
| 1646 |
+
|
| 1647 |
+
x = latents
|
| 1648 |
+
# Merged cfg
|
| 1649 |
+
if x.shape[0] != context.shape[0]:
|
| 1650 |
+
x = torch.concat([x] * context.shape[0], dim=0)
|
| 1651 |
+
if timestep.shape[0] != context.shape[0]:
|
| 1652 |
+
timestep = torch.concat([timestep] * context.shape[0], dim=0)
|
| 1653 |
+
|
| 1654 |
+
# Image Embedding
|
| 1655 |
+
if y is not None and dit.require_vae_embedding:
|
| 1656 |
+
x = torch.cat([x, y], dim=1)
|
| 1657 |
+
if clip_feature is not None and dit.require_clip_embedding:
|
| 1658 |
+
clip_embdding = dit.img_emb(clip_feature)
|
| 1659 |
+
context = torch.cat([clip_embdding, context], dim=1)
|
| 1660 |
+
|
| 1661 |
+
# Add camera control
|
| 1662 |
+
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
| 1663 |
+
|
| 1664 |
+
# Reference image
|
| 1665 |
+
if reference_latents is not None:
|
| 1666 |
+
if len(reference_latents.shape) == 5:
|
| 1667 |
+
reference_latents = reference_latents[:, :, 0]
|
| 1668 |
+
reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2)
|
| 1669 |
+
x = torch.concat([reference_latents, x], dim=1)
|
| 1670 |
+
f += 1
|
| 1671 |
+
|
| 1672 |
+
offset = 1
|
| 1673 |
+
freqs = (
|
| 1674 |
+
torch.cat(
|
| 1675 |
+
[
|
| 1676 |
+
dit.freqs[0][offset : f + offset].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 1677 |
+
dit.freqs[1][offset : h + offset].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 1678 |
+
dit.freqs[2][offset : w + offset].view(1, 1, w, -1).expand(f, h, w, -1),
|
| 1679 |
+
],
|
| 1680 |
+
dim=-1,
|
| 1681 |
+
)
|
| 1682 |
+
.reshape(f * h * w, 1, -1)
|
| 1683 |
+
.to(x.device)
|
| 1684 |
+
)
|
| 1685 |
+
|
| 1686 |
+
############################################################################################
|
| 1687 |
+
if ip_image is not None:
|
| 1688 |
+
x_ip, (f_ip, h_ip, w_ip) = dit.patchify(
|
| 1689 |
+
ip_image
|
| 1690 |
+
) # x_ip [1, 1024, 5120] [B, N, D] f_ip = 1 h_ip = 32 w_ip = 32
|
| 1691 |
+
freqs_ip = (
|
| 1692 |
+
torch.cat(
|
| 1693 |
+
[
|
| 1694 |
+
dit.freqs[0][0].view(f_ip, 1, 1, -1).expand(f_ip, h_ip, w_ip, -1),
|
| 1695 |
+
dit.freqs[1][h + offset : h + offset + h_ip]
|
| 1696 |
+
.view(1, h_ip, 1, -1)
|
| 1697 |
+
.expand(f_ip, h_ip, w_ip, -1),
|
| 1698 |
+
dit.freqs[2][w + offset : w + offset + w_ip]
|
| 1699 |
+
.view(1, 1, w_ip, -1)
|
| 1700 |
+
.expand(f_ip, h_ip, w_ip, -1),
|
| 1701 |
+
],
|
| 1702 |
+
dim=-1,
|
| 1703 |
+
)
|
| 1704 |
+
.reshape(f_ip * h_ip * w_ip, 1, -1)
|
| 1705 |
+
.to(x_ip.device)
|
| 1706 |
+
)
|
| 1707 |
+
freqs_original = freqs
|
| 1708 |
+
freqs = torch.cat([freqs, freqs_ip], dim=0)
|
| 1709 |
+
############################################################################################
|
| 1710 |
+
else:
|
| 1711 |
+
freqs_original = freqs
|
| 1712 |
+
# TeaCache
|
| 1713 |
+
if tea_cache is not None:
|
| 1714 |
+
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
| 1715 |
+
else:
|
| 1716 |
+
tea_cache_update = False
|
| 1717 |
+
|
| 1718 |
+
if vace_context is not None:
|
| 1719 |
+
vace_hints = vace(x, vace_context, context, t_mod, freqs)
|
| 1720 |
+
|
| 1721 |
+
# blocks
|
| 1722 |
+
if use_unified_sequence_parallel:
|
| 1723 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
| 1724 |
+
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[
|
| 1725 |
+
get_sequence_parallel_rank()
|
| 1726 |
+
]
|
| 1727 |
+
if tea_cache_update:
|
| 1728 |
+
x = tea_cache.update(x)
|
| 1729 |
+
else:
|
| 1730 |
+
|
| 1731 |
+
def create_custom_forward(module):
|
| 1732 |
+
def custom_forward(*inputs):
|
| 1733 |
+
return module(*inputs)
|
| 1734 |
+
|
| 1735 |
+
return custom_forward
|
| 1736 |
+
|
| 1737 |
+
for block_id, block in enumerate(dit.blocks):
|
| 1738 |
+
if use_gradient_checkpointing_offload:
|
| 1739 |
+
with torch.autograd.graph.save_on_cpu():
|
| 1740 |
+
x, x_ip = torch.utils.checkpoint.checkpoint(
|
| 1741 |
+
create_custom_forward(block),
|
| 1742 |
+
x,
|
| 1743 |
+
context,
|
| 1744 |
+
t_mod,
|
| 1745 |
+
freqs,
|
| 1746 |
+
x_ip=x_ip,
|
| 1747 |
+
t_mod_ip=t_mod_ip,
|
| 1748 |
+
use_reentrant=False,
|
| 1749 |
+
)
|
| 1750 |
+
elif use_gradient_checkpointing:
|
| 1751 |
+
x, x_ip = torch.utils.checkpoint.checkpoint(
|
| 1752 |
+
create_custom_forward(block),
|
| 1753 |
+
x,
|
| 1754 |
+
context,
|
| 1755 |
+
t_mod,
|
| 1756 |
+
freqs,
|
| 1757 |
+
x_ip=x_ip,
|
| 1758 |
+
t_mod_ip=t_mod_ip,
|
| 1759 |
+
use_reentrant=False,
|
| 1760 |
+
)
|
| 1761 |
+
else:
|
| 1762 |
+
x, x_ip = block(x, context, t_mod, freqs, x_ip=x_ip, t_mod_ip=t_mod_ip)
|
| 1763 |
+
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
| 1764 |
+
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
| 1765 |
+
if (
|
| 1766 |
+
use_unified_sequence_parallel
|
| 1767 |
+
and dist.is_initialized()
|
| 1768 |
+
and dist.get_world_size() > 1
|
| 1769 |
+
):
|
| 1770 |
+
current_vace_hint = torch.chunk(
|
| 1771 |
+
current_vace_hint, get_sequence_parallel_world_size(), dim=1
|
| 1772 |
+
)[get_sequence_parallel_rank()]
|
| 1773 |
+
x = x + current_vace_hint * vace_scale
|
| 1774 |
+
if tea_cache is not None:
|
| 1775 |
+
tea_cache.store(x)
|
| 1776 |
+
|
| 1777 |
+
x = dit.head(x, t)
|
| 1778 |
+
if use_unified_sequence_parallel:
|
| 1779 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
| 1780 |
+
x = get_sp_group().all_gather(x, dim=1)
|
| 1781 |
+
# Remove reference latents
|
| 1782 |
+
if reference_latents is not None:
|
| 1783 |
+
x = x[:, reference_latents.shape[1] :]
|
| 1784 |
+
f -= 1
|
| 1785 |
+
x = dit.unpatchify(x, (f, h, w))
|
| 1786 |
+
return x
|
preprocessor/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .image_input_preprocessor import FaceProcessor
|
| 2 |
+
from .videomask_generator import VideoMaskGenerator
|
preprocessor/image_input_preprocessor.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import requests
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import PIL.Image
|
| 7 |
+
import PIL.ImageOps
|
| 8 |
+
from insightface.app import FaceAnalysis
|
| 9 |
+
from facexlib.parsing import init_parsing_model
|
| 10 |
+
from torchvision.transforms.functional import normalize
|
| 11 |
+
from typing import Union, Optional
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _img2tensor(img: np.ndarray, bgr2rgb: bool = True) -> torch.Tensor:
|
| 15 |
+
if bgr2rgb:
|
| 16 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 17 |
+
img = img.astype(np.float32) / 255.0
|
| 18 |
+
img = np.transpose(img, (2, 0, 1))
|
| 19 |
+
return torch.from_numpy(img)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _pad_to_square(img: np.ndarray, pad_color: int = 255) -> np.ndarray:
|
| 23 |
+
h, w, _ = img.shape
|
| 24 |
+
if h == w:
|
| 25 |
+
return img
|
| 26 |
+
|
| 27 |
+
if h > w:
|
| 28 |
+
pad_size = (h - w) // 2
|
| 29 |
+
padded_img = cv2.copyMakeBorder(
|
| 30 |
+
img,
|
| 31 |
+
0,
|
| 32 |
+
0,
|
| 33 |
+
pad_size,
|
| 34 |
+
h - w - pad_size,
|
| 35 |
+
cv2.BORDER_CONSTANT,
|
| 36 |
+
value=[pad_color] * 3,
|
| 37 |
+
)
|
| 38 |
+
else:
|
| 39 |
+
pad_size = (w - h) // 2
|
| 40 |
+
padded_img = cv2.copyMakeBorder(
|
| 41 |
+
img,
|
| 42 |
+
pad_size,
|
| 43 |
+
w - h - pad_size,
|
| 44 |
+
0,
|
| 45 |
+
0,
|
| 46 |
+
cv2.BORDER_CONSTANT,
|
| 47 |
+
value=[pad_color] * 3,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return padded_img
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class FaceProcessor:
|
| 54 |
+
def __init__(self, antelopv2_path=".", device: Optional[torch.device] = None):
|
| 55 |
+
if device is None:
|
| 56 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
else:
|
| 58 |
+
self.device = device
|
| 59 |
+
|
| 60 |
+
providers = (
|
| 61 |
+
["CUDAExecutionProvider"]
|
| 62 |
+
if self.device.type == "cuda"
|
| 63 |
+
else ["CPUExecutionProvider"]
|
| 64 |
+
)
|
| 65 |
+
self.app = FaceAnalysis(
|
| 66 |
+
name="antelopev2", root=antelopv2_path, providers=providers
|
| 67 |
+
)
|
| 68 |
+
self.app.prepare(ctx_id=0, det_size=(640, 640))
|
| 69 |
+
|
| 70 |
+
self.parsing_model = init_parsing_model(
|
| 71 |
+
model_name="bisenet", device=self.device
|
| 72 |
+
)
|
| 73 |
+
self.parsing_model.eval()
|
| 74 |
+
|
| 75 |
+
print("FaceProcessor initialized successfully.")
|
| 76 |
+
|
| 77 |
+
def process(
|
| 78 |
+
self,
|
| 79 |
+
image: Union[str, PIL.Image.Image],
|
| 80 |
+
resize_to: int = 512,
|
| 81 |
+
border_thresh: int = 10,
|
| 82 |
+
face_crop_scale: float = 1.5,
|
| 83 |
+
extra_input: bool = False,
|
| 84 |
+
) -> PIL.Image.Image:
|
| 85 |
+
if isinstance(image, str):
|
| 86 |
+
if image.startswith("http://") or image.startswith("https://"):
|
| 87 |
+
image = PIL.Image.open(requests.get(image, stream=True, timeout=10).raw)
|
| 88 |
+
elif os.path.isfile(image):
|
| 89 |
+
image = PIL.Image.open(image)
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f"Input string is not a valid URL or file path: {image}"
|
| 93 |
+
)
|
| 94 |
+
elif not isinstance(image, PIL.Image.Image):
|
| 95 |
+
raise TypeError(
|
| 96 |
+
"Input must be a file path, a URL, or a PIL.Image.Image object."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
image = PIL.ImageOps.exif_transpose(image).convert("RGB")
|
| 100 |
+
|
| 101 |
+
frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 102 |
+
|
| 103 |
+
faces = self.app.get(frame)
|
| 104 |
+
h, w, _ = frame.shape
|
| 105 |
+
image_to_process = None
|
| 106 |
+
|
| 107 |
+
if not faces:
|
| 108 |
+
print(
|
| 109 |
+
"[Warning] No face detected. Using the whole image, padded to square."
|
| 110 |
+
)
|
| 111 |
+
image_to_process = _pad_to_square(frame, pad_color=255)
|
| 112 |
+
else:
|
| 113 |
+
largest_face = max(
|
| 114 |
+
faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1])
|
| 115 |
+
)
|
| 116 |
+
x1, y1, x2, y2 = map(int, largest_face.bbox)
|
| 117 |
+
|
| 118 |
+
is_close_to_border = (
|
| 119 |
+
x1 <= border_thresh
|
| 120 |
+
and y1 <= border_thresh
|
| 121 |
+
and x2 >= w - border_thresh
|
| 122 |
+
and y2 >= h - border_thresh
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if is_close_to_border:
|
| 126 |
+
print(
|
| 127 |
+
"[Info] Face is close to border, padding original image to square."
|
| 128 |
+
)
|
| 129 |
+
image_to_process = _pad_to_square(frame, pad_color=255)
|
| 130 |
+
else:
|
| 131 |
+
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
| 132 |
+
side = int(max(x2 - x1, y2 - y1) * face_crop_scale)
|
| 133 |
+
half = side // 2
|
| 134 |
+
|
| 135 |
+
left = max(cx - half, 0)
|
| 136 |
+
top = max(cy - half, 0)
|
| 137 |
+
right = min(cx + half, w)
|
| 138 |
+
bottom = min(cy + half, h)
|
| 139 |
+
|
| 140 |
+
cropped_face = frame[top:bottom, left:right]
|
| 141 |
+
image_to_process = _pad_to_square(cropped_face, pad_color=255)
|
| 142 |
+
|
| 143 |
+
image_resized = cv2.resize(
|
| 144 |
+
image_to_process, (resize_to, resize_to), interpolation=cv2.INTER_AREA
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
face_tensor = (
|
| 148 |
+
_img2tensor(image_resized, bgr2rgb=True).unsqueeze(0).to(self.device)
|
| 149 |
+
)
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
normalized_face = normalize(face_tensor, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
| 152 |
+
parsing_out = self.parsing_model(normalized_face)[0]
|
| 153 |
+
parsing_mask = parsing_out.argmax(dim=1, keepdim=True)
|
| 154 |
+
|
| 155 |
+
background_mask_np = (parsing_mask.squeeze().cpu().numpy() == 0).astype(
|
| 156 |
+
np.uint8
|
| 157 |
+
)
|
| 158 |
+
white_background = np.ones_like(image_resized, dtype=np.uint8) * 255
|
| 159 |
+
mask_3channel = cv2.cvtColor(background_mask_np * 255, cv2.COLOR_GRAY2BGR)
|
| 160 |
+
result_img_bgr = np.where(mask_3channel == 255, white_background, image_resized)
|
| 161 |
+
result_img_rgb = cv2.cvtColor(result_img_bgr, cv2.COLOR_BGR2RGB)
|
| 162 |
+
img_white_bg = PIL.Image.fromarray(result_img_rgb)
|
| 163 |
+
if extra_input:
|
| 164 |
+
# 2. Create image with transparent background (new logic)
|
| 165 |
+
# Create an alpha channel: 255 for foreground (not background), 0 for background
|
| 166 |
+
alpha_channel = (parsing_mask.squeeze().cpu().numpy() != 0).astype(
|
| 167 |
+
np.uint8
|
| 168 |
+
) * 255
|
| 169 |
+
|
| 170 |
+
# Convert the resized BGR image to RGB
|
| 171 |
+
image_resized_rgb = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB)
|
| 172 |
+
|
| 173 |
+
# Stack RGB channels with the new alpha channel
|
| 174 |
+
rgba_image = np.dstack((image_resized_rgb, alpha_channel))
|
| 175 |
+
|
| 176 |
+
# Create PIL image from the RGBA numpy array
|
| 177 |
+
img_transparent_bg = PIL.Image.fromarray(rgba_image, "RGBA")
|
| 178 |
+
|
| 179 |
+
return img_white_bg, img_transparent_bg
|
| 180 |
+
else:
|
| 181 |
+
return img_white_bg
|
preprocessor/videomask_generator.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torchvision.transforms.functional import normalize
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from PIL import Image, ImageOps
|
| 7 |
+
import random
|
| 8 |
+
import os
|
| 9 |
+
import requests
|
| 10 |
+
from insightface.app import FaceAnalysis
|
| 11 |
+
from facexlib.parsing import init_parsing_model
|
| 12 |
+
from typing import Union, Optional, Tuple, List
|
| 13 |
+
|
| 14 |
+
# --- Helper Functions (Unchanged) ---
|
| 15 |
+
def tensor_to_cv2_img(tensor_frame: torch.Tensor) -> np.ndarray:
|
| 16 |
+
"""Converts a single RGB torch tensor to a BGR OpenCV image."""
|
| 17 |
+
img_np = (tensor_frame.cpu().numpy() * 255).astype(np.uint8)
|
| 18 |
+
return cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 19 |
+
|
| 20 |
+
def tensor_to_cv2_bgra_img(tensor_frame: torch.Tensor) -> np.ndarray:
|
| 21 |
+
"""Converts a single RGBA torch tensor to a BGRA OpenCV image."""
|
| 22 |
+
if tensor_frame.shape[2] != 4:
|
| 23 |
+
raise ValueError("Input tensor must be an RGBA image with 4 channels.")
|
| 24 |
+
img_np = (tensor_frame.cpu().numpy() * 255).astype(np.uint8)
|
| 25 |
+
return cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGRA)
|
| 26 |
+
|
| 27 |
+
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
|
| 28 |
+
"""Converts a PIL image to a torch tensor."""
|
| 29 |
+
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
|
| 30 |
+
|
| 31 |
+
class VideoMaskGenerator:
|
| 32 |
+
def __init__(self, antelopv2_path=".", device: Optional[torch.device] = None):
|
| 33 |
+
if device is None:
|
| 34 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
+
else:
|
| 36 |
+
self.device = device
|
| 37 |
+
|
| 38 |
+
print(f"Using device: {self.device}")
|
| 39 |
+
|
| 40 |
+
providers = ["CUDAExecutionProvider"] if self.device.type == "cuda" else ["CPUExecutionProvider"]
|
| 41 |
+
|
| 42 |
+
# Initialize face detection and landmark model (antelopev2 provides both)
|
| 43 |
+
self.detection_model = FaceAnalysis(name="antelopev2", root=antelopv2_path, providers=providers)
|
| 44 |
+
self.detection_model.prepare(ctx_id=0, det_size=(640, 640))
|
| 45 |
+
|
| 46 |
+
# Initialize face parsing model
|
| 47 |
+
self.parsing_model = init_parsing_model(model_name="bisenet", device=self.device)
|
| 48 |
+
self.parsing_model.eval()
|
| 49 |
+
|
| 50 |
+
print("FaceProcessor initialized successfully.")
|
| 51 |
+
|
| 52 |
+
def process(
|
| 53 |
+
self,
|
| 54 |
+
video_path: str,
|
| 55 |
+
face_image: Union[str, Image.Image],
|
| 56 |
+
confidence_threshold: float = 0.5,
|
| 57 |
+
face_crop_scale: float = 1.5,
|
| 58 |
+
dilation_kernel_size: int = 10,
|
| 59 |
+
feather_amount: int = 21,
|
| 60 |
+
random_horizontal_flip_chance: float = 0.0,
|
| 61 |
+
match_angle_and_size: bool = True
|
| 62 |
+
) -> Tuple[np.ndarray, np.ndarray, int, int, int]:
|
| 63 |
+
"""
|
| 64 |
+
Processes a video to replace a face with a provided face image.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
video_path (str): Path to the input video file.
|
| 68 |
+
face_image (Union[str, Image.Image]): Path or PIL image of the face to paste.
|
| 69 |
+
confidence_threshold (float): Confidence threshold for face detection.
|
| 70 |
+
face_crop_scale (float): Scale factor for cropping the detected face box.
|
| 71 |
+
dilation_kernel_size (int): Kernel size for mask dilation.
|
| 72 |
+
feather_amount (int): Amount of feathering for the mask edges.
|
| 73 |
+
random_horizontal_flip_chance (float): Chance to flip the source face horizontally.
|
| 74 |
+
match_angle_and_size (bool): Whether to use landmark matching for rotation and scale.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Tuple[np.ndarray, np.ndarray, int, int, int]:
|
| 78 |
+
- Processed video as a numpy array (F, H, W, C).
|
| 79 |
+
- Generated masks as a numpy array (F, H, W).
|
| 80 |
+
- Width of the processed video.
|
| 81 |
+
- Height of the processed video.
|
| 82 |
+
- Number of frames in the processed video.
|
| 83 |
+
"""
|
| 84 |
+
# --- Video Pre-processing ---
|
| 85 |
+
if not os.path.exists(video_path):
|
| 86 |
+
raise FileNotFoundError(f"Video file not found at: {video_path}")
|
| 87 |
+
|
| 88 |
+
cap = cv2.VideoCapture(video_path)
|
| 89 |
+
frames = []
|
| 90 |
+
while cap.isOpened():
|
| 91 |
+
ret, frame = cap.read()
|
| 92 |
+
if not ret:
|
| 93 |
+
break
|
| 94 |
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 95 |
+
cap.release()
|
| 96 |
+
|
| 97 |
+
if not frames:
|
| 98 |
+
raise ValueError("Could not read any frames from the video.")
|
| 99 |
+
|
| 100 |
+
video_np = np.array(frames)
|
| 101 |
+
|
| 102 |
+
h, w = video_np.shape[1], video_np.shape[2]
|
| 103 |
+
new_h, new_w = (h // 16) * 16, (w // 16) * 16
|
| 104 |
+
|
| 105 |
+
y_start = (h - new_h) // 2
|
| 106 |
+
x_start = (w - new_w) // 2
|
| 107 |
+
video_cropped = video_np[:, y_start:y_start+new_h, x_start:x_start+new_w, :]
|
| 108 |
+
|
| 109 |
+
num_frames = video_cropped.shape[0]
|
| 110 |
+
target_frames = (num_frames // 4) * 4 + 1
|
| 111 |
+
video_trimmed = video_cropped[:target_frames]
|
| 112 |
+
|
| 113 |
+
final_h, final_w, final_frames = video_trimmed.shape[1], video_trimmed.shape[2], video_trimmed.shape[0]
|
| 114 |
+
print(f"Video pre-processed: {final_w}x{final_h}, {final_frames} frames.")
|
| 115 |
+
|
| 116 |
+
# --- Face Image Pre-processing & Source Landmark Extraction ---
|
| 117 |
+
if isinstance(face_image, str):
|
| 118 |
+
if face_image.startswith("http"):
|
| 119 |
+
face_image = Image.open(requests.get(face_image, stream=True, timeout=10).raw)
|
| 120 |
+
else:
|
| 121 |
+
face_image = Image.open(face_image)
|
| 122 |
+
|
| 123 |
+
face_image = ImageOps.exif_transpose(face_image).convert("RGBA")
|
| 124 |
+
face_rgba_tensor = pil_to_tensor(face_image)
|
| 125 |
+
face_to_paste_cv2 = tensor_to_cv2_bgra_img(face_rgba_tensor)
|
| 126 |
+
|
| 127 |
+
source_kpts = None
|
| 128 |
+
if match_angle_and_size:
|
| 129 |
+
# Use insightface (antelopev2) to get landmarks from the source face image
|
| 130 |
+
source_face_bgr = cv2.cvtColor(face_to_paste_cv2, cv2.COLOR_BGRA2BGR)
|
| 131 |
+
source_faces = self.detection_model.get(source_face_bgr)
|
| 132 |
+
if source_faces:
|
| 133 |
+
# Use the landmarks from the first (and likely only) detected face
|
| 134 |
+
source_kpts = source_faces[0].kps
|
| 135 |
+
else:
|
| 136 |
+
print("[Warning] No face or landmarks found in source image. Disabling angle matching.")
|
| 137 |
+
match_angle_and_size = False
|
| 138 |
+
|
| 139 |
+
face_to_paste_pil = Image.fromarray((face_rgba_tensor.cpu().numpy() * 255).astype(np.uint8), 'RGBA')
|
| 140 |
+
|
| 141 |
+
# --- Main Processing Loop ---
|
| 142 |
+
processed_frames_list = []
|
| 143 |
+
mask_list = []
|
| 144 |
+
|
| 145 |
+
for i in tqdm(range(final_frames), desc="Pasting face onto frames"):
|
| 146 |
+
frame_rgb = video_trimmed[i]
|
| 147 |
+
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
| 148 |
+
|
| 149 |
+
# Use insightface for detection and landmarks
|
| 150 |
+
faces = self.detection_model.get(frame_bgr)
|
| 151 |
+
|
| 152 |
+
pasted = False
|
| 153 |
+
final_mask = np.zeros((final_h, final_w), dtype=np.uint8)
|
| 154 |
+
|
| 155 |
+
if faces:
|
| 156 |
+
largest_face = max(faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]))
|
| 157 |
+
|
| 158 |
+
if largest_face.det_score > confidence_threshold:
|
| 159 |
+
# **MODIFIED BLOCK**: Use insightface landmarks for affine transform
|
| 160 |
+
if match_angle_and_size and source_kpts is not None:
|
| 161 |
+
target_kpts = largest_face.kps # Get landmarks directly from the detected face
|
| 162 |
+
|
| 163 |
+
# Estimate the transformation matrix
|
| 164 |
+
M, _ = cv2.estimateAffinePartial2D(source_kpts, target_kpts, method=cv2.LMEDS)
|
| 165 |
+
|
| 166 |
+
if M is not None:
|
| 167 |
+
# Split the RGBA source face for separate warping
|
| 168 |
+
b, g, r, a = cv2.split(face_to_paste_cv2)
|
| 169 |
+
source_rgb_cv2 = cv2.merge([r, g, b])
|
| 170 |
+
|
| 171 |
+
# Warp the face and its alpha channel
|
| 172 |
+
warped_face = cv2.warpAffine(source_rgb_cv2, M, (final_w, final_h))
|
| 173 |
+
warped_alpha = cv2.warpAffine(a, M, (final_w, final_h))
|
| 174 |
+
|
| 175 |
+
# Blend the warped face onto the frame using the warped alpha channel
|
| 176 |
+
alpha_float = warped_alpha.astype(np.float32) / 255.0
|
| 177 |
+
alpha_expanded = np.expand_dims(alpha_float, axis=2)
|
| 178 |
+
|
| 179 |
+
frame_rgb = (1.0 - alpha_expanded) * frame_rgb + alpha_expanded * warped_face
|
| 180 |
+
frame_rgb = frame_rgb.astype(np.uint8)
|
| 181 |
+
final_mask = warped_alpha
|
| 182 |
+
pasted = True
|
| 183 |
+
|
| 184 |
+
# Fallback to simple box-pasting if angle matching is off or fails
|
| 185 |
+
if not pasted:
|
| 186 |
+
x1, y1, x2, y2 = map(int, largest_face.bbox)
|
| 187 |
+
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
|
| 188 |
+
side_len = int(max(x2 - x1, y2 - y1) * face_crop_scale)
|
| 189 |
+
half_side = side_len // 2
|
| 190 |
+
|
| 191 |
+
crop_y1, crop_x1 = max(center_y - half_side, 0), max(center_x - half_side, 0)
|
| 192 |
+
crop_y2, crop_x2 = min(center_y + half_side, final_h), min(center_x + half_side, final_w)
|
| 193 |
+
|
| 194 |
+
box_w, box_h = crop_x2 - crop_x1, crop_y2 - crop_y1
|
| 195 |
+
|
| 196 |
+
if box_w > 0 and box_h > 0:
|
| 197 |
+
source_img = face_to_paste_pil.copy()
|
| 198 |
+
if random.random() < random_horizontal_flip_chance:
|
| 199 |
+
source_img = source_img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 200 |
+
|
| 201 |
+
face_resized = source_img.resize((box_w, box_h), Image.Resampling.LANCZOS)
|
| 202 |
+
|
| 203 |
+
target_frame_pil = Image.fromarray(frame_rgb)
|
| 204 |
+
|
| 205 |
+
# --- Mask Generation using BiSeNet ---
|
| 206 |
+
face_crop_bgr = cv2.cvtColor(frame_rgb[crop_y1:crop_y2, crop_x1:crop_x2], cv2.COLOR_RGB2BGR)
|
| 207 |
+
if face_crop_bgr.size > 0:
|
| 208 |
+
face_resized_512 = cv2.resize(face_crop_bgr, (512, 512), interpolation=cv2.INTER_AREA)
|
| 209 |
+
face_rgb_512 = cv2.cvtColor(face_resized_512, cv2.COLOR_BGR2RGB)
|
| 210 |
+
face_tensor_in = torch.from_numpy(face_rgb_512.astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0).to(self.device)
|
| 211 |
+
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
normalized_face = normalize(face_tensor_in, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 214 |
+
parsing_map = self.parsing_model(normalized_face)[0].argmax(dim=1, keepdim=True)
|
| 215 |
+
|
| 216 |
+
parsing_map_np = parsing_map.squeeze().cpu().numpy().astype(np.uint8)
|
| 217 |
+
parts_to_include = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] # All face parts
|
| 218 |
+
final_mask_512 = np.isin(parsing_map_np, parts_to_include).astype(np.uint8) * 255
|
| 219 |
+
|
| 220 |
+
if dilation_kernel_size > 0:
|
| 221 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_kernel_size, dilation_kernel_size))
|
| 222 |
+
final_mask_512 = cv2.dilate(final_mask_512, kernel, iterations=1)
|
| 223 |
+
|
| 224 |
+
if feather_amount > 0:
|
| 225 |
+
if feather_amount % 2 == 0: feather_amount += 1
|
| 226 |
+
final_mask_512 = cv2.GaussianBlur(final_mask_512, (feather_amount, feather_amount), 0)
|
| 227 |
+
|
| 228 |
+
mask_resized_to_crop = cv2.resize(final_mask_512, (box_w, box_h), interpolation=cv2.INTER_LINEAR)
|
| 229 |
+
generated_mask_pil = Image.fromarray(mask_resized_to_crop, mode='L')
|
| 230 |
+
|
| 231 |
+
target_frame_pil.paste(face_resized, (crop_x1, crop_y1), mask=generated_mask_pil)
|
| 232 |
+
frame_rgb = np.array(target_frame_pil)
|
| 233 |
+
final_mask[crop_y1:crop_y2, crop_x1:crop_x2] = mask_resized_to_crop
|
| 234 |
+
|
| 235 |
+
processed_frames_list.append(frame_rgb)
|
| 236 |
+
mask_list.append(final_mask)
|
| 237 |
+
|
| 238 |
+
output_video = np.stack(processed_frames_list)
|
| 239 |
+
# Ensure mask has a channel dimension for consistency
|
| 240 |
+
output_masks = np.stack(mask_list)[..., np.newaxis]
|
| 241 |
+
|
| 242 |
+
return (output_video, output_masks, final_w, final_h, final_frames)
|
prompters/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .prompt_refiners import Translator, BeautifulPrompt, QwenPrompt
|
| 2 |
+
from .omost import OmostPromter
|
| 3 |
+
from .wan_prompter import WanPrompter
|
prompters/base_prompter.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.model_manager import ModelManager
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
| 6 |
+
# Get model_max_length from self.tokenizer
|
| 7 |
+
length = tokenizer.model_max_length if max_length is None else max_length
|
| 8 |
+
|
| 9 |
+
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
| 10 |
+
tokenizer.model_max_length = 99999999
|
| 11 |
+
|
| 12 |
+
# Tokenize it!
|
| 13 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
| 14 |
+
|
| 15 |
+
# Determine the real length.
|
| 16 |
+
max_length = (input_ids.shape[1] + length - 1) // length * length
|
| 17 |
+
|
| 18 |
+
# Restore tokenizer.model_max_length
|
| 19 |
+
tokenizer.model_max_length = length
|
| 20 |
+
|
| 21 |
+
# Tokenize it again with fixed length.
|
| 22 |
+
input_ids = tokenizer(
|
| 23 |
+
prompt,
|
| 24 |
+
return_tensors="pt",
|
| 25 |
+
padding="max_length",
|
| 26 |
+
max_length=max_length,
|
| 27 |
+
truncation=True,
|
| 28 |
+
).input_ids
|
| 29 |
+
|
| 30 |
+
# Reshape input_ids to fit the text encoder.
|
| 31 |
+
num_sentence = input_ids.shape[1] // length
|
| 32 |
+
input_ids = input_ids.reshape((num_sentence, length))
|
| 33 |
+
|
| 34 |
+
return input_ids
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BasePrompter:
|
| 38 |
+
def __init__(self):
|
| 39 |
+
self.refiners = []
|
| 40 |
+
self.extenders = []
|
| 41 |
+
|
| 42 |
+
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
|
| 43 |
+
for refiner_class in refiner_classes:
|
| 44 |
+
refiner = refiner_class.from_model_manager(model_manager)
|
| 45 |
+
self.refiners.append(refiner)
|
| 46 |
+
|
| 47 |
+
def load_prompt_extenders(self, model_manager: ModelManager, extender_classes=[]):
|
| 48 |
+
for extender_class in extender_classes:
|
| 49 |
+
extender = extender_class.from_model_manager(model_manager)
|
| 50 |
+
self.extenders.append(extender)
|
| 51 |
+
|
| 52 |
+
@torch.no_grad()
|
| 53 |
+
def process_prompt(self, prompt, positive=True):
|
| 54 |
+
if isinstance(prompt, list):
|
| 55 |
+
prompt = [
|
| 56 |
+
self.process_prompt(prompt_, positive=positive) for prompt_ in prompt
|
| 57 |
+
]
|
| 58 |
+
else:
|
| 59 |
+
for refiner in self.refiners:
|
| 60 |
+
prompt = refiner(prompt, positive=positive)
|
| 61 |
+
return prompt
|
| 62 |
+
|
| 63 |
+
@torch.no_grad()
|
| 64 |
+
def extend_prompt(self, prompt: str, positive=True):
|
| 65 |
+
extended_prompt = dict(prompt=prompt)
|
| 66 |
+
for extender in self.extenders:
|
| 67 |
+
extended_prompt = extender(extended_prompt)
|
| 68 |
+
return extended_prompt
|
prompters/omost.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, TextIteratorStreamer
|
| 2 |
+
import difflib
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import re
|
| 6 |
+
from models.model_manager import ModelManager
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
valid_colors = { # r, g, b
|
| 10 |
+
"aliceblue": (240, 248, 255),
|
| 11 |
+
"antiquewhite": (250, 235, 215),
|
| 12 |
+
"aqua": (0, 255, 255),
|
| 13 |
+
"aquamarine": (127, 255, 212),
|
| 14 |
+
"azure": (240, 255, 255),
|
| 15 |
+
"beige": (245, 245, 220),
|
| 16 |
+
"bisque": (255, 228, 196),
|
| 17 |
+
"black": (0, 0, 0),
|
| 18 |
+
"blanchedalmond": (255, 235, 205),
|
| 19 |
+
"blue": (0, 0, 255),
|
| 20 |
+
"blueviolet": (138, 43, 226),
|
| 21 |
+
"brown": (165, 42, 42),
|
| 22 |
+
"burlywood": (222, 184, 135),
|
| 23 |
+
"cadetblue": (95, 158, 160),
|
| 24 |
+
"chartreuse": (127, 255, 0),
|
| 25 |
+
"chocolate": (210, 105, 30),
|
| 26 |
+
"coral": (255, 127, 80),
|
| 27 |
+
"cornflowerblue": (100, 149, 237),
|
| 28 |
+
"cornsilk": (255, 248, 220),
|
| 29 |
+
"crimson": (220, 20, 60),
|
| 30 |
+
"cyan": (0, 255, 255),
|
| 31 |
+
"darkblue": (0, 0, 139),
|
| 32 |
+
"darkcyan": (0, 139, 139),
|
| 33 |
+
"darkgoldenrod": (184, 134, 11),
|
| 34 |
+
"darkgray": (169, 169, 169),
|
| 35 |
+
"darkgrey": (169, 169, 169),
|
| 36 |
+
"darkgreen": (0, 100, 0),
|
| 37 |
+
"darkkhaki": (189, 183, 107),
|
| 38 |
+
"darkmagenta": (139, 0, 139),
|
| 39 |
+
"darkolivegreen": (85, 107, 47),
|
| 40 |
+
"darkorange": (255, 140, 0),
|
| 41 |
+
"darkorchid": (153, 50, 204),
|
| 42 |
+
"darkred": (139, 0, 0),
|
| 43 |
+
"darksalmon": (233, 150, 122),
|
| 44 |
+
"darkseagreen": (143, 188, 143),
|
| 45 |
+
"darkslateblue": (72, 61, 139),
|
| 46 |
+
"darkslategray": (47, 79, 79),
|
| 47 |
+
"darkslategrey": (47, 79, 79),
|
| 48 |
+
"darkturquoise": (0, 206, 209),
|
| 49 |
+
"darkviolet": (148, 0, 211),
|
| 50 |
+
"deeppink": (255, 20, 147),
|
| 51 |
+
"deepskyblue": (0, 191, 255),
|
| 52 |
+
"dimgray": (105, 105, 105),
|
| 53 |
+
"dimgrey": (105, 105, 105),
|
| 54 |
+
"dodgerblue": (30, 144, 255),
|
| 55 |
+
"firebrick": (178, 34, 34),
|
| 56 |
+
"floralwhite": (255, 250, 240),
|
| 57 |
+
"forestgreen": (34, 139, 34),
|
| 58 |
+
"fuchsia": (255, 0, 255),
|
| 59 |
+
"gainsboro": (220, 220, 220),
|
| 60 |
+
"ghostwhite": (248, 248, 255),
|
| 61 |
+
"gold": (255, 215, 0),
|
| 62 |
+
"goldenrod": (218, 165, 32),
|
| 63 |
+
"gray": (128, 128, 128),
|
| 64 |
+
"grey": (128, 128, 128),
|
| 65 |
+
"green": (0, 128, 0),
|
| 66 |
+
"greenyellow": (173, 255, 47),
|
| 67 |
+
"honeydew": (240, 255, 240),
|
| 68 |
+
"hotpink": (255, 105, 180),
|
| 69 |
+
"indianred": (205, 92, 92),
|
| 70 |
+
"indigo": (75, 0, 130),
|
| 71 |
+
"ivory": (255, 255, 240),
|
| 72 |
+
"khaki": (240, 230, 140),
|
| 73 |
+
"lavender": (230, 230, 250),
|
| 74 |
+
"lavenderblush": (255, 240, 245),
|
| 75 |
+
"lawngreen": (124, 252, 0),
|
| 76 |
+
"lemonchiffon": (255, 250, 205),
|
| 77 |
+
"lightblue": (173, 216, 230),
|
| 78 |
+
"lightcoral": (240, 128, 128),
|
| 79 |
+
"lightcyan": (224, 255, 255),
|
| 80 |
+
"lightgoldenrodyellow": (250, 250, 210),
|
| 81 |
+
"lightgray": (211, 211, 211),
|
| 82 |
+
"lightgrey": (211, 211, 211),
|
| 83 |
+
"lightgreen": (144, 238, 144),
|
| 84 |
+
"lightpink": (255, 182, 193),
|
| 85 |
+
"lightsalmon": (255, 160, 122),
|
| 86 |
+
"lightseagreen": (32, 178, 170),
|
| 87 |
+
"lightskyblue": (135, 206, 250),
|
| 88 |
+
"lightslategray": (119, 136, 153),
|
| 89 |
+
"lightslategrey": (119, 136, 153),
|
| 90 |
+
"lightsteelblue": (176, 196, 222),
|
| 91 |
+
"lightyellow": (255, 255, 224),
|
| 92 |
+
"lime": (0, 255, 0),
|
| 93 |
+
"limegreen": (50, 205, 50),
|
| 94 |
+
"linen": (250, 240, 230),
|
| 95 |
+
"magenta": (255, 0, 255),
|
| 96 |
+
"maroon": (128, 0, 0),
|
| 97 |
+
"mediumaquamarine": (102, 205, 170),
|
| 98 |
+
"mediumblue": (0, 0, 205),
|
| 99 |
+
"mediumorchid": (186, 85, 211),
|
| 100 |
+
"mediumpurple": (147, 112, 219),
|
| 101 |
+
"mediumseagreen": (60, 179, 113),
|
| 102 |
+
"mediumslateblue": (123, 104, 238),
|
| 103 |
+
"mediumspringgreen": (0, 250, 154),
|
| 104 |
+
"mediumturquoise": (72, 209, 204),
|
| 105 |
+
"mediumvioletred": (199, 21, 133),
|
| 106 |
+
"midnightblue": (25, 25, 112),
|
| 107 |
+
"mintcream": (245, 255, 250),
|
| 108 |
+
"mistyrose": (255, 228, 225),
|
| 109 |
+
"moccasin": (255, 228, 181),
|
| 110 |
+
"navajowhite": (255, 222, 173),
|
| 111 |
+
"navy": (0, 0, 128),
|
| 112 |
+
"navyblue": (0, 0, 128),
|
| 113 |
+
"oldlace": (253, 245, 230),
|
| 114 |
+
"olive": (128, 128, 0),
|
| 115 |
+
"olivedrab": (107, 142, 35),
|
| 116 |
+
"orange": (255, 165, 0),
|
| 117 |
+
"orangered": (255, 69, 0),
|
| 118 |
+
"orchid": (218, 112, 214),
|
| 119 |
+
"palegoldenrod": (238, 232, 170),
|
| 120 |
+
"palegreen": (152, 251, 152),
|
| 121 |
+
"paleturquoise": (175, 238, 238),
|
| 122 |
+
"palevioletred": (219, 112, 147),
|
| 123 |
+
"papayawhip": (255, 239, 213),
|
| 124 |
+
"peachpuff": (255, 218, 185),
|
| 125 |
+
"peru": (205, 133, 63),
|
| 126 |
+
"pink": (255, 192, 203),
|
| 127 |
+
"plum": (221, 160, 221),
|
| 128 |
+
"powderblue": (176, 224, 230),
|
| 129 |
+
"purple": (128, 0, 128),
|
| 130 |
+
"rebeccapurple": (102, 51, 153),
|
| 131 |
+
"red": (255, 0, 0),
|
| 132 |
+
"rosybrown": (188, 143, 143),
|
| 133 |
+
"royalblue": (65, 105, 225),
|
| 134 |
+
"saddlebrown": (139, 69, 19),
|
| 135 |
+
"salmon": (250, 128, 114),
|
| 136 |
+
"sandybrown": (244, 164, 96),
|
| 137 |
+
"seagreen": (46, 139, 87),
|
| 138 |
+
"seashell": (255, 245, 238),
|
| 139 |
+
"sienna": (160, 82, 45),
|
| 140 |
+
"silver": (192, 192, 192),
|
| 141 |
+
"skyblue": (135, 206, 235),
|
| 142 |
+
"slateblue": (106, 90, 205),
|
| 143 |
+
"slategray": (112, 128, 144),
|
| 144 |
+
"slategrey": (112, 128, 144),
|
| 145 |
+
"snow": (255, 250, 250),
|
| 146 |
+
"springgreen": (0, 255, 127),
|
| 147 |
+
"steelblue": (70, 130, 180),
|
| 148 |
+
"tan": (210, 180, 140),
|
| 149 |
+
"teal": (0, 128, 128),
|
| 150 |
+
"thistle": (216, 191, 216),
|
| 151 |
+
"tomato": (255, 99, 71),
|
| 152 |
+
"turquoise": (64, 224, 208),
|
| 153 |
+
"violet": (238, 130, 238),
|
| 154 |
+
"wheat": (245, 222, 179),
|
| 155 |
+
"white": (255, 255, 255),
|
| 156 |
+
"whitesmoke": (245, 245, 245),
|
| 157 |
+
"yellow": (255, 255, 0),
|
| 158 |
+
"yellowgreen": (154, 205, 50),
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
valid_locations = { # x, y in 90*90
|
| 162 |
+
"in the center": (45, 45),
|
| 163 |
+
"on the left": (15, 45),
|
| 164 |
+
"on the right": (75, 45),
|
| 165 |
+
"on the top": (45, 15),
|
| 166 |
+
"on the bottom": (45, 75),
|
| 167 |
+
"on the top-left": (15, 15),
|
| 168 |
+
"on the top-right": (75, 15),
|
| 169 |
+
"on the bottom-left": (15, 75),
|
| 170 |
+
"on the bottom-right": (75, 75),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
valid_offsets = { # x, y in 90*90
|
| 174 |
+
"no offset": (0, 0),
|
| 175 |
+
"slightly to the left": (-10, 0),
|
| 176 |
+
"slightly to the right": (10, 0),
|
| 177 |
+
"slightly to the upper": (0, -10),
|
| 178 |
+
"slightly to the lower": (0, 10),
|
| 179 |
+
"slightly to the upper-left": (-10, -10),
|
| 180 |
+
"slightly to the upper-right": (10, -10),
|
| 181 |
+
"slightly to the lower-left": (-10, 10),
|
| 182 |
+
"slightly to the lower-right": (10, 10),
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
valid_areas = { # w, h in 90*90
|
| 186 |
+
"a small square area": (50, 50),
|
| 187 |
+
"a small vertical area": (40, 60),
|
| 188 |
+
"a small horizontal area": (60, 40),
|
| 189 |
+
"a medium-sized square area": (60, 60),
|
| 190 |
+
"a medium-sized vertical area": (50, 80),
|
| 191 |
+
"a medium-sized horizontal area": (80, 50),
|
| 192 |
+
"a large square area": (70, 70),
|
| 193 |
+
"a large vertical area": (60, 90),
|
| 194 |
+
"a large horizontal area": (90, 60),
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def safe_str(x):
|
| 199 |
+
return x.strip(",. ") + "."
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def closest_name(input_str, options):
|
| 203 |
+
input_str = input_str.lower()
|
| 204 |
+
|
| 205 |
+
closest_match = difflib.get_close_matches(
|
| 206 |
+
input_str, list(options.keys()), n=1, cutoff=0.5
|
| 207 |
+
)
|
| 208 |
+
assert isinstance(closest_match, list) and len(closest_match) > 0, (
|
| 209 |
+
f"The value [{input_str}] is not valid!"
|
| 210 |
+
)
|
| 211 |
+
result = closest_match[0]
|
| 212 |
+
|
| 213 |
+
if result != input_str:
|
| 214 |
+
print(f"Automatically corrected [{input_str}] -> [{result}].")
|
| 215 |
+
|
| 216 |
+
return result
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class Canvas:
|
| 220 |
+
@staticmethod
|
| 221 |
+
def from_bot_response(response: str):
|
| 222 |
+
matched = re.search(r"```python\n(.*?)\n```", response, re.DOTALL)
|
| 223 |
+
assert matched, "Response does not contain codes!"
|
| 224 |
+
code_content = matched.group(1)
|
| 225 |
+
assert "canvas = Canvas()" in code_content, (
|
| 226 |
+
"Code block must include valid canvas var!"
|
| 227 |
+
)
|
| 228 |
+
local_vars = {"Canvas": Canvas}
|
| 229 |
+
exec(code_content, {}, local_vars)
|
| 230 |
+
canvas = local_vars.get("canvas", None)
|
| 231 |
+
assert isinstance(canvas, Canvas), "Code block must produce valid canvas var!"
|
| 232 |
+
return canvas
|
| 233 |
+
|
| 234 |
+
def __init__(self):
|
| 235 |
+
self.components = []
|
| 236 |
+
self.color = None
|
| 237 |
+
self.record_tags = True
|
| 238 |
+
self.prefixes = []
|
| 239 |
+
self.suffixes = []
|
| 240 |
+
return
|
| 241 |
+
|
| 242 |
+
def set_global_description(
|
| 243 |
+
self,
|
| 244 |
+
description: str,
|
| 245 |
+
detailed_descriptions: list,
|
| 246 |
+
tags: str,
|
| 247 |
+
HTML_web_color_name: str,
|
| 248 |
+
):
|
| 249 |
+
assert isinstance(description, str), "Global description is not valid!"
|
| 250 |
+
assert isinstance(detailed_descriptions, list) and all(
|
| 251 |
+
isinstance(item, str) for item in detailed_descriptions
|
| 252 |
+
), "Global detailed_descriptions is not valid!"
|
| 253 |
+
assert isinstance(tags, str), "Global tags is not valid!"
|
| 254 |
+
|
| 255 |
+
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
|
| 256 |
+
self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
|
| 257 |
+
|
| 258 |
+
self.prefixes = [description]
|
| 259 |
+
self.suffixes = detailed_descriptions
|
| 260 |
+
|
| 261 |
+
if self.record_tags:
|
| 262 |
+
self.suffixes = self.suffixes + [tags]
|
| 263 |
+
|
| 264 |
+
self.prefixes = [safe_str(x) for x in self.prefixes]
|
| 265 |
+
self.suffixes = [safe_str(x) for x in self.suffixes]
|
| 266 |
+
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
def add_local_description(
|
| 270 |
+
self,
|
| 271 |
+
location: str,
|
| 272 |
+
offset: str,
|
| 273 |
+
area: str,
|
| 274 |
+
distance_to_viewer: float,
|
| 275 |
+
description: str,
|
| 276 |
+
detailed_descriptions: list,
|
| 277 |
+
tags: str,
|
| 278 |
+
atmosphere: str,
|
| 279 |
+
style: str,
|
| 280 |
+
quality_meta: str,
|
| 281 |
+
HTML_web_color_name: str,
|
| 282 |
+
):
|
| 283 |
+
assert isinstance(description, str), "Local description is wrong!"
|
| 284 |
+
assert (
|
| 285 |
+
isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0
|
| 286 |
+
), f"The distance_to_viewer for [{description}] is not positive float number!"
|
| 287 |
+
assert isinstance(detailed_descriptions, list) and all(
|
| 288 |
+
isinstance(item, str) for item in detailed_descriptions
|
| 289 |
+
), f"The detailed_descriptions for [{description}] is not valid!"
|
| 290 |
+
assert isinstance(tags, str), f"The tags for [{description}] is not valid!"
|
| 291 |
+
assert isinstance(atmosphere, str), (
|
| 292 |
+
f"The atmosphere for [{description}] is not valid!"
|
| 293 |
+
)
|
| 294 |
+
assert isinstance(style, str), f"The style for [{description}] is not valid!"
|
| 295 |
+
assert isinstance(quality_meta, str), (
|
| 296 |
+
f"The quality_meta for [{description}] is not valid!"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
location = closest_name(location, valid_locations)
|
| 300 |
+
offset = closest_name(offset, valid_offsets)
|
| 301 |
+
area = closest_name(area, valid_areas)
|
| 302 |
+
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
|
| 303 |
+
|
| 304 |
+
xb, yb = valid_locations[location]
|
| 305 |
+
xo, yo = valid_offsets[offset]
|
| 306 |
+
w, h = valid_areas[area]
|
| 307 |
+
rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
|
| 308 |
+
rect = [max(0, min(90, i)) for i in rect]
|
| 309 |
+
color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
|
| 310 |
+
|
| 311 |
+
prefixes = self.prefixes + [description]
|
| 312 |
+
suffixes = detailed_descriptions
|
| 313 |
+
|
| 314 |
+
if self.record_tags:
|
| 315 |
+
suffixes = suffixes + [tags, atmosphere, style, quality_meta]
|
| 316 |
+
|
| 317 |
+
prefixes = [safe_str(x) for x in prefixes]
|
| 318 |
+
suffixes = [safe_str(x) for x in suffixes]
|
| 319 |
+
|
| 320 |
+
self.components.append(
|
| 321 |
+
dict(
|
| 322 |
+
rect=rect,
|
| 323 |
+
distance_to_viewer=distance_to_viewer,
|
| 324 |
+
color=color,
|
| 325 |
+
prefixes=prefixes,
|
| 326 |
+
suffixes=suffixes,
|
| 327 |
+
location=location,
|
| 328 |
+
)
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
return
|
| 332 |
+
|
| 333 |
+
def process(self):
|
| 334 |
+
# sort components
|
| 335 |
+
self.components = sorted(
|
| 336 |
+
self.components, key=lambda x: x["distance_to_viewer"], reverse=True
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# compute initial latent
|
| 340 |
+
# print(self.color)
|
| 341 |
+
initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
|
| 342 |
+
|
| 343 |
+
for component in self.components:
|
| 344 |
+
a, b, c, d = component["rect"]
|
| 345 |
+
initial_latent[a:b, c:d] = (
|
| 346 |
+
0.7 * component["color"] + 0.3 * initial_latent[a:b, c:d]
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
|
| 350 |
+
|
| 351 |
+
# compute conditions
|
| 352 |
+
|
| 353 |
+
bag_of_conditions = [
|
| 354 |
+
dict(
|
| 355 |
+
mask=np.ones(shape=(90, 90), dtype=np.float32),
|
| 356 |
+
prefixes=self.prefixes,
|
| 357 |
+
suffixes=self.suffixes,
|
| 358 |
+
location="full",
|
| 359 |
+
)
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
+
for i, component in enumerate(self.components):
|
| 363 |
+
a, b, c, d = component["rect"]
|
| 364 |
+
m = np.zeros(shape=(90, 90), dtype=np.float32)
|
| 365 |
+
m[a:b, c:d] = 1.0
|
| 366 |
+
bag_of_conditions.append(
|
| 367 |
+
dict(
|
| 368 |
+
mask=m,
|
| 369 |
+
prefixes=component["prefixes"],
|
| 370 |
+
suffixes=component["suffixes"],
|
| 371 |
+
location=component["location"],
|
| 372 |
+
)
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
return dict(
|
| 376 |
+
initial_latent=initial_latent,
|
| 377 |
+
bag_of_conditions=bag_of_conditions,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class OmostPromter(torch.nn.Module):
|
| 382 |
+
def __init__(self, model=None, tokenizer=None, template="", device="cpu"):
|
| 383 |
+
super().__init__()
|
| 384 |
+
self.model = model
|
| 385 |
+
self.tokenizer = tokenizer
|
| 386 |
+
self.device = device
|
| 387 |
+
if template == "":
|
| 388 |
+
template = r"""You are a helpful AI assistant to compose images using the below python class `Canvas`:
|
| 389 |
+
```python
|
| 390 |
+
class Canvas:
|
| 391 |
+
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
|
| 392 |
+
pass
|
| 393 |
+
|
| 394 |
+
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
|
| 395 |
+
assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
|
| 396 |
+
assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
|
| 397 |
+
assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
|
| 398 |
+
assert distance_to_viewer > 0
|
| 399 |
+
pass
|
| 400 |
+
```"""
|
| 401 |
+
self.template = template
|
| 402 |
+
|
| 403 |
+
@staticmethod
|
| 404 |
+
def from_model_manager(model_manager: ModelManager):
|
| 405 |
+
model, model_path = model_manager.fetch_model(
|
| 406 |
+
"omost_prompt", require_model_path=True
|
| 407 |
+
)
|
| 408 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 409 |
+
omost = OmostPromter(
|
| 410 |
+
model=model, tokenizer=tokenizer, device=model_manager.device
|
| 411 |
+
)
|
| 412 |
+
return omost
|
| 413 |
+
|
| 414 |
+
def __call__(self, prompt_dict: dict):
|
| 415 |
+
raw_prompt = prompt_dict["prompt"]
|
| 416 |
+
conversation = [{"role": "system", "content": self.template}]
|
| 417 |
+
conversation.append({"role": "user", "content": raw_prompt})
|
| 418 |
+
|
| 419 |
+
input_ids = self.tokenizer.apply_chat_template(
|
| 420 |
+
conversation, return_tensors="pt", add_generation_prompt=True
|
| 421 |
+
).to(self.device)
|
| 422 |
+
streamer = TextIteratorStreamer(
|
| 423 |
+
self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
| 424 |
+
)
|
| 425 |
+
attention_mask = torch.ones(
|
| 426 |
+
input_ids.shape, dtype=torch.bfloat16, device=self.device
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
generate_kwargs = dict(
|
| 430 |
+
input_ids=input_ids,
|
| 431 |
+
streamer=streamer,
|
| 432 |
+
# stopping_criteria=stopping_criteria,
|
| 433 |
+
# max_new_tokens=max_new_tokens,
|
| 434 |
+
do_sample=True,
|
| 435 |
+
attention_mask=attention_mask,
|
| 436 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 437 |
+
# temperature=temperature,
|
| 438 |
+
# top_p=top_p,
|
| 439 |
+
)
|
| 440 |
+
self.model.generate(**generate_kwargs)
|
| 441 |
+
outputs = []
|
| 442 |
+
for text in streamer:
|
| 443 |
+
outputs.append(text)
|
| 444 |
+
llm_outputs = "".join(outputs)
|
| 445 |
+
|
| 446 |
+
canvas = Canvas.from_bot_response(llm_outputs)
|
| 447 |
+
canvas_output = canvas.process()
|
| 448 |
+
|
| 449 |
+
prompts = [
|
| 450 |
+
" ".join(_["prefixes"] + _["suffixes"][:2])
|
| 451 |
+
for _ in canvas_output["bag_of_conditions"]
|
| 452 |
+
]
|
| 453 |
+
canvas_output["prompt"] = prompts[0]
|
| 454 |
+
canvas_output["prompts"] = prompts[1:]
|
| 455 |
+
|
| 456 |
+
raw_masks = [_["mask"] for _ in canvas_output["bag_of_conditions"]]
|
| 457 |
+
masks = []
|
| 458 |
+
for mask in raw_masks:
|
| 459 |
+
mask[mask > 0.5] = 255
|
| 460 |
+
mask = np.stack([mask] * 3, axis=-1).astype("uint8")
|
| 461 |
+
masks.append(Image.fromarray(mask))
|
| 462 |
+
|
| 463 |
+
canvas_output["masks"] = masks
|
| 464 |
+
prompt_dict.update(canvas_output)
|
| 465 |
+
print(f"Your prompt is extended by Omost:\n")
|
| 466 |
+
cnt = 0
|
| 467 |
+
for component, pmt in zip(canvas_output["bag_of_conditions"], prompts):
|
| 468 |
+
loc = component["location"]
|
| 469 |
+
cnt += 1
|
| 470 |
+
print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n")
|
| 471 |
+
|
| 472 |
+
return prompt_dict
|
prompters/prompt_refiners.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer
|
| 2 |
+
from models.model_manager import ModelManager
|
| 3 |
+
import torch
|
| 4 |
+
from .omost import OmostPromter
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BeautifulPrompt(torch.nn.Module):
|
| 8 |
+
def __init__(self, tokenizer_path=None, model=None, template=""):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| 11 |
+
self.model = model
|
| 12 |
+
self.template = template
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def from_model_manager(model_manager: ModelManager):
|
| 16 |
+
model, model_path = model_manager.fetch_model(
|
| 17 |
+
"beautiful_prompt", require_model_path=True
|
| 18 |
+
)
|
| 19 |
+
template = "Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:"
|
| 20 |
+
if model_path.endswith("v2"):
|
| 21 |
+
template = """Converts a simple image description into a prompt. \
|
| 22 |
+
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
| 23 |
+
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
| 24 |
+
but make sure there is a correlation between the input and output.\n\
|
| 25 |
+
### Input: {raw_prompt}\n### Output:"""
|
| 26 |
+
beautiful_prompt = BeautifulPrompt(
|
| 27 |
+
tokenizer_path=model_path, model=model, template=template
|
| 28 |
+
)
|
| 29 |
+
return beautiful_prompt
|
| 30 |
+
|
| 31 |
+
def __call__(self, raw_prompt, positive=True, **kwargs):
|
| 32 |
+
if positive:
|
| 33 |
+
model_input = self.template.format(raw_prompt=raw_prompt)
|
| 34 |
+
input_ids = self.tokenizer.encode(model_input, return_tensors="pt").to(
|
| 35 |
+
self.model.device
|
| 36 |
+
)
|
| 37 |
+
outputs = self.model.generate(
|
| 38 |
+
input_ids,
|
| 39 |
+
max_new_tokens=384,
|
| 40 |
+
do_sample=True,
|
| 41 |
+
temperature=0.9,
|
| 42 |
+
top_k=50,
|
| 43 |
+
top_p=0.95,
|
| 44 |
+
repetition_penalty=1.1,
|
| 45 |
+
num_return_sequences=1,
|
| 46 |
+
)
|
| 47 |
+
prompt = (
|
| 48 |
+
raw_prompt
|
| 49 |
+
+ ", "
|
| 50 |
+
+ self.tokenizer.batch_decode(
|
| 51 |
+
outputs[:, input_ids.size(1) :], skip_special_tokens=True
|
| 52 |
+
)[0].strip()
|
| 53 |
+
)
|
| 54 |
+
print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
|
| 55 |
+
return prompt
|
| 56 |
+
else:
|
| 57 |
+
return raw_prompt
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class QwenPrompt(torch.nn.Module):
|
| 61 |
+
# This class leverages the open-source Qwen model to translate Chinese prompts into English,
|
| 62 |
+
# with an integrated optimization mechanism for enhanced translation quality.
|
| 63 |
+
def __init__(self, tokenizer_path=None, model=None, system_prompt=""):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| 66 |
+
self.model = model
|
| 67 |
+
self.system_prompt = system_prompt
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def from_model_manager(model_nameger: ModelManager):
|
| 71 |
+
model, model_path = model_nameger.fetch_model(
|
| 72 |
+
"qwen_prompt", require_model_path=True
|
| 73 |
+
)
|
| 74 |
+
system_prompt = """You are an English image describer. Here are some example image styles:\n\n1. Extreme close-up: Clear focus on a single object with a blurred background, highlighted under natural sunlight.\n2. Vintage: A photograph of a historical scene, using techniques such as Daguerreotype or cyanotype.\n3. Anime: A stylized cartoon image, emphasizing hyper-realistic portraits and luminous brushwork.\n4. Candid: A natural, unposed shot capturing spontaneous moments, often with cinematic qualities.\n5. Landscape: A photorealistic image of natural scenery, such as a sunrise over the sea.\n6. Design: Colorful and detailed illustrations, often in the style of 2D game art or botanical illustrations.\n7. Urban: An ultrarealistic scene in a modern setting, possibly a cityscape viewed from indoors.\n\nYour task is to translate a given Chinese image description into a concise and precise English description. Ensure that the imagery is vivid and descriptive, and include stylistic elements to enrich the description.\nPlease note the following points:\n\n1. Capture the essence and mood of the Chinese description without including direct phrases or words from the examples provided.\n2. You should add appropriate words to make the images described in the prompt more aesthetically pleasing. If the Chinese description does not specify a style, you need to add some stylistic descriptions based on the essence of the Chinese text.\n3. The generated English description should not exceed 200 words.\n\n"""
|
| 75 |
+
qwen_prompt = QwenPrompt(
|
| 76 |
+
tokenizer_path=model_path, model=model, system_prompt=system_prompt
|
| 77 |
+
)
|
| 78 |
+
return qwen_prompt
|
| 79 |
+
|
| 80 |
+
def __call__(self, raw_prompt, positive=True, **kwargs):
|
| 81 |
+
if positive:
|
| 82 |
+
messages = [
|
| 83 |
+
{"role": "system", "content": self.system_prompt},
|
| 84 |
+
{"role": "user", "content": raw_prompt},
|
| 85 |
+
]
|
| 86 |
+
text = self.tokenizer.apply_chat_template(
|
| 87 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 88 |
+
)
|
| 89 |
+
model_inputs = self.tokenizer([text], return_tensors="pt").to(
|
| 90 |
+
self.model.device
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
generated_ids = self.model.generate(
|
| 94 |
+
model_inputs.input_ids, max_new_tokens=512
|
| 95 |
+
)
|
| 96 |
+
generated_ids = [
|
| 97 |
+
output_ids[len(input_ids) :]
|
| 98 |
+
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
prompt = self.tokenizer.batch_decode(
|
| 102 |
+
generated_ids, skip_special_tokens=True
|
| 103 |
+
)[0]
|
| 104 |
+
print(f"Your prompt is refined by Qwen: {prompt}")
|
| 105 |
+
return prompt
|
| 106 |
+
else:
|
| 107 |
+
return raw_prompt
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class Translator(torch.nn.Module):
|
| 111 |
+
def __init__(self, tokenizer_path=None, model=None):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| 114 |
+
self.model = model
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def from_model_manager(model_manager: ModelManager):
|
| 118 |
+
model, model_path = model_manager.fetch_model(
|
| 119 |
+
"translator", require_model_path=True
|
| 120 |
+
)
|
| 121 |
+
translator = Translator(tokenizer_path=model_path, model=model)
|
| 122 |
+
return translator
|
| 123 |
+
|
| 124 |
+
def __call__(self, prompt, **kwargs):
|
| 125 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
| 126 |
+
self.model.device
|
| 127 |
+
)
|
| 128 |
+
output_ids = self.model.generate(input_ids)
|
| 129 |
+
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
| 130 |
+
print(f"Your prompt is translated: {prompt}")
|
| 131 |
+
return prompt
|
prompters/wan_prompter.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_prompter import BasePrompter
|
| 2 |
+
from models.wan_video_text_encoder import WanTextEncoder
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
import os, torch
|
| 5 |
+
import ftfy
|
| 6 |
+
import html
|
| 7 |
+
import string
|
| 8 |
+
import regex as re
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def basic_clean(text):
|
| 12 |
+
text = ftfy.fix_text(text)
|
| 13 |
+
text = html.unescape(html.unescape(text))
|
| 14 |
+
return text.strip()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def whitespace_clean(text):
|
| 18 |
+
text = re.sub(r"\s+", " ", text)
|
| 19 |
+
text = text.strip()
|
| 20 |
+
return text
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
| 24 |
+
text = text.replace("_", " ")
|
| 25 |
+
if keep_punctuation_exact_string:
|
| 26 |
+
text = keep_punctuation_exact_string.join(
|
| 27 |
+
part.translate(str.maketrans("", "", string.punctuation))
|
| 28 |
+
for part in text.split(keep_punctuation_exact_string)
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
| 32 |
+
text = text.lower()
|
| 33 |
+
text = re.sub(r"\s+", " ", text)
|
| 34 |
+
return text.strip()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class HuggingfaceTokenizer:
|
| 38 |
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
| 39 |
+
assert clean in (None, "whitespace", "lower", "canonicalize")
|
| 40 |
+
self.name = name
|
| 41 |
+
self.seq_len = seq_len
|
| 42 |
+
self.clean = clean
|
| 43 |
+
|
| 44 |
+
# init tokenizer
|
| 45 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
| 46 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 47 |
+
|
| 48 |
+
def __call__(self, sequence, **kwargs):
|
| 49 |
+
return_mask = kwargs.pop("return_mask", False)
|
| 50 |
+
|
| 51 |
+
# arguments
|
| 52 |
+
_kwargs = {"return_tensors": "pt"}
|
| 53 |
+
if self.seq_len is not None:
|
| 54 |
+
_kwargs.update(
|
| 55 |
+
{
|
| 56 |
+
"padding": "max_length",
|
| 57 |
+
"truncation": True,
|
| 58 |
+
"max_length": self.seq_len,
|
| 59 |
+
}
|
| 60 |
+
)
|
| 61 |
+
_kwargs.update(**kwargs)
|
| 62 |
+
|
| 63 |
+
# tokenization
|
| 64 |
+
if isinstance(sequence, str):
|
| 65 |
+
sequence = [sequence]
|
| 66 |
+
if self.clean:
|
| 67 |
+
sequence = [self._clean(u) for u in sequence]
|
| 68 |
+
ids = self.tokenizer(sequence, **_kwargs)
|
| 69 |
+
|
| 70 |
+
# output
|
| 71 |
+
if return_mask:
|
| 72 |
+
return ids.input_ids, ids.attention_mask
|
| 73 |
+
else:
|
| 74 |
+
return ids.input_ids
|
| 75 |
+
|
| 76 |
+
def _clean(self, text):
|
| 77 |
+
if self.clean == "whitespace":
|
| 78 |
+
text = whitespace_clean(basic_clean(text))
|
| 79 |
+
elif self.clean == "lower":
|
| 80 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 81 |
+
elif self.clean == "canonicalize":
|
| 82 |
+
text = canonicalize(basic_clean(text))
|
| 83 |
+
return text
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class WanPrompter(BasePrompter):
|
| 87 |
+
def __init__(self, tokenizer_path=None, text_len=512):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.text_len = text_len
|
| 90 |
+
self.text_encoder = None
|
| 91 |
+
self.fetch_tokenizer(tokenizer_path)
|
| 92 |
+
|
| 93 |
+
def fetch_tokenizer(self, tokenizer_path=None):
|
| 94 |
+
if tokenizer_path is not None:
|
| 95 |
+
self.tokenizer = HuggingfaceTokenizer(
|
| 96 |
+
name=tokenizer_path, seq_len=self.text_len, clean="whitespace"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def fetch_models(self, text_encoder: WanTextEncoder = None):
|
| 100 |
+
self.text_encoder = text_encoder
|
| 101 |
+
|
| 102 |
+
def encode_prompt(self, prompt, positive=True, device="cuda"):
|
| 103 |
+
prompt = self.process_prompt(prompt, positive=positive)
|
| 104 |
+
|
| 105 |
+
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
| 106 |
+
ids = ids.to(device)
|
| 107 |
+
mask = mask.to(device)
|
| 108 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 109 |
+
prompt_emb = self.text_encoder(ids, mask)
|
| 110 |
+
for i, v in enumerate(seq_lens):
|
| 111 |
+
prompt_emb[:, v:] = 0
|
| 112 |
+
return prompt_emb
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.7.0
|
| 2 |
+
torchvision==0.22.0
|
| 3 |
+
ftfy==6.3.1
|
| 4 |
+
huggingface_hub==0.31.1
|
| 5 |
+
imageio==2.37.0
|
| 6 |
+
insightface==0.7.3
|
| 7 |
+
numpy==2.2.6
|
| 8 |
+
opencv_python==4.11.0.86
|
| 9 |
+
Pillow==11.3.0
|
| 10 |
+
safetensors==0.5.3
|
| 11 |
+
tqdm==4.67.1
|
| 12 |
+
transformers==4.46.2
|
| 13 |
+
facexlib==0.3.0
|
| 14 |
+
einops==0.8.1
|
| 15 |
+
onnxruntime-gpu==1.22.0
|
| 16 |
+
imageio-ffmpeg==0.6.0
|
| 17 |
+
scikit-image==0.25.2
|
schedulers/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .ddim import EnhancedDDIMScheduler
|
| 2 |
+
from .continuous_ode import ContinuousODEScheduler
|
| 3 |
+
from .flow_match import FlowMatchScheduler
|
schedulers/continuous_ode.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ContinuousODEScheduler:
|
| 5 |
+
def __init__(
|
| 6 |
+
self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0
|
| 7 |
+
):
|
| 8 |
+
self.sigma_max = sigma_max
|
| 9 |
+
self.sigma_min = sigma_min
|
| 10 |
+
self.rho = rho
|
| 11 |
+
self.set_timesteps(num_inference_steps)
|
| 12 |
+
|
| 13 |
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, **kwargs):
|
| 14 |
+
ramp = torch.linspace(1 - denoising_strength, 1, num_inference_steps)
|
| 15 |
+
min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
|
| 16 |
+
max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
|
| 17 |
+
self.sigmas = torch.pow(
|
| 18 |
+
max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho
|
| 19 |
+
)
|
| 20 |
+
self.timesteps = torch.log(self.sigmas) * 0.25
|
| 21 |
+
|
| 22 |
+
def step(self, model_output, timestep, sample, to_final=False):
|
| 23 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 24 |
+
sigma = self.sigmas[timestep_id]
|
| 25 |
+
sample *= (sigma * sigma + 1).sqrt()
|
| 26 |
+
estimated_sample = (
|
| 27 |
+
-sigma / (sigma * sigma + 1).sqrt() * model_output
|
| 28 |
+
+ 1 / (sigma * sigma + 1) * sample
|
| 29 |
+
)
|
| 30 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
| 31 |
+
prev_sample = estimated_sample
|
| 32 |
+
else:
|
| 33 |
+
sigma_ = self.sigmas[timestep_id + 1]
|
| 34 |
+
derivative = 1 / sigma * (sample - estimated_sample)
|
| 35 |
+
prev_sample = sample + derivative * (sigma_ - sigma)
|
| 36 |
+
prev_sample /= (sigma_ * sigma_ + 1).sqrt()
|
| 37 |
+
return prev_sample
|
| 38 |
+
|
| 39 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
| 40 |
+
# This scheduler doesn't support this function.
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
def add_noise(self, original_samples, noise, timestep):
|
| 44 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 45 |
+
sigma = self.sigmas[timestep_id]
|
| 46 |
+
sample = (original_samples + noise * sigma) / (sigma * sigma + 1).sqrt()
|
| 47 |
+
return sample
|
| 48 |
+
|
| 49 |
+
def training_target(self, sample, noise, timestep):
|
| 50 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 51 |
+
sigma = self.sigmas[timestep_id]
|
| 52 |
+
target = (
|
| 53 |
+
-(sigma * sigma + 1).sqrt() / sigma + 1 / (sigma * sigma + 1).sqrt() / sigma
|
| 54 |
+
) * sample + 1 / (sigma * sigma + 1).sqrt() * noise
|
| 55 |
+
return target
|
| 56 |
+
|
| 57 |
+
def training_weight(self, timestep):
|
| 58 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 59 |
+
sigma = self.sigmas[timestep_id]
|
| 60 |
+
weight = (1 + sigma * sigma).sqrt() / sigma
|
| 61 |
+
return weight
|
schedulers/ddim.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, math
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class EnhancedDDIMScheduler:
|
| 5 |
+
def __init__(
|
| 6 |
+
self,
|
| 7 |
+
num_train_timesteps=1000,
|
| 8 |
+
beta_start=0.00085,
|
| 9 |
+
beta_end=0.012,
|
| 10 |
+
beta_schedule="scaled_linear",
|
| 11 |
+
prediction_type="epsilon",
|
| 12 |
+
rescale_zero_terminal_snr=False,
|
| 13 |
+
):
|
| 14 |
+
self.num_train_timesteps = num_train_timesteps
|
| 15 |
+
if beta_schedule == "scaled_linear":
|
| 16 |
+
betas = torch.square(
|
| 17 |
+
torch.linspace(
|
| 18 |
+
math.sqrt(beta_start),
|
| 19 |
+
math.sqrt(beta_end),
|
| 20 |
+
num_train_timesteps,
|
| 21 |
+
dtype=torch.float32,
|
| 22 |
+
)
|
| 23 |
+
)
|
| 24 |
+
elif beta_schedule == "linear":
|
| 25 |
+
betas = torch.linspace(
|
| 26 |
+
beta_start, beta_end, num_train_timesteps, dtype=torch.float32
|
| 27 |
+
)
|
| 28 |
+
else:
|
| 29 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
| 30 |
+
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
|
| 31 |
+
if rescale_zero_terminal_snr:
|
| 32 |
+
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
|
| 33 |
+
self.alphas_cumprod = self.alphas_cumprod.tolist()
|
| 34 |
+
self.set_timesteps(10)
|
| 35 |
+
self.prediction_type = prediction_type
|
| 36 |
+
|
| 37 |
+
def rescale_zero_terminal_snr(self, alphas_cumprod):
|
| 38 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 39 |
+
|
| 40 |
+
# Store old values.
|
| 41 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 42 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 43 |
+
|
| 44 |
+
# Shift so the last timestep is zero.
|
| 45 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 46 |
+
|
| 47 |
+
# Scale so the first timestep is back to the old value.
|
| 48 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 49 |
+
|
| 50 |
+
# Convert alphas_bar_sqrt to betas
|
| 51 |
+
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
|
| 52 |
+
|
| 53 |
+
return alphas_bar
|
| 54 |
+
|
| 55 |
+
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
|
| 56 |
+
# The timesteps are aligned to 999...0, which is different from other implementations,
|
| 57 |
+
# but I think this implementation is more reasonable in theory.
|
| 58 |
+
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
| 59 |
+
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
| 60 |
+
if num_inference_steps == 1:
|
| 61 |
+
self.timesteps = torch.Tensor([max_timestep])
|
| 62 |
+
else:
|
| 63 |
+
step_length = max_timestep / (num_inference_steps - 1)
|
| 64 |
+
self.timesteps = torch.Tensor(
|
| 65 |
+
[
|
| 66 |
+
round(max_timestep - i * step_length)
|
| 67 |
+
for i in range(num_inference_steps)
|
| 68 |
+
]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
| 72 |
+
if self.prediction_type == "epsilon":
|
| 73 |
+
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(
|
| 74 |
+
alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t
|
| 75 |
+
)
|
| 76 |
+
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
| 77 |
+
prev_sample = sample * weight_x + model_output * weight_e
|
| 78 |
+
elif self.prediction_type == "v_prediction":
|
| 79 |
+
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(
|
| 80 |
+
alpha_prod_t * (1 - alpha_prod_t_prev)
|
| 81 |
+
)
|
| 82 |
+
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt(
|
| 83 |
+
(1 - alpha_prod_t) * (1 - alpha_prod_t_prev)
|
| 84 |
+
)
|
| 85 |
+
prev_sample = sample * weight_x + model_output * weight_e
|
| 86 |
+
else:
|
| 87 |
+
raise NotImplementedError(f"{self.prediction_type} is not implemented")
|
| 88 |
+
return prev_sample
|
| 89 |
+
|
| 90 |
+
def step(self, model_output, timestep, sample, to_final=False):
|
| 91 |
+
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
| 92 |
+
if isinstance(timestep, torch.Tensor):
|
| 93 |
+
timestep = timestep.cpu()
|
| 94 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 95 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
| 96 |
+
alpha_prod_t_prev = 1.0
|
| 97 |
+
else:
|
| 98 |
+
timestep_prev = int(self.timesteps[timestep_id + 1])
|
| 99 |
+
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
| 100 |
+
|
| 101 |
+
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
| 102 |
+
|
| 103 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
| 104 |
+
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
| 105 |
+
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(
|
| 106 |
+
1 - alpha_prod_t
|
| 107 |
+
)
|
| 108 |
+
return noise_pred
|
| 109 |
+
|
| 110 |
+
def add_noise(self, original_samples, noise, timestep):
|
| 111 |
+
sqrt_alpha_prod = math.sqrt(
|
| 112 |
+
self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
| 113 |
+
)
|
| 114 |
+
sqrt_one_minus_alpha_prod = math.sqrt(
|
| 115 |
+
1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
| 116 |
+
)
|
| 117 |
+
noisy_samples = (
|
| 118 |
+
sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
| 119 |
+
)
|
| 120 |
+
return noisy_samples
|
| 121 |
+
|
| 122 |
+
def training_target(self, sample, noise, timestep):
|
| 123 |
+
if self.prediction_type == "epsilon":
|
| 124 |
+
return noise
|
| 125 |
+
else:
|
| 126 |
+
sqrt_alpha_prod = math.sqrt(
|
| 127 |
+
self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
| 128 |
+
)
|
| 129 |
+
sqrt_one_minus_alpha_prod = math.sqrt(
|
| 130 |
+
1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
| 131 |
+
)
|
| 132 |
+
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
| 133 |
+
return target
|
| 134 |
+
|
| 135 |
+
def training_weight(self, timestep):
|
| 136 |
+
return 1.0
|
schedulers/flow_match.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class FlowMatchScheduler:
|
| 5 |
+
def __init__(
|
| 6 |
+
self,
|
| 7 |
+
num_inference_steps=100,
|
| 8 |
+
num_train_timesteps=1000,
|
| 9 |
+
shift=3.0,
|
| 10 |
+
sigma_max=1.0,
|
| 11 |
+
sigma_min=0.003 / 1.002,
|
| 12 |
+
inverse_timesteps=False,
|
| 13 |
+
extra_one_step=False,
|
| 14 |
+
reverse_sigmas=False,
|
| 15 |
+
):
|
| 16 |
+
self.num_train_timesteps = num_train_timesteps
|
| 17 |
+
self.shift = shift
|
| 18 |
+
self.sigma_max = sigma_max
|
| 19 |
+
self.sigma_min = sigma_min
|
| 20 |
+
self.inverse_timesteps = inverse_timesteps
|
| 21 |
+
self.extra_one_step = extra_one_step
|
| 22 |
+
self.reverse_sigmas = reverse_sigmas
|
| 23 |
+
self.set_timesteps(num_inference_steps)
|
| 24 |
+
|
| 25 |
+
def set_timesteps(
|
| 26 |
+
self,
|
| 27 |
+
num_inference_steps=100,
|
| 28 |
+
denoising_strength=1.0,
|
| 29 |
+
training=False,
|
| 30 |
+
shift=None,
|
| 31 |
+
):
|
| 32 |
+
if shift is not None:
|
| 33 |
+
self.shift = shift
|
| 34 |
+
sigma_start = (
|
| 35 |
+
self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
| 36 |
+
)
|
| 37 |
+
if self.extra_one_step:
|
| 38 |
+
self.sigmas = torch.linspace(
|
| 39 |
+
sigma_start, self.sigma_min, num_inference_steps + 1
|
| 40 |
+
)[:-1]
|
| 41 |
+
else:
|
| 42 |
+
self.sigmas = torch.linspace(
|
| 43 |
+
sigma_start, self.sigma_min, num_inference_steps
|
| 44 |
+
)
|
| 45 |
+
if self.inverse_timesteps:
|
| 46 |
+
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
| 47 |
+
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
| 48 |
+
if self.reverse_sigmas:
|
| 49 |
+
self.sigmas = 1 - self.sigmas
|
| 50 |
+
self.timesteps = self.sigmas * self.num_train_timesteps
|
| 51 |
+
if training:
|
| 52 |
+
x = self.timesteps
|
| 53 |
+
y = torch.exp(
|
| 54 |
+
-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2
|
| 55 |
+
)
|
| 56 |
+
y_shifted = y - y.min()
|
| 57 |
+
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
| 58 |
+
self.linear_timesteps_weights = bsmntw_weighing
|
| 59 |
+
self.training = True
|
| 60 |
+
else:
|
| 61 |
+
self.training = False
|
| 62 |
+
|
| 63 |
+
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
| 64 |
+
if isinstance(timestep, torch.Tensor):
|
| 65 |
+
timestep = timestep.cpu()
|
| 66 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 67 |
+
sigma = self.sigmas[timestep_id]
|
| 68 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
| 69 |
+
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
| 70 |
+
else:
|
| 71 |
+
sigma_ = self.sigmas[timestep_id + 1]
|
| 72 |
+
prev_sample = sample + model_output * (sigma_ - sigma)
|
| 73 |
+
return prev_sample
|
| 74 |
+
|
| 75 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
| 76 |
+
if isinstance(timestep, torch.Tensor):
|
| 77 |
+
timestep = timestep.cpu()
|
| 78 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 79 |
+
sigma = self.sigmas[timestep_id]
|
| 80 |
+
model_output = (sample - sample_stablized) / sigma
|
| 81 |
+
return model_output
|
| 82 |
+
|
| 83 |
+
def add_noise(self, original_samples, noise, timestep):
|
| 84 |
+
if isinstance(timestep, torch.Tensor):
|
| 85 |
+
timestep = timestep.cpu()
|
| 86 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 87 |
+
sigma = self.sigmas[timestep_id]
|
| 88 |
+
sample = (1 - sigma) * original_samples + sigma * noise
|
| 89 |
+
return sample
|
| 90 |
+
|
| 91 |
+
def training_target(self, sample, noise, timestep):
|
| 92 |
+
target = noise - sample
|
| 93 |
+
return target
|
| 94 |
+
|
| 95 |
+
def training_weight(self, timestep):
|
| 96 |
+
timestep_id = torch.argmin(
|
| 97 |
+
(self.timesteps - timestep.to(self.timesteps.device)).abs()
|
| 98 |
+
)
|
| 99 |
+
weights = self.linear_timesteps_weights[timestep_id]
|
| 100 |
+
return weights
|
test/input/first_frame.png
ADDED
|
Git LFS Details
|
test/input/lecun.jpg
ADDED
|
test/input/pose.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f2038896160fb990162832cbff6eaebcf05b25e1a3b8c201e5b147a4ce3ce01d
|
| 3 |
+
size 173260
|
test/input/ruonan.jpg
ADDED
|
Git LFS Details
|
test/input/woman.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee8fb303f53a89c0ab36c0457c9452149c58a881055becb4da8abc41766bc6db
|
| 3 |
+
size 8399484
|