Upload folder using huggingface_hub
Browse files- LICENSE +202 -0
- README.md +220 -19
- added_tokens.json +1 -0
- config.json +53 -0
- configuration_ernie4_5_moe.py +192 -0
- generation_config.json +13 -0
- model-00001-of-00014.safetensors +3 -0
- model-00002-of-00014.safetensors +3 -0
- model-00003-of-00014.safetensors +3 -0
- model-00004-of-00014.safetensors +3 -0
- model-00005-of-00014.safetensors +3 -0
- model-00006-of-00014.safetensors +3 -0
- model-00007-of-00014.safetensors +3 -0
- model-00008-of-00014.safetensors +3 -0
- model-00009-of-00014.safetensors +3 -0
- model-00010-of-00014.safetensors +3 -0
- model-00011-of-00014.safetensors +3 -0
- model-00012-of-00014.safetensors +3 -0
- model-00013-of-00014.safetensors +3 -0
- model-00014-of-00014.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_ernie4_5_moe.py +1590 -0
- quantization_config.json +0 -0
- special_tokens_map.json +1 -0
- tokenization_ernie4_5.py +352 -0
- tokenizer.json +0 -0
- tokenizer.model +3 -0
- tokenizer_config.json +22 -0
LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
| 202 |
+
|
README.md
CHANGED
|
@@ -1,22 +1,223 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
- zh
|
| 6 |
+
pipeline_tag: text-generation
|
| 7 |
+
tags:
|
| 8 |
+
- ERNIE4.5
|
| 9 |
+
library_name: transformers
|
| 10 |
---
|
| 11 |
+
|
| 12 |
+
<div align="center" style="line-height: 1;">
|
| 13 |
+
<a href="https://ernie.baidu.com/" target="_blank" style="margin: 2px;">
|
| 14 |
+
<img alt="Chat" src="https://img.shields.io/badge/🤖_Chat-ERNIE_Bot-blue" style="display: inline-block; vertical-align: middle;"/>
|
| 15 |
+
</a>
|
| 16 |
+
<a href="https://huggingface.co/baidu" target="_blank" style="margin: 2px;">
|
| 17 |
+
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Baidu-ffc107?color=ffc107&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
| 18 |
+
</a>
|
| 19 |
+
<a href="https://github.com/PaddlePaddle/ERNIE" target="_blank" style="margin: 2px;">
|
| 20 |
+
<img alt="Github" src="https://img.shields.io/badge/GitHub-ERNIE-000?logo=github&color=0000FF" style="display: inline-block; vertical-align: middle;"/>
|
| 21 |
+
</a>
|
| 22 |
+
<a href="https://ernie.baidu.com/blog/ernie4.5" target="_blank" style="margin: 2px;">
|
| 23 |
+
<img alt="Blog" src="https://img.shields.io/badge/🖖_Blog-ERNIE4.5-A020A0" style="display: inline-block; vertical-align: middle;"/>
|
| 24 |
+
</a>
|
| 25 |
+
</div>
|
| 26 |
+
|
| 27 |
+
<div align="center" style="line-height: 1;">
|
| 28 |
+
<a href="#license" style="margin: 2px;">
|
| 29 |
+
<img alt="License" src="https://img.shields.io/badge/License-Apache2.0-A5de54" style="display: inline-block; vertical-align: middle;"/>
|
| 30 |
+
</a>
|
| 31 |
+
</div>
|
| 32 |
+
|
| 33 |
+
# ERNIE-4.5-300B-A47B
|
| 34 |
+
|
| 35 |
+
> [!NOTE]
|
| 36 |
+
> Note: "**-Paddle**" models use [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) weights, while "**-PT**" models use Transformer-style PyTorch weights.
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
## ERNIE 4.5 Highlights
|
| 40 |
+
|
| 41 |
+
The advanced capabilities of the ERNIE 4.5 models, particularly the MoE-based A47B and A3B series, are underpinned by several key technical innovations:
|
| 42 |
+
|
| 43 |
+
1. **Multimodal Heterogeneous MoE Pre-Training:** Our models are jointly trained on both textual and visual modalities to better capture the nuances of multimodal information and improve performance on tasks involving text understanding and generation, image understanding, and cross-modal reasoning. To achieve this without one modality hindering the learning of another, we designed a *heterogeneous MoE structure*, incorporated *modality-isolated routing*, and employed *router orthogonal loss* and *multimodal token-balanced loss*. These architectural choices ensure that both modalities are effectively represented, allowing for mutual reinforcement during training.
|
| 44 |
+
|
| 45 |
+
2. **Scaling-Efficient Infrastructure:** We propose a novel heterogeneous hybrid parallelism and hierarchical load balancing strategy for efficient training of ERNIE 4.5 models. By using intra-node expert parallelism, memory-efficient pipeline scheduling, FP8 mixed-precision training and finegrained recomputation methods, we achieve remarkable pre-training throughput. For inference, we propose *multi-expert parallel collaboration* method and *convolutional code quantization* algorithm to achieve 4-bit/2-bit lossless quantization. Furthermore, we introduce PD disaggregation with dynamic role switching for effective resource utilization to enhance inference performance for ERNIE 4.5 MoE models. Built on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle), ERNIE 4.5 delivers high-performance inference across a wide range of hardware platforms.
|
| 46 |
+
|
| 47 |
+
3. **Modality-Specific Post-Training:** To meet the diverse requirements of real-world applications, we fine-tuned variants of the pre-trained model for specific modalities. Our LLMs are optimized for general-purpose language understanding and generation. The VLMs focuses on visuallanguage understanding and supports both thinking and non-thinking modes. Each model employed a combination of *Supervised Fine-tuning (SFT)*, *Direct Preference Optimization (DPO)* or a modified reinforcement learning method named *Unified Preference Optimization (UPO)* for post-training.
|
| 48 |
+
|
| 49 |
+
## Model Overview
|
| 50 |
+
|
| 51 |
+
ERNIE-4.5-300B-A47B is a text MoE Post-trained model, with 300B total parameters and 47B activated parameters for each token. The following are the model configuration details:
|
| 52 |
+
|
| 53 |
+
|Key|Value|
|
| 54 |
+
|-|-|
|
| 55 |
+
|Modality|Text|
|
| 56 |
+
|Training Stage|Pretraining|
|
| 57 |
+
|Params(Total / Activated)|300B / 47B|
|
| 58 |
+
|Layers|54|
|
| 59 |
+
|Heads(Q/KV)|64 / 8|
|
| 60 |
+
|Text Experts(Total / Activated)|64 / 8|
|
| 61 |
+
|Vision Experts(Total / Activated)|64 / 8|
|
| 62 |
+
|Context Length|131072|
|
| 63 |
+
|
| 64 |
+
## Quickstart
|
| 65 |
+
|
| 66 |
+
### Using `transformers` library
|
| 67 |
+
|
| 68 |
+
**Note**: Before using the model, please ensure you have the `transformers` library installed (version 4.50.0 or higher)
|
| 69 |
+
|
| 70 |
+
The following contains a code snippet illustrating how to use the model generate content based on given inputs.
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 74 |
+
|
| 75 |
+
model_name = "baidu/ERNIE-4.5-300B-A47B-PT"
|
| 76 |
+
|
| 77 |
+
# load the tokenizer and the model
|
| 78 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 79 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
| 80 |
+
|
| 81 |
+
# prepare the model input
|
| 82 |
+
prompt = "Give me a short introduction to large language model."
|
| 83 |
+
messages = [
|
| 84 |
+
{"role": "user", "content": prompt}
|
| 85 |
+
]
|
| 86 |
+
text = tokenizer.apply_chat_template(
|
| 87 |
+
messages,
|
| 88 |
+
tokenize=False,
|
| 89 |
+
add_generation_prompt=True
|
| 90 |
+
)
|
| 91 |
+
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
|
| 92 |
+
|
| 93 |
+
# conduct text completion
|
| 94 |
+
generated_ids = model.generate(
|
| 95 |
+
model_inputs.input_ids,
|
| 96 |
+
max_new_tokens=1024
|
| 97 |
+
)
|
| 98 |
+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
| 99 |
+
|
| 100 |
+
# decode the generated ids
|
| 101 |
+
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
|
| 102 |
+
print("generate_text:", generate_text)
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### Using vLLM
|
| 106 |
+
|
| 107 |
+
[vllm](https://github.com/vllm-project/vllm/tree/main) github library. Python-only [build](https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html#set-up-using-python-only-build-without-compilation).
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
# 80G * 16 GPU
|
| 111 |
+
vllm serve baidu/ERNIE-4.5-300B-A47B-PT --trust-remote-code
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
# FP8 online quantification 80G * 16 GPU
|
| 116 |
+
vllm serve baidu/ERNIE-4.5-300B-A47B-PT --trust-remote-code --quantization fp8
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
## Best Practices
|
| 120 |
+
|
| 121 |
+
### **Sampling Parameters**
|
| 122 |
+
|
| 123 |
+
To achieve optimal performance, we suggest using `Temperature=0.8`, `TopP=0.8`.
|
| 124 |
+
|
| 125 |
+
### Prompts for Web Search
|
| 126 |
+
|
| 127 |
+
For Web Search, {references}, {date}, and {question} are arguments.
|
| 128 |
+
|
| 129 |
+
For Chinese question, we use the prompt:
|
| 130 |
+
|
| 131 |
+
```python
|
| 132 |
+
ernie_search_zh_prompt = \
|
| 133 |
+
'''下面你会收到当前时间、多个不同来源的参考文章和一段对话。你的任务是阅读多个参考文章,并根据参考文章中的信息回答对话中的问题。
|
| 134 |
+
以下是当前时间和参考文章:
|
| 135 |
+
---------
|
| 136 |
+
#当前时间
|
| 137 |
+
{date}
|
| 138 |
+
|
| 139 |
+
#参考文章
|
| 140 |
+
{references}
|
| 141 |
+
|
| 142 |
+
---------
|
| 143 |
+
请注意:
|
| 144 |
+
1. 回答必须结合问题需求和当前时间,对参考文章的可用性进行判断,避免在回答中使用错误或过时的信息。
|
| 145 |
+
2. 当参考文章中的信息无法准确地回答问题时,你需要在回答中提供获取相应信息的建议,或承认无法提供相应信息。
|
| 146 |
+
3. 你需要优先根据百科、官网、权威机构、专业网站等高权威性来源的信息来回答问题。
|
| 147 |
+
4. 回复需要综合参考文章中的相关数字、案例、法律条文、公式等信息,使你的答案更专业。
|
| 148 |
+
5. 当问题属于创作类任务时,需注意以下维度:
|
| 149 |
+
- 态度鲜明:观点、立场清晰明确,避免模棱两可,语言果断直接
|
| 150 |
+
- 文采飞扬:用词精准生动,善用修辞手法,增强感染力
|
| 151 |
+
- 有理有据:逻辑严密递进,结合权威数据/事实支撑论点
|
| 152 |
+
---------
|
| 153 |
+
下面请结合以上信息,回答问题,补全对话
|
| 154 |
+
{question}'''
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
For English question, we use the prompt:
|
| 158 |
+
|
| 159 |
+
```python
|
| 160 |
+
ernie_search_en_prompt = \
|
| 161 |
+
'''
|
| 162 |
+
Below you will be given the current time, multiple references from different sources, and a conversation. Your task is to read the references and use the information in them to answer the question in the conversation.
|
| 163 |
+
Here are the current time and the references:
|
| 164 |
+
---------
|
| 165 |
+
#Current Time
|
| 166 |
+
{date}
|
| 167 |
+
|
| 168 |
+
#References
|
| 169 |
+
{references}
|
| 170 |
+
|
| 171 |
+
---------
|
| 172 |
+
Please note:
|
| 173 |
+
1. Based on the question’s requirements and the current time, assess the usefulness of the references to avoid using inaccurate or outdated information in the answer.
|
| 174 |
+
2. If the references do not provide enough information to accurately answer the question, you should suggest how to obtain the relevant information or acknowledge that you are unable to provide it.
|
| 175 |
+
3. Prioritize using information from highly authoritative sources such as encyclopedias, official websites, authoritative institutions, and professional websites when answering questions.
|
| 176 |
+
4. Incorporate relevant numbers, cases, legal provisions, formulas, and other details from the references to make your answer more professional.
|
| 177 |
+
5. For creative tasks, keep these dimensions in mind:
|
| 178 |
+
- Clear attitude: Clear views and positions, avoid ambiguity, and use decisive and direct language
|
| 179 |
+
- Brilliant writing: Precise and vivid words, good use of rhetoric, and enhance the appeal
|
| 180 |
+
- Well-reasoned: Rigorous logic and progressive, combined with authoritative data/facts to support the argument
|
| 181 |
+
|
| 182 |
+
---------
|
| 183 |
+
Now, using the information above, answer the question and complete the conversation:
|
| 184 |
+
{question}'''
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
Parameter notes:
|
| 188 |
+
|
| 189 |
+
* {question} is the user’s question
|
| 190 |
+
* {date} is the current time, and the recommended format is “YYYY-MM-DD HH:MM:SS, Day of the Week, Beijing/China.”
|
| 191 |
+
* {references} is the references, and the recommended format is:
|
| 192 |
+
|
| 193 |
+
```text
|
| 194 |
+
##参考文章1
|
| 195 |
+
标题:周杰伦
|
| 196 |
+
文章发布时间:2025-04-20
|
| 197 |
+
内容:周杰伦(Jay Chou),1979年1月18日出生于台湾省新北市,祖籍福建省永春县,华语流行乐男歌手、音乐人、演员、导演、编剧,毕业于淡江中学。2000年,发行个人首张音乐专辑《Jay》。...
|
| 198 |
+
来源网站网址:baike.baidu.com
|
| 199 |
+
来源网站的网站名:百度百科
|
| 200 |
+
|
| 201 |
+
##参考文章2
|
| 202 |
+
...
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
## License
|
| 206 |
+
|
| 207 |
+
The ERNIE 4.5 models are provided under the Apache License 2.0. This license permits commercial use, subject to its terms and conditions. Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
| 208 |
+
|
| 209 |
+
## Citation
|
| 210 |
+
|
| 211 |
+
If you find ERNIE 4.5 useful or wish to use it in your projects, please kindly cite our technical report:
|
| 212 |
+
|
| 213 |
+
```bibtex
|
| 214 |
+
@misc{ernie2025technicalreport,
|
| 215 |
+
title={ERNIE 4.5 Technical Report},
|
| 216 |
+
author={Baidu ERNIE Team},
|
| 217 |
+
year={2025},
|
| 218 |
+
eprint={},
|
| 219 |
+
archivePrefix={arXiv},
|
| 220 |
+
primaryClass={cs.CL},
|
| 221 |
+
url={}
|
| 222 |
+
}
|
| 223 |
+
```
|
added_tokens.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"<|IMAGE_PLACEHOLDER|>": 100295, "<|AUDIO_PLACEHOLDER|>": 100296, "<|LOC_0|>": 100297, "<|LOC_1|>": 100298, "<|LOC_2|>": 100299, "<|LOC_3|>": 100300, "<|LOC_4|>": 100301, "<|LOC_5|>": 100302, "<|LOC_6|>": 100303, "<|LOC_7|>": 100304, "<|LOC_8|>": 100305, "<|LOC_9|>": 100306, "<|LOC_10|>": 100307, "<|LOC_11|>": 100308, "<|LOC_12|>": 100309, "<|LOC_13|>": 100310, "<|LOC_14|>": 100311, "<|LOC_15|>": 100312, "<|LOC_16|>": 100313, "<|LOC_17|>": 100314, "<|LOC_18|>": 100315, "<|LOC_19|>": 100316, "<|LOC_20|>": 100317, "<|LOC_21|>": 100318, "<|LOC_22|>": 100319, "<|LOC_23|>": 100320, "<|LOC_24|>": 100321, "<|LOC_25|>": 100322, "<|LOC_26|>": 100323, "<|LOC_27|>": 100324, "<|LOC_28|>": 100325, "<|LOC_29|>": 100326, "<|LOC_30|>": 100327, "<|LOC_31|>": 100328, "<|LOC_32|>": 100329, "<|LOC_33|>": 100330, "<|LOC_34|>": 100331, "<|LOC_35|>": 100332, "<|LOC_36|>": 100333, "<|LOC_37|>": 100334, "<|LOC_38|>": 100335, "<|LOC_39|>": 100336, "<|LOC_40|>": 100337, "<|LOC_41|>": 100338, "<|LOC_42|>": 100339, "<|LOC_43|>": 100340, "<|LOC_44|>": 100341, "<|LOC_45|>": 100342, "<|LOC_46|>": 100343, "<|LOC_47|>": 100344, "<|LOC_48|>": 100345, "<|LOC_49|>": 100346, "<|LOC_50|>": 100347, "<|LOC_51|>": 100348, "<|LOC_52|>": 100349, "<|LOC_53|>": 100350, "<|LOC_54|>": 100351, "<|LOC_55|>": 100352, "<|LOC_56|>": 100353, "<|LOC_57|>": 100354, "<|LOC_58|>": 100355, "<|LOC_59|>": 100356, "<|LOC_60|>": 100357, "<|LOC_61|>": 100358, "<|LOC_62|>": 100359, "<|LOC_63|>": 100360, "<|LOC_64|>": 100361, "<|LOC_65|>": 100362, "<|LOC_66|>": 100363, "<|LOC_67|>": 100364, "<|LOC_68|>": 100365, "<|LOC_69|>": 100366, "<|LOC_70|>": 100367, "<|LOC_71|>": 100368, "<|LOC_72|>": 100369, "<|LOC_73|>": 100370, "<|LOC_74|>": 100371, "<|LOC_75|>": 100372, "<|LOC_76|>": 100373, "<|LOC_77|>": 100374, "<|LOC_78|>": 100375, "<|LOC_79|>": 100376, "<|LOC_80|>": 100377, "<|LOC_81|>": 100378, "<|LOC_82|>": 100379, "<|LOC_83|>": 100380, "<|LOC_84|>": 100381, "<|LOC_85|>": 100382, "<|LOC_86|>": 100383, "<|LOC_87|>": 100384, "<|LOC_88|>": 100385, "<|LOC_89|>": 100386, "<|LOC_90|>": 100387, "<|LOC_91|>": 100388, "<|LOC_92|>": 100389, "<|LOC_93|>": 100390, "<|LOC_94|>": 100391, "<|LOC_95|>": 100392, "<|LOC_96|>": 100393, "<|LOC_97|>": 100394, "<|LOC_98|>": 100395, "<|LOC_99|>": 100396, "<|LOC_100|>": 100397, "<|LOC_101|>": 100398, "<|LOC_102|>": 100399, "<|LOC_103|>": 100400, "<|LOC_104|>": 100401, "<|LOC_105|>": 100402, "<|LOC_106|>": 100403, "<|LOC_107|>": 100404, "<|LOC_108|>": 100405, "<|LOC_109|>": 100406, "<|LOC_110|>": 100407, "<|LOC_111|>": 100408, "<|LOC_112|>": 100409, "<|LOC_113|>": 100410, "<|LOC_114|>": 100411, "<|LOC_115|>": 100412, "<|LOC_116|>": 100413, "<|LOC_117|>": 100414, "<|LOC_118|>": 100415, "<|LOC_119|>": 100416, "<|LOC_120|>": 100417, "<|LOC_121|>": 100418, "<|LOC_122|>": 100419, "<|LOC_123|>": 100420, "<|LOC_124|>": 100421, "<|LOC_125|>": 100422, "<|LOC_126|>": 100423, "<|LOC_127|>": 100424, "<|LOC_128|>": 100425, "<|LOC_129|>": 100426, "<|LOC_130|>": 100427, "<|LOC_131|>": 100428, "<|LOC_132|>": 100429, "<|LOC_133|>": 100430, "<|LOC_134|>": 100431, "<|LOC_135|>": 100432, "<|LOC_136|>": 100433, "<|LOC_137|>": 100434, "<|LOC_138|>": 100435, "<|LOC_139|>": 100436, "<|LOC_140|>": 100437, "<|LOC_141|>": 100438, "<|LOC_142|>": 100439, "<|LOC_143|>": 100440, "<|LOC_144|>": 100441, "<|LOC_145|>": 100442, "<|LOC_146|>": 100443, "<|LOC_147|>": 100444, "<|LOC_148|>": 100445, "<|LOC_149|>": 100446, "<|LOC_150|>": 100447, "<|LOC_151|>": 100448, "<|LOC_152|>": 100449, "<|LOC_153|>": 100450, "<|LOC_154|>": 100451, "<|LOC_155|>": 100452, "<|LOC_156|>": 100453, "<|LOC_157|>": 100454, "<|LOC_158|>": 100455, "<|LOC_159|>": 100456, "<|LOC_160|>": 100457, "<|LOC_161|>": 100458, "<|LOC_162|>": 100459, "<|LOC_163|>": 100460, "<|LOC_164|>": 100461, "<|LOC_165|>": 100462, "<|LOC_166|>": 100463, "<|LOC_167|>": 100464, "<|LOC_168|>": 100465, "<|LOC_169|>": 100466, "<|LOC_170|>": 100467, "<|LOC_171|>": 100468, "<|LOC_172|>": 100469, "<|LOC_173|>": 100470, "<|LOC_174|>": 100471, "<|LOC_175|>": 100472, "<|LOC_176|>": 100473, "<|LOC_177|>": 100474, "<|LOC_178|>": 100475, "<|LOC_179|>": 100476, "<|LOC_180|>": 100477, "<|LOC_181|>": 100478, "<|LOC_182|>": 100479, "<|LOC_183|>": 100480, "<|LOC_184|>": 100481, "<|LOC_185|>": 100482, "<|LOC_186|>": 100483, "<|LOC_187|>": 100484, "<|LOC_188|>": 100485, "<|LOC_189|>": 100486, "<|LOC_190|>": 100487, "<|LOC_191|>": 100488, "<|LOC_192|>": 100489, "<|LOC_193|>": 100490, "<|LOC_194|>": 100491, "<|LOC_195|>": 100492, "<|LOC_196|>": 100493, "<|LOC_197|>": 100494, "<|LOC_198|>": 100495, "<|LOC_199|>": 100496, "<|LOC_200|>": 100497, "<|LOC_201|>": 100498, "<|LOC_202|>": 100499, "<|LOC_203|>": 100500, "<|LOC_204|>": 100501, "<|LOC_205|>": 100502, "<|LOC_206|>": 100503, "<|LOC_207|>": 100504, "<|LOC_208|>": 100505, "<|LOC_209|>": 100506, "<|LOC_210|>": 100507, "<|LOC_211|>": 100508, "<|LOC_212|>": 100509, "<|LOC_213|>": 100510, "<|LOC_214|>": 100511, "<|LOC_215|>": 100512, "<|LOC_216|>": 100513, "<|LOC_217|>": 100514, "<|LOC_218|>": 100515, "<|LOC_219|>": 100516, "<|LOC_220|>": 100517, "<|LOC_221|>": 100518, "<|LOC_222|>": 100519, "<|LOC_223|>": 100520, "<|LOC_224|>": 100521, "<|LOC_225|>": 100522, "<|LOC_226|>": 100523, "<|LOC_227|>": 100524, "<|LOC_228|>": 100525, "<|LOC_229|>": 100526, "<|LOC_230|>": 100527, "<|LOC_231|>": 100528, "<|LOC_232|>": 100529, "<|LOC_233|>": 100530, "<|LOC_234|>": 100531, "<|LOC_235|>": 100532, "<|LOC_236|>": 100533, "<|LOC_237|>": 100534, "<|LOC_238|>": 100535, "<|LOC_239|>": 100536, "<|LOC_240|>": 100537, "<|LOC_241|>": 100538, "<|LOC_242|>": 100539, "<|LOC_243|>": 100540, "<|LOC_244|>": 100541, "<|LOC_245|>": 100542, "<|LOC_246|>": 100543, "<|LOC_247|>": 100544, "<|LOC_248|>": 100545, "<|LOC_249|>": 100546, "<|LOC_250|>": 100547, "<|LOC_251|>": 100548, "<|LOC_252|>": 100549, "<|LOC_253|>": 100550, "<|LOC_254|>": 100551, "<|LOC_255|>": 100552, "<|LOC_256|>": 100553, "<|LOC_257|>": 100554, "<|LOC_258|>": 100555, "<|LOC_259|>": 100556, "<|LOC_260|>": 100557, "<|LOC_261|>": 100558, "<|LOC_262|>": 100559, "<|LOC_263|>": 100560, "<|LOC_264|>": 100561, "<|LOC_265|>": 100562, "<|LOC_266|>": 100563, "<|LOC_267|>": 100564, "<|LOC_268|>": 100565, "<|LOC_269|>": 100566, "<|LOC_270|>": 100567, "<|LOC_271|>": 100568, "<|LOC_272|>": 100569, "<|LOC_273|>": 100570, "<|LOC_274|>": 100571, "<|LOC_275|>": 100572, "<|LOC_276|>": 100573, "<|LOC_277|>": 100574, "<|LOC_278|>": 100575, "<|LOC_279|>": 100576, "<|LOC_280|>": 100577, "<|LOC_281|>": 100578, "<|LOC_282|>": 100579, "<|LOC_283|>": 100580, "<|LOC_284|>": 100581, "<|LOC_285|>": 100582, "<|LOC_286|>": 100583, "<|LOC_287|>": 100584, "<|LOC_288|>": 100585, "<|LOC_289|>": 100586, "<|LOC_290|>": 100587, "<|LOC_291|>": 100588, "<|LOC_292|>": 100589, "<|LOC_293|>": 100590, "<|LOC_294|>": 100591, "<|LOC_295|>": 100592, "<|LOC_296|>": 100593, "<|LOC_297|>": 100594, "<|LOC_298|>": 100595, "<|LOC_299|>": 100596, "<|LOC_300|>": 100597, "<|LOC_301|>": 100598, "<|LOC_302|>": 100599, "<|LOC_303|>": 100600, "<|LOC_304|>": 100601, "<|LOC_305|>": 100602, "<|LOC_306|>": 100603, "<|LOC_307|>": 100604, "<|LOC_308|>": 100605, "<|LOC_309|>": 100606, "<|LOC_310|>": 100607, "<|LOC_311|>": 100608, "<|LOC_312|>": 100609, "<|LOC_313|>": 100610, "<|LOC_314|>": 100611, "<|LOC_315|>": 100612, "<|LOC_316|>": 100613, "<|LOC_317|>": 100614, "<|LOC_318|>": 100615, "<|LOC_319|>": 100616, "<|LOC_320|>": 100617, "<|LOC_321|>": 100618, "<|LOC_322|>": 100619, "<|LOC_323|>": 100620, "<|LOC_324|>": 100621, "<|LOC_325|>": 100622, "<|LOC_326|>": 100623, "<|LOC_327|>": 100624, "<|LOC_328|>": 100625, "<|LOC_329|>": 100626, "<|LOC_330|>": 100627, "<|LOC_331|>": 100628, "<|LOC_332|>": 100629, "<|LOC_333|>": 100630, "<|LOC_334|>": 100631, "<|LOC_335|>": 100632, "<|LOC_336|>": 100633, "<|LOC_337|>": 100634, "<|LOC_338|>": 100635, "<|LOC_339|>": 100636, "<|LOC_340|>": 100637, "<|LOC_341|>": 100638, "<|LOC_342|>": 100639, "<|LOC_343|>": 100640, "<|LOC_344|>": 100641, "<|LOC_345|>": 100642, "<|LOC_346|>": 100643, "<|LOC_347|>": 100644, "<|LOC_348|>": 100645, "<|LOC_349|>": 100646, "<|LOC_350|>": 100647, "<|LOC_351|>": 100648, "<|LOC_352|>": 100649, "<|LOC_353|>": 100650, "<|LOC_354|>": 100651, "<|LOC_355|>": 100652, "<|LOC_356|>": 100653, "<|LOC_357|>": 100654, "<|LOC_358|>": 100655, "<|LOC_359|>": 100656, "<|LOC_360|>": 100657, "<|LOC_361|>": 100658, "<|LOC_362|>": 100659, "<|LOC_363|>": 100660, "<|LOC_364|>": 100661, "<|LOC_365|>": 100662, "<|LOC_366|>": 100663, "<|LOC_367|>": 100664, "<|LOC_368|>": 100665, "<|LOC_369|>": 100666, "<|LOC_370|>": 100667, "<|LOC_371|>": 100668, "<|LOC_372|>": 100669, "<|LOC_373|>": 100670, "<|LOC_374|>": 100671, "<|LOC_375|>": 100672, "<|LOC_376|>": 100673, "<|LOC_377|>": 100674, "<|LOC_378|>": 100675, "<|LOC_379|>": 100676, "<|LOC_380|>": 100677, "<|LOC_381|>": 100678, "<|LOC_382|>": 100679, "<|LOC_383|>": 100680, "<|LOC_384|>": 100681, "<|LOC_385|>": 100682, "<|LOC_386|>": 100683, "<|LOC_387|>": 100684, "<|LOC_388|>": 100685, "<|LOC_389|>": 100686, "<|LOC_390|>": 100687, "<|LOC_391|>": 100688, "<|LOC_392|>": 100689, "<|LOC_393|>": 100690, "<|LOC_394|>": 100691, "<|LOC_395|>": 100692, "<|LOC_396|>": 100693, "<|LOC_397|>": 100694, "<|LOC_398|>": 100695, "<|LOC_399|>": 100696, "<|LOC_400|>": 100697, "<|LOC_401|>": 100698, "<|LOC_402|>": 100699, "<|LOC_403|>": 100700, "<|LOC_404|>": 100701, "<|LOC_405|>": 100702, "<|LOC_406|>": 100703, "<|LOC_407|>": 100704, "<|LOC_408|>": 100705, "<|LOC_409|>": 100706, "<|LOC_410|>": 100707, "<|LOC_411|>": 100708, "<|LOC_412|>": 100709, "<|LOC_413|>": 100710, "<|LOC_414|>": 100711, "<|LOC_415|>": 100712, "<|LOC_416|>": 100713, "<|LOC_417|>": 100714, "<|LOC_418|>": 100715, "<|LOC_419|>": 100716, "<|LOC_420|>": 100717, "<|LOC_421|>": 100718, "<|LOC_422|>": 100719, "<|LOC_423|>": 100720, "<|LOC_424|>": 100721, "<|LOC_425|>": 100722, "<|LOC_426|>": 100723, "<|LOC_427|>": 100724, "<|LOC_428|>": 100725, "<|LOC_429|>": 100726, "<|LOC_430|>": 100727, "<|LOC_431|>": 100728, "<|LOC_432|>": 100729, "<|LOC_433|>": 100730, "<|LOC_434|>": 100731, "<|LOC_435|>": 100732, "<|LOC_436|>": 100733, "<|LOC_437|>": 100734, "<|LOC_438|>": 100735, "<|LOC_439|>": 100736, "<|LOC_440|>": 100737, "<|LOC_441|>": 100738, "<|LOC_442|>": 100739, "<|LOC_443|>": 100740, "<|LOC_444|>": 100741, "<|LOC_445|>": 100742, "<|LOC_446|>": 100743, "<|LOC_447|>": 100744, "<|LOC_448|>": 100745, "<|LOC_449|>": 100746, "<|LOC_450|>": 100747, "<|LOC_451|>": 100748, "<|LOC_452|>": 100749, "<|LOC_453|>": 100750, "<|LOC_454|>": 100751, "<|LOC_455|>": 100752, "<|LOC_456|>": 100753, "<|LOC_457|>": 100754, "<|LOC_458|>": 100755, "<|LOC_459|>": 100756, "<|LOC_460|>": 100757, "<|LOC_461|>": 100758, "<|LOC_462|>": 100759, "<|LOC_463|>": 100760, "<|LOC_464|>": 100761, "<|LOC_465|>": 100762, "<|LOC_466|>": 100763, "<|LOC_467|>": 100764, "<|LOC_468|>": 100765, "<|LOC_469|>": 100766, "<|LOC_470|>": 100767, "<|LOC_471|>": 100768, "<|LOC_472|>": 100769, "<|LOC_473|>": 100770, "<|LOC_474|>": 100771, "<|LOC_475|>": 100772, "<|LOC_476|>": 100773, "<|LOC_477|>": 100774, "<|LOC_478|>": 100775, "<|LOC_479|>": 100776, "<|LOC_480|>": 100777, "<|LOC_481|>": 100778, "<|LOC_482|>": 100779, "<|LOC_483|>": 100780, "<|LOC_484|>": 100781, "<|LOC_485|>": 100782, "<|LOC_486|>": 100783, "<|LOC_487|>": 100784, "<|LOC_488|>": 100785, "<|LOC_489|>": 100786, "<|LOC_490|>": 100787, "<|LOC_491|>": 100788, "<|LOC_492|>": 100789, "<|LOC_493|>": 100790, "<|LOC_494|>": 100791, "<|LOC_495|>": 100792, "<|LOC_496|>": 100793, "<|LOC_497|>": 100794, "<|LOC_498|>": 100795, "<|LOC_499|>": 100796, "<|LOC_500|>": 100797, "<|LOC_501|>": 100798, "<|LOC_502|>": 100799, "<|LOC_503|>": 100800, "<|LOC_504|>": 100801, "<|LOC_505|>": 100802, "<|LOC_506|>": 100803, "<|LOC_507|>": 100804, "<|LOC_508|>": 100805, "<|LOC_509|>": 100806, "<|LOC_510|>": 100807, "<|LOC_511|>": 100808, "<|LOC_512|>": 100809, "<|LOC_513|>": 100810, "<|LOC_514|>": 100811, "<|LOC_515|>": 100812, "<|LOC_516|>": 100813, "<|LOC_517|>": 100814, "<|LOC_518|>": 100815, "<|LOC_519|>": 100816, "<|LOC_520|>": 100817, "<|LOC_521|>": 100818, "<|LOC_522|>": 100819, "<|LOC_523|>": 100820, "<|LOC_524|>": 100821, "<|LOC_525|>": 100822, "<|LOC_526|>": 100823, "<|LOC_527|>": 100824, "<|LOC_528|>": 100825, "<|LOC_529|>": 100826, "<|LOC_530|>": 100827, "<|LOC_531|>": 100828, "<|LOC_532|>": 100829, "<|LOC_533|>": 100830, "<|LOC_534|>": 100831, "<|LOC_535|>": 100832, "<|LOC_536|>": 100833, "<|LOC_537|>": 100834, "<|LOC_538|>": 100835, "<|LOC_539|>": 100836, "<|LOC_540|>": 100837, "<|LOC_541|>": 100838, "<|LOC_542|>": 100839, "<|LOC_543|>": 100840, "<|LOC_544|>": 100841, "<|LOC_545|>": 100842, "<|LOC_546|>": 100843, "<|LOC_547|>": 100844, "<|LOC_548|>": 100845, "<|LOC_549|>": 100846, "<|LOC_550|>": 100847, "<|LOC_551|>": 100848, "<|LOC_552|>": 100849, "<|LOC_553|>": 100850, "<|LOC_554|>": 100851, "<|LOC_555|>": 100852, "<|LOC_556|>": 100853, "<|LOC_557|>": 100854, "<|LOC_558|>": 100855, "<|LOC_559|>": 100856, "<|LOC_560|>": 100857, "<|LOC_561|>": 100858, "<|LOC_562|>": 100859, "<|LOC_563|>": 100860, "<|LOC_564|>": 100861, "<|LOC_565|>": 100862, "<|LOC_566|>": 100863, "<|LOC_567|>": 100864, "<|LOC_568|>": 100865, "<|LOC_569|>": 100866, "<|LOC_570|>": 100867, "<|LOC_571|>": 100868, "<|LOC_572|>": 100869, "<|LOC_573|>": 100870, "<|LOC_574|>": 100871, "<|LOC_575|>": 100872, "<|LOC_576|>": 100873, "<|LOC_577|>": 100874, "<|LOC_578|>": 100875, "<|LOC_579|>": 100876, "<|LOC_580|>": 100877, "<|LOC_581|>": 100878, "<|LOC_582|>": 100879, "<|LOC_583|>": 100880, "<|LOC_584|>": 100881, "<|LOC_585|>": 100882, "<|LOC_586|>": 100883, "<|LOC_587|>": 100884, "<|LOC_588|>": 100885, "<|LOC_589|>": 100886, "<|LOC_590|>": 100887, "<|LOC_591|>": 100888, "<|LOC_592|>": 100889, "<|LOC_593|>": 100890, "<|LOC_594|>": 100891, "<|LOC_595|>": 100892, "<|LOC_596|>": 100893, "<|LOC_597|>": 100894, "<|LOC_598|>": 100895, "<|LOC_599|>": 100896, "<|LOC_600|>": 100897, "<|LOC_601|>": 100898, "<|LOC_602|>": 100899, "<|LOC_603|>": 100900, "<|LOC_604|>": 100901, "<|LOC_605|>": 100902, "<|LOC_606|>": 100903, "<|LOC_607|>": 100904, "<|LOC_608|>": 100905, "<|LOC_609|>": 100906, "<|LOC_610|>": 100907, "<|LOC_611|>": 100908, "<|LOC_612|>": 100909, "<|LOC_613|>": 100910, "<|LOC_614|>": 100911, "<|LOC_615|>": 100912, "<|LOC_616|>": 100913, "<|LOC_617|>": 100914, "<|LOC_618|>": 100915, "<|LOC_619|>": 100916, "<|LOC_620|>": 100917, "<|LOC_621|>": 100918, "<|LOC_622|>": 100919, "<|LOC_623|>": 100920, "<|LOC_624|>": 100921, "<|LOC_625|>": 100922, "<|LOC_626|>": 100923, "<|LOC_627|>": 100924, "<|LOC_628|>": 100925, "<|LOC_629|>": 100926, "<|LOC_630|>": 100927, "<|LOC_631|>": 100928, "<|LOC_632|>": 100929, "<|LOC_633|>": 100930, "<|LOC_634|>": 100931, "<|LOC_635|>": 100932, "<|LOC_636|>": 100933, "<|LOC_637|>": 100934, "<|LOC_638|>": 100935, "<|LOC_639|>": 100936, "<|LOC_640|>": 100937, "<|LOC_641|>": 100938, "<|LOC_642|>": 100939, "<|LOC_643|>": 100940, "<|LOC_644|>": 100941, "<|LOC_645|>": 100942, "<|LOC_646|>": 100943, "<|LOC_647|>": 100944, "<|LOC_648|>": 100945, "<|LOC_649|>": 100946, "<|LOC_650|>": 100947, "<|LOC_651|>": 100948, "<|LOC_652|>": 100949, "<|LOC_653|>": 100950, "<|LOC_654|>": 100951, "<|LOC_655|>": 100952, "<|LOC_656|>": 100953, "<|LOC_657|>": 100954, "<|LOC_658|>": 100955, "<|LOC_659|>": 100956, "<|LOC_660|>": 100957, "<|LOC_661|>": 100958, "<|LOC_662|>": 100959, "<|LOC_663|>": 100960, "<|LOC_664|>": 100961, "<|LOC_665|>": 100962, "<|LOC_666|>": 100963, "<|LOC_667|>": 100964, "<|LOC_668|>": 100965, "<|LOC_669|>": 100966, "<|LOC_670|>": 100967, "<|LOC_671|>": 100968, "<|LOC_672|>": 100969, "<|LOC_673|>": 100970, "<|LOC_674|>": 100971, "<|LOC_675|>": 100972, "<|LOC_676|>": 100973, "<|LOC_677|>": 100974, "<|LOC_678|>": 100975, "<|LOC_679|>": 100976, "<|LOC_680|>": 100977, "<|LOC_681|>": 100978, "<|LOC_682|>": 100979, "<|LOC_683|>": 100980, "<|LOC_684|>": 100981, "<|LOC_685|>": 100982, "<|LOC_686|>": 100983, "<|LOC_687|>": 100984, "<|LOC_688|>": 100985, "<|LOC_689|>": 100986, "<|LOC_690|>": 100987, "<|LOC_691|>": 100988, "<|LOC_692|>": 100989, "<|LOC_693|>": 100990, "<|LOC_694|>": 100991, "<|LOC_695|>": 100992, "<|LOC_696|>": 100993, "<|LOC_697|>": 100994, "<|LOC_698|>": 100995, "<|LOC_699|>": 100996, "<|LOC_700|>": 100997, "<|LOC_701|>": 100998, "<|LOC_702|>": 100999, "<|LOC_703|>": 101000, "<|LOC_704|>": 101001, "<|LOC_705|>": 101002, "<|LOC_706|>": 101003, "<|LOC_707|>": 101004, "<|LOC_708|>": 101005, "<|LOC_709|>": 101006, "<|LOC_710|>": 101007, "<|LOC_711|>": 101008, "<|LOC_712|>": 101009, "<|LOC_713|>": 101010, "<|LOC_714|>": 101011, "<|LOC_715|>": 101012, "<|LOC_716|>": 101013, "<|LOC_717|>": 101014, "<|LOC_718|>": 101015, "<|LOC_719|>": 101016, "<|LOC_720|>": 101017, "<|LOC_721|>": 101018, "<|LOC_722|>": 101019, "<|LOC_723|>": 101020, "<|LOC_724|>": 101021, "<|LOC_725|>": 101022, "<|LOC_726|>": 101023, "<|LOC_727|>": 101024, "<|LOC_728|>": 101025, "<|LOC_729|>": 101026, "<|LOC_730|>": 101027, "<|LOC_731|>": 101028, "<|LOC_732|>": 101029, "<|LOC_733|>": 101030, "<|LOC_734|>": 101031, "<|LOC_735|>": 101032, "<|LOC_736|>": 101033, "<|LOC_737|>": 101034, "<|LOC_738|>": 101035, "<|LOC_739|>": 101036, "<|LOC_740|>": 101037, "<|LOC_741|>": 101038, "<|LOC_742|>": 101039, "<|LOC_743|>": 101040, "<|LOC_744|>": 101041, "<|LOC_745|>": 101042, "<|LOC_746|>": 101043, "<|LOC_747|>": 101044, "<|LOC_748|>": 101045, "<|LOC_749|>": 101046, "<|LOC_750|>": 101047, "<|LOC_751|>": 101048, "<|LOC_752|>": 101049, "<|LOC_753|>": 101050, "<|LOC_754|>": 101051, "<|LOC_755|>": 101052, "<|LOC_756|>": 101053, "<|LOC_757|>": 101054, "<|LOC_758|>": 101055, "<|LOC_759|>": 101056, "<|LOC_760|>": 101057, "<|LOC_761|>": 101058, "<|LOC_762|>": 101059, "<|LOC_763|>": 101060, "<|LOC_764|>": 101061, "<|LOC_765|>": 101062, "<|LOC_766|>": 101063, "<|LOC_767|>": 101064, "<|LOC_768|>": 101065, "<|LOC_769|>": 101066, "<|LOC_770|>": 101067, "<|LOC_771|>": 101068, "<|LOC_772|>": 101069, "<|LOC_773|>": 101070, "<|LOC_774|>": 101071, "<|LOC_775|>": 101072, "<|LOC_776|>": 101073, "<|LOC_777|>": 101074, "<|LOC_778|>": 101075, "<|LOC_779|>": 101076, "<|LOC_780|>": 101077, "<|LOC_781|>": 101078, "<|LOC_782|>": 101079, "<|LOC_783|>": 101080, "<|LOC_784|>": 101081, "<|LOC_785|>": 101082, "<|LOC_786|>": 101083, "<|LOC_787|>": 101084, "<|LOC_788|>": 101085, "<|LOC_789|>": 101086, "<|LOC_790|>": 101087, "<|LOC_791|>": 101088, "<|LOC_792|>": 101089, "<|LOC_793|>": 101090, "<|LOC_794|>": 101091, "<|LOC_795|>": 101092, "<|LOC_796|>": 101093, "<|LOC_797|>": 101094, "<|LOC_798|>": 101095, "<|LOC_799|>": 101096, "<|LOC_800|>": 101097, "<|LOC_801|>": 101098, "<|LOC_802|>": 101099, "<|LOC_803|>": 101100, "<|LOC_804|>": 101101, "<|LOC_805|>": 101102, "<|LOC_806|>": 101103, "<|LOC_807|>": 101104, "<|LOC_808|>": 101105, "<|LOC_809|>": 101106, "<|LOC_810|>": 101107, "<|LOC_811|>": 101108, "<|LOC_812|>": 101109, "<|LOC_813|>": 101110, "<|LOC_814|>": 101111, "<|LOC_815|>": 101112, "<|LOC_816|>": 101113, "<|LOC_817|>": 101114, "<|LOC_818|>": 101115, "<|LOC_819|>": 101116, "<|LOC_820|>": 101117, "<|LOC_821|>": 101118, "<|LOC_822|>": 101119, "<|LOC_823|>": 101120, "<|LOC_824|>": 101121, "<|LOC_825|>": 101122, "<|LOC_826|>": 101123, "<|LOC_827|>": 101124, "<|LOC_828|>": 101125, "<|LOC_829|>": 101126, "<|LOC_830|>": 101127, "<|LOC_831|>": 101128, "<|LOC_832|>": 101129, "<|LOC_833|>": 101130, "<|LOC_834|>": 101131, "<|LOC_835|>": 101132, "<|LOC_836|>": 101133, "<|LOC_837|>": 101134, "<|LOC_838|>": 101135, "<|LOC_839|>": 101136, "<|LOC_840|>": 101137, "<|LOC_841|>": 101138, "<|LOC_842|>": 101139, "<|LOC_843|>": 101140, "<|LOC_844|>": 101141, "<|LOC_845|>": 101142, "<|LOC_846|>": 101143, "<|LOC_847|>": 101144, "<|LOC_848|>": 101145, "<|LOC_849|>": 101146, "<|LOC_850|>": 101147, "<|LOC_851|>": 101148, "<|LOC_852|>": 101149, "<|LOC_853|>": 101150, "<|LOC_854|>": 101151, "<|LOC_855|>": 101152, "<|LOC_856|>": 101153, "<|LOC_857|>": 101154, "<|LOC_858|>": 101155, "<|LOC_859|>": 101156, "<|LOC_860|>": 101157, "<|LOC_861|>": 101158, "<|LOC_862|>": 101159, "<|LOC_863|>": 101160, "<|LOC_864|>": 101161, "<|LOC_865|>": 101162, "<|LOC_866|>": 101163, "<|LOC_867|>": 101164, "<|LOC_868|>": 101165, "<|LOC_869|>": 101166, "<|LOC_870|>": 101167, "<|LOC_871|>": 101168, "<|LOC_872|>": 101169, "<|LOC_873|>": 101170, "<|LOC_874|>": 101171, "<|LOC_875|>": 101172, "<|LOC_876|>": 101173, "<|LOC_877|>": 101174, "<|LOC_878|>": 101175, "<|LOC_879|>": 101176, "<|LOC_880|>": 101177, "<|LOC_881|>": 101178, "<|LOC_882|>": 101179, "<|LOC_883|>": 101180, "<|LOC_884|>": 101181, "<|LOC_885|>": 101182, "<|LOC_886|>": 101183, "<|LOC_887|>": 101184, "<|LOC_888|>": 101185, "<|LOC_889|>": 101186, "<|LOC_890|>": 101187, "<|LOC_891|>": 101188, "<|LOC_892|>": 101189, "<|LOC_893|>": 101190, "<|LOC_894|>": 101191, "<|LOC_895|>": 101192, "<|LOC_896|>": 101193, "<|LOC_897|>": 101194, "<|LOC_898|>": 101195, "<|LOC_899|>": 101196, "<|LOC_900|>": 101197, "<|LOC_901|>": 101198, "<|LOC_902|>": 101199, "<|LOC_903|>": 101200, "<|LOC_904|>": 101201, "<|LOC_905|>": 101202, "<|LOC_906|>": 101203, "<|LOC_907|>": 101204, "<|LOC_908|>": 101205, "<|LOC_909|>": 101206, "<|LOC_910|>": 101207, "<|LOC_911|>": 101208, "<|LOC_912|>": 101209, "<|LOC_913|>": 101210, "<|LOC_914|>": 101211, "<|LOC_915|>": 101212, "<|LOC_916|>": 101213, "<|LOC_917|>": 101214, "<|LOC_918|>": 101215, "<|LOC_919|>": 101216, "<|LOC_920|>": 101217, "<|LOC_921|>": 101218, "<|LOC_922|>": 101219, "<|LOC_923|>": 101220, "<|LOC_924|>": 101221, "<|LOC_925|>": 101222, "<|LOC_926|>": 101223, "<|LOC_927|>": 101224, "<|LOC_928|>": 101225, "<|LOC_929|>": 101226, "<|LOC_930|>": 101227, "<|LOC_931|>": 101228, "<|LOC_932|>": 101229, "<|LOC_933|>": 101230, "<|LOC_934|>": 101231, "<|LOC_935|>": 101232, "<|LOC_936|>": 101233, "<|LOC_937|>": 101234, "<|LOC_938|>": 101235, "<|LOC_939|>": 101236, "<|LOC_940|>": 101237, "<|LOC_941|>": 101238, "<|LOC_942|>": 101239, "<|LOC_943|>": 101240, "<|LOC_944|>": 101241, "<|LOC_945|>": 101242, "<|LOC_946|>": 101243, "<|LOC_947|>": 101244, "<|LOC_948|>": 101245, "<|LOC_949|>": 101246, "<|LOC_950|>": 101247, "<|LOC_951|>": 101248, "<|LOC_952|>": 101249, "<|LOC_953|>": 101250, "<|LOC_954|>": 101251, "<|LOC_955|>": 101252, "<|LOC_956|>": 101253, "<|LOC_957|>": 101254, "<|LOC_958|>": 101255, "<|LOC_959|>": 101256, "<|LOC_960|>": 101257, "<|LOC_961|>": 101258, "<|LOC_962|>": 101259, "<|LOC_963|>": 101260, "<|LOC_964|>": 101261, "<|LOC_965|>": 101262, "<|LOC_966|>": 101263, "<|LOC_967|>": 101264, "<|LOC_968|>": 101265, "<|LOC_969|>": 101266, "<|LOC_970|>": 101267, "<|LOC_971|>": 101268, "<|LOC_972|>": 101269, "<|LOC_973|>": 101270, "<|LOC_974|>": 101271, "<|LOC_975|>": 101272, "<|LOC_976|>": 101273, "<|LOC_977|>": 101274, "<|LOC_978|>": 101275, "<|LOC_979|>": 101276, "<|LOC_980|>": 101277, "<|LOC_981|>": 101278, "<|LOC_982|>": 101279, "<|LOC_983|>": 101280, "<|LOC_984|>": 101281, "<|LOC_985|>": 101282, "<|LOC_986|>": 101283, "<|LOC_987|>": 101284, "<|LOC_988|>": 101285, "<|LOC_989|>": 101286, "<|LOC_990|>": 101287, "<|LOC_991|>": 101288, "<|LOC_992|>": 101289, "<|LOC_993|>": 101290, "<|LOC_994|>": 101291, "<|LOC_995|>": 101292, "<|LOC_996|>": 101293, "<|LOC_997|>": 101294, "<|LOC_998|>": 101295, "<|LOC_999|>": 101296, "<|LOC_1000|>": 101297, "<|LOC_BEGIN|>": 101298, "<|LOC_END|>": 101299, "<|LOC_SEP|>": 101300, "<|CROP_COL_SEP|>": 101301, "<|CROP_ROW_SEP|>": 101302, "<|IMAGE_SEP|>": 101303}
|
config.json
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_attn_implementation": "sdpa",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"Ernie4_5_MoeForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_ernie4_5_moe.Ernie4_5_MoeConfig",
|
| 8 |
+
"AutoModel": "modeling_ernie4_5_moe.Ernie4_5_Model",
|
| 9 |
+
"AutoModelForCausalLM": "modeling_ernie4_5_moe.Ernie4_5_MoeForCausalLM"
|
| 10 |
+
},
|
| 11 |
+
"bos_token_id": 1,
|
| 12 |
+
"eos_token_id": 2,
|
| 13 |
+
"hidden_act": "silu",
|
| 14 |
+
"hidden_size": 8192,
|
| 15 |
+
"intermediate_size": 28672,
|
| 16 |
+
"max_position_embeddings": 131072,
|
| 17 |
+
"model_type": "ernie4_5_moe",
|
| 18 |
+
"moe_capacity": [
|
| 19 |
+
64,
|
| 20 |
+
64,
|
| 21 |
+
64
|
| 22 |
+
],
|
| 23 |
+
"moe_gate": "topk",
|
| 24 |
+
"moe_intermediate_size": 3584,
|
| 25 |
+
"moe_k": 8,
|
| 26 |
+
"moe_layer_interval": 1,
|
| 27 |
+
"moe_layer_start_index": 3,
|
| 28 |
+
"moe_num_experts": 64,
|
| 29 |
+
"moe_use_aux_free": true,
|
| 30 |
+
"num_attention_heads": 64,
|
| 31 |
+
"num_hidden_layers": 54,
|
| 32 |
+
"num_key_value_heads": 8,
|
| 33 |
+
"num_nextn_predict_layers": 1,
|
| 34 |
+
"pad_token_id": 0,
|
| 35 |
+
"rms_norm_eps": 1e-05,
|
| 36 |
+
"rope_theta": 500000,
|
| 37 |
+
"tie_word_embeddings": false,
|
| 38 |
+
"torch_dtype": "bfloat16",
|
| 39 |
+
"use_bias": false,
|
| 40 |
+
"use_cache": true,
|
| 41 |
+
"vocab_size": 103424,
|
| 42 |
+
"quantization_config": {
|
| 43 |
+
"quant_method": "exl3",
|
| 44 |
+
"version": "0.0.4",
|
| 45 |
+
"bits": 2.5,
|
| 46 |
+
"head_bits": 6,
|
| 47 |
+
"calibration": {
|
| 48 |
+
"rows": 100,
|
| 49 |
+
"cols": 2048
|
| 50 |
+
},
|
| 51 |
+
"out_scales": "auto"
|
| 52 |
+
}
|
| 53 |
+
}
|
configuration_ernie4_5_moe.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Ernie4_5_Moe model configuration"""
|
| 15 |
+
|
| 16 |
+
from transformers import PretrainedConfig
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Ernie4_5_MoeConfig(PretrainedConfig):
|
| 21 |
+
r"""
|
| 22 |
+
This is the configuration class to store the configuration of a [`Ernie4_5_Model`].
|
| 23 |
+
It is used to instantiate an ERNIE-4.5 model according to the specified arguments,
|
| 24 |
+
defining the model architecture. Instantiating a configuration with the defaults
|
| 25 |
+
will yield a similar configuration to that of ERNIE-4.5-300B-A47B-PT [baidu/ERNIE-4.5-300B-A47B-PT].
|
| 26 |
+
|
| 27 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 28 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
vocab_size (int): Size of the vocabulary (number of unique tokens)
|
| 33 |
+
hidden_size (int): Dimensionality of the encoder layers and the pooler layer
|
| 34 |
+
intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer
|
| 35 |
+
max_position_embeddings (int): Maximum sequence length the model can handle
|
| 36 |
+
num_hidden_layers (int): Number of hidden layers in the Transformer encoder
|
| 37 |
+
num_attention_heads (int): Number of attention heads for each attention layer
|
| 38 |
+
rms_norm_eps (float): The epsilon used by the RMS normalization layers
|
| 39 |
+
use_cache (bool): Whether to use caching for faster generation (decoding)
|
| 40 |
+
use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation
|
| 41 |
+
pad_token_id (int): Token ID used for padding sequences
|
| 42 |
+
bos_token_id (int): Token ID used for beginning-of-sequence
|
| 43 |
+
eos_token_id (int): Token ID used for end-of-sequence
|
| 44 |
+
use_bias (bool): Whether to use bias terms in linear layers
|
| 45 |
+
rope_theta (float): The base period of the RoPE embeddings
|
| 46 |
+
weight_share_add_bias (bool): Whether to share bias weights in certain layers
|
| 47 |
+
ignored_index (int): Target value that is ignored during loss computation
|
| 48 |
+
attention_probs_dropout_prob (float): Dropout probability for attention weights
|
| 49 |
+
hidden_dropout_prob (float): Dropout probability for hidden layers
|
| 50 |
+
num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention)
|
| 51 |
+
max_sequence_length (int): Maximum sequence length for positional embeddings
|
| 52 |
+
moe_num_experts: Number of experts in MoE layers
|
| 53 |
+
moe_capacity: Capacity configuration for MoE layers
|
| 54 |
+
moe_layer_interval: Interval between MoE layers
|
| 55 |
+
moe_layer_start_index: Starting layer index for MoE
|
| 56 |
+
moe_layer_end_index: Ending layer index for MoE (-1 means last layer)
|
| 57 |
+
sinkhorn_2gate: Whether to use sinkhorn 2-gate routing
|
| 58 |
+
sinkhorn_temp: Temperature for sinkhorn routing
|
| 59 |
+
moe_dropout_prob: Dropout probability for MoE layers
|
| 60 |
+
moe_gate: Type of gating mechanism ('top2', etc.)
|
| 61 |
+
moe_intermediate_size: Intermediate size for MoE layers
|
| 62 |
+
moe_gate_act: Activation function for gating
|
| 63 |
+
moe_k: Number of experts to route to
|
| 64 |
+
**kwargs: Additional base model configuration parameters
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
model_type = "ernie4_5_moe"
|
| 68 |
+
use_keep_in_fp32_modules = True
|
| 69 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 70 |
+
|
| 71 |
+
attribute_map = {
|
| 72 |
+
"n_positions": "max_position_embeddings",
|
| 73 |
+
"n_embd": "hidden_size",
|
| 74 |
+
"n_layer": "num_hidden_layers",
|
| 75 |
+
"n_head": "num_attention_heads",
|
| 76 |
+
"n_inner": "intermediate_size",
|
| 77 |
+
"activation_function": "hidden_act",
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
# Default tensor parallel plan for base model `ernie_4_5_moe`
|
| 81 |
+
base_model_tp_plan = {
|
| 82 |
+
"model.layers.*.self_attn.q_proj": "colwise_rep",
|
| 83 |
+
"model.layers.*.self_attn.k_proj": "colwise_rep",
|
| 84 |
+
"model.layers.*.self_attn.v_proj": "colwise_rep",
|
| 85 |
+
"model.layers.*.self_attn.o_proj": "rowwise_rep",
|
| 86 |
+
"model.layers.*.mlp.experts.*.gate_proj": "colwise",
|
| 87 |
+
"model.layers.*.mlp.experts.*.up_proj": "colwise",
|
| 88 |
+
"model.layers.*.mlp.experts.*.down_proj": "rowwise",
|
| 89 |
+
"model.layers.*.mlp.gate_proj": "colwise",
|
| 90 |
+
"model.layers.*.mlp.up_proj": "colwise",
|
| 91 |
+
"model.layers.*.mlp.down_proj": "rowwise",
|
| 92 |
+
}
|
| 93 |
+
base_model_pp_plan = {
|
| 94 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 95 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 96 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
vocab_size=32000,
|
| 102 |
+
hidden_size=768,
|
| 103 |
+
intermediate_size=11008,
|
| 104 |
+
num_hidden_layers=2,
|
| 105 |
+
num_attention_heads=2,
|
| 106 |
+
num_key_value_heads=None,
|
| 107 |
+
max_position_embeddings=32768,
|
| 108 |
+
use_sliding_window=None,
|
| 109 |
+
sliding_window=None,
|
| 110 |
+
rms_norm_eps=1e-6,
|
| 111 |
+
use_cache=False,
|
| 112 |
+
pad_token_id=0,
|
| 113 |
+
bos_token_id=1,
|
| 114 |
+
eos_token_id=2,
|
| 115 |
+
attention_probs_dropout_prob=0.0,
|
| 116 |
+
hidden_dropout_prob=0.0,
|
| 117 |
+
rope_theta=10000.0,
|
| 118 |
+
use_flash_attention=False,
|
| 119 |
+
use_rmsnorm=True,
|
| 120 |
+
use_bias=False,
|
| 121 |
+
weight_share_add_bias=True,
|
| 122 |
+
max_sequence_length=None,
|
| 123 |
+
ignored_index=-100,
|
| 124 |
+
use_moe=True,
|
| 125 |
+
moe_num_experts=64,
|
| 126 |
+
moe_capacity=(64, 64, 64),
|
| 127 |
+
moe_layer_interval=2,
|
| 128 |
+
moe_layer_start_index=0,
|
| 129 |
+
moe_layer_end_index=-1,
|
| 130 |
+
sinkhorn_2gate=True,
|
| 131 |
+
sinkhorn_temp=3e-2,
|
| 132 |
+
moe_dropout_prob=0.0,
|
| 133 |
+
moe_gate="top2",
|
| 134 |
+
moe_intermediate_size=3584,
|
| 135 |
+
moe_k=2,
|
| 136 |
+
moe_gate_act="softmax",
|
| 137 |
+
moe_use_aux_free=False,
|
| 138 |
+
**kwargs
|
| 139 |
+
):
|
| 140 |
+
self.vocab_size = vocab_size
|
| 141 |
+
self.max_position_embeddings = max_position_embeddings
|
| 142 |
+
self.use_sliding_window = use_sliding_window
|
| 143 |
+
self.sliding_window = sliding_window
|
| 144 |
+
self.hidden_size = hidden_size
|
| 145 |
+
self.intermediate_size = intermediate_size
|
| 146 |
+
self.num_hidden_layers = num_hidden_layers
|
| 147 |
+
self.num_attention_heads = num_attention_heads
|
| 148 |
+
|
| 149 |
+
if num_key_value_heads is None:
|
| 150 |
+
num_key_value_heads = num_attention_heads
|
| 151 |
+
|
| 152 |
+
self.num_key_value_heads = num_key_value_heads
|
| 153 |
+
self.use_rmsnorm = use_rmsnorm
|
| 154 |
+
self.rms_norm_eps = rms_norm_eps
|
| 155 |
+
self.rope_theta = rope_theta
|
| 156 |
+
self.max_sequence_length = max_sequence_length
|
| 157 |
+
self.pad_token_id = pad_token_id
|
| 158 |
+
self.bos_token_id = bos_token_id
|
| 159 |
+
self.eos_token_id = eos_token_id
|
| 160 |
+
self.ignored_index = ignored_index
|
| 161 |
+
self.use_cache = use_cache
|
| 162 |
+
self.use_bias = use_bias
|
| 163 |
+
self.weight_share_add_bias = weight_share_add_bias
|
| 164 |
+
self.use_flash_attention = use_flash_attention
|
| 165 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 166 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 167 |
+
|
| 168 |
+
self.use_moe = moe_num_experts > 0 and use_moe
|
| 169 |
+
self.moe_num_experts = moe_num_experts
|
| 170 |
+
self.moe_capacity = moe_capacity
|
| 171 |
+
self.sinkhorn_2gate = sinkhorn_2gate
|
| 172 |
+
self.sinkhorn_temp = sinkhorn_temp
|
| 173 |
+
self.moe_layer_interval = moe_layer_interval
|
| 174 |
+
self.moe_dropout_prob = moe_dropout_prob
|
| 175 |
+
self.moe_gate = moe_gate
|
| 176 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 177 |
+
self.moe_k = moe_k
|
| 178 |
+
self.moe_layer_start_index = moe_layer_start_index
|
| 179 |
+
self.moe_layer_end_index = self.num_hidden_layers - 1 if moe_layer_end_index == -1 else moe_layer_end_index
|
| 180 |
+
self.moe_gate_act = moe_gate_act
|
| 181 |
+
self.moe_use_aux_free = moe_use_aux_free
|
| 182 |
+
|
| 183 |
+
# Set default for tied embeddings if not specified.
|
| 184 |
+
if "tie_word_embeddings" not in kwargs:
|
| 185 |
+
kwargs["tie_word_embeddings"] = False
|
| 186 |
+
|
| 187 |
+
super().__init__(
|
| 188 |
+
pad_token_id=pad_token_id,
|
| 189 |
+
bos_token_id=bos_token_id,
|
| 190 |
+
eos_token_id=eos_token_id,
|
| 191 |
+
**kwargs,
|
| 192 |
+
)
|
generation_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_sample": true,
|
| 3 |
+
"top_p": 0.8,
|
| 4 |
+
"temperature": 0.8,
|
| 5 |
+
"repetition_penalty": 1.0,
|
| 6 |
+
"frequency_penalty": 0.0,
|
| 7 |
+
"presence_penalty": 0.0,
|
| 8 |
+
"bos_token_id": 1,
|
| 9 |
+
"eos_token_id": 2,
|
| 10 |
+
"pad_token_id": 0,
|
| 11 |
+
"transformers_version": "4.52.4",
|
| 12 |
+
"use_cache": true
|
| 13 |
+
}
|
model-00001-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d8b444b107080dc7ad65c07c08a87d9bfb6044595f8416429bfa1e2d0872e4a
|
| 3 |
+
size 8544909440
|
model-00002-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:36d148846a5da84072a62c053f413be4f65d9f497d9b21956db370b3eed36578
|
| 3 |
+
size 8063060728
|
model-00003-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3cddc25e0ad12fa7adf1577e8f56bb77445e1d67c26e6706770fed4553c8af29
|
| 3 |
+
size 7358419432
|
model-00004-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e709f9a63af17136dc3384a4e904de526ee5cb51c8e24236930796c316ae1a6
|
| 3 |
+
size 7358419432
|
model-00005-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:83ac6b105d725522ee1526356b4df714605aa41b0da730631d3dcfd293b6cc06
|
| 3 |
+
size 7296021288
|
model-00006-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09de0993f9fc0125cbb7708df148e7f1ffe3fb68cd7675348abc48e62ae7f51b
|
| 3 |
+
size 7296021288
|
model-00007-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67fa28d6ec39d168fb6e860067d972bce3ff0513feae8942c64ef7c6c14f9521
|
| 3 |
+
size 6528980168
|
model-00008-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb1821e950c75b2e5c9456ddf8b4feaf5a0a79f2bf9dc0a9fb1dd22f61b84e00
|
| 3 |
+
size 6528980168
|
model-00009-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f959159bca990432a6775c306d1f3157ff4e8235167f4a61169285568b485908
|
| 3 |
+
size 6528980168
|
model-00010-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:336b118b795c0ea438eddb3dbe13b3f648927a8947dc45cde68a40570678239b
|
| 3 |
+
size 6528980168
|
model-00011-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee5c49b6bdc2e9c1e0eb301ec9ef5aafa0733e0c35060d85634925bcea6c0a66
|
| 3 |
+
size 6528980168
|
model-00012-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba44b18f0ff1d0b57eaa5a785719fe76217e8b1a973685bc3dd6ac5bf8fac2ad
|
| 3 |
+
size 6528980168
|
model-00013-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8bf31a1b58662788ef74bf7304e956db331d8e07933509c5c5b3793b7dbcf0c
|
| 3 |
+
size 8063062600
|
model-00014-of-00014.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ac60124df3af664b096b76ba5c65c83845c3ac1eb276fe0859a3e3a31baeb350
|
| 3 |
+
size 2107360272
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_ernie4_5_moe.py
ADDED
|
@@ -0,0 +1,1590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
""" Ernie4_5_Moe model """
|
| 15 |
+
|
| 16 |
+
from copy import deepcopy
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from functools import partial
|
| 19 |
+
from typing import Callable, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
|
| 25 |
+
from transformers.cache_utils import (
|
| 26 |
+
Cache,
|
| 27 |
+
DynamicCache,
|
| 28 |
+
SlidingWindowCache,
|
| 29 |
+
StaticCache,
|
| 30 |
+
)
|
| 31 |
+
from transformers.generation import GenerationMixin
|
| 32 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 33 |
+
from transformers.modeling_outputs import ModelOutput, MoeCausalLMOutputWithPast
|
| 34 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 35 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 36 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 37 |
+
from transformers.processing_utils import Unpack
|
| 38 |
+
from transformers.utils import (
|
| 39 |
+
LossKwargs,
|
| 40 |
+
auto_docstring,
|
| 41 |
+
can_return_tuple,
|
| 42 |
+
logging,
|
| 43 |
+
is_torch_flex_attn_available,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_torch_flex_attn_available():
|
| 50 |
+
from torch.nn.attention.flex_attention import BlockMask
|
| 51 |
+
|
| 52 |
+
from transformers.integrations.flex_attention import make_flex_block_causal_mask
|
| 53 |
+
|
| 54 |
+
logger = logging.get_logger(__name__)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
|
| 58 |
+
"""Kwargs class used during autoregressive generation"""
|
| 59 |
+
|
| 60 |
+
...
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class Erine4_5_MoeModelOutputWithPast(ModelOutput):
|
| 65 |
+
"""Class for Ernie4_5_Moe model outputs with past keys."""
|
| 66 |
+
|
| 67 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 68 |
+
past_key_values: Optional[Cache] = None
|
| 69 |
+
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 70 |
+
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 71 |
+
router_loss: Optional[torch.FloatTensor] = None
|
| 72 |
+
gate_logits: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class Ernie4_5_MoeCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
|
| 77 |
+
"""Class for Ernie4_5_Moe causal LM output with past keys"""
|
| 78 |
+
|
| 79 |
+
router_loss: Optional[torch.FloatTensor] = None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def rotate_half(x):
|
| 83 |
+
"""Rotates half the hidden dims of the input."""
|
| 84 |
+
|
| 85 |
+
x1 = x[..., 0::2]
|
| 86 |
+
x2 = x[..., 1::2]
|
| 87 |
+
return torch.stack((-x2, x1), dim=-1).reshape(x.shape)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 91 |
+
"""
|
| 92 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 93 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 94 |
+
"""
|
| 95 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 96 |
+
if n_rep == 1:
|
| 97 |
+
return hidden_states
|
| 98 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
| 99 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 100 |
+
)
|
| 101 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 105 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
q (`torch.Tensor`): The query tensor.
|
| 109 |
+
k (`torch.Tensor`): The key tensor.
|
| 110 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 111 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 112 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 113 |
+
Deprecated and unused.
|
| 114 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 115 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 116 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 117 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 118 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 119 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 120 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 121 |
+
Returns:
|
| 122 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 123 |
+
"""
|
| 124 |
+
orig_dtype = q.dtype
|
| 125 |
+
sin_pos = torch.stack([sin, sin], dim=-1).reshape(*sin.shape[:-1], -1)
|
| 126 |
+
cos_pos = torch.stack([cos, cos], dim=-1).reshape(*sin.shape[:-1], -1)
|
| 127 |
+
q_embed = (q.float() * cos_pos) + (rotate_half(q).float() * sin_pos)
|
| 128 |
+
k_embed = (k.float() * cos_pos) + (rotate_half(k).float() * sin_pos)
|
| 129 |
+
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def eager_attention_forward(
|
| 133 |
+
module: nn.Module,
|
| 134 |
+
query: torch.Tensor,
|
| 135 |
+
key: torch.Tensor,
|
| 136 |
+
value: torch.Tensor,
|
| 137 |
+
attention_mask: Optional[torch.Tensor],
|
| 138 |
+
scaling: float,
|
| 139 |
+
dropout: float = 0.0,
|
| 140 |
+
**kwargs,
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
Eager attention for Ernie4_5_Attention forward function.
|
| 144 |
+
"""
|
| 145 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 146 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 147 |
+
|
| 148 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 149 |
+
if attention_mask is not None:
|
| 150 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 151 |
+
attn_weights = attn_weights + causal_mask.to(attn_weights.device)
|
| 152 |
+
|
| 153 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
| 154 |
+
query.dtype
|
| 155 |
+
)
|
| 156 |
+
attn_weights = nn.functional.dropout(
|
| 157 |
+
attn_weights, p=dropout, training=module.training
|
| 158 |
+
)
|
| 159 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 160 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 161 |
+
|
| 162 |
+
return attn_output, attn_weights
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def topk_gate_func(
|
| 166 |
+
module: nn.Module,
|
| 167 |
+
hidden_states: torch.Tensor,
|
| 168 |
+
):
|
| 169 |
+
"""
|
| 170 |
+
Topk gate function for Ernie4_5_MoEMlp
|
| 171 |
+
"""
|
| 172 |
+
capacity = module.get_capacity(hidden_states.shape[0])
|
| 173 |
+
with torch.autocast(device_type="cuda", dtype=torch.float32):
|
| 174 |
+
logits = module.gate(hidden_states.float())
|
| 175 |
+
router_loss = torch.zeros([1], dtype=torch.float32, device=hidden_states.device)
|
| 176 |
+
router_loss.detach()
|
| 177 |
+
return logits, capacity, router_loss
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class Ernie4_5_ResidualWithDropout(nn.Module):
|
| 181 |
+
"""
|
| 182 |
+
Fused dropout implementation with residual connection support.
|
| 183 |
+
|
| 184 |
+
This layer combines dropout and residual addition in a single operation for better performance,
|
| 185 |
+
particularly on GPU devices. The dropout is conditionally applied based on the probability.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
prob (float): Dropout probability (between 0 and 1)
|
| 189 |
+
|
| 190 |
+
Attributes:
|
| 191 |
+
prob (float): Stores the dropout probability
|
| 192 |
+
dropout (nn.Dropout): The actual dropout layer instance
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
def __init__(self, prob):
|
| 196 |
+
"""
|
| 197 |
+
Initialize the fused dropout layer.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
prob (float): Dropout probability (0 means no dropout)
|
| 201 |
+
"""
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.prob = prob
|
| 204 |
+
self.dropout = nn.Dropout(p=prob)
|
| 205 |
+
|
| 206 |
+
def forward(self, x, y):
|
| 207 |
+
"""
|
| 208 |
+
Forward pass of the fused dropout layer.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
x (torch.Tensor): Input tensor to potentially apply dropout on
|
| 212 |
+
y (torch.Tensor): Residual tensor to add to the (possibly dropped out) x
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
torch.Tensor: Result of x (with optional dropout) + y
|
| 216 |
+
"""
|
| 217 |
+
if self.prob > 0:
|
| 218 |
+
x = self.dropout(x)
|
| 219 |
+
output = x + y
|
| 220 |
+
|
| 221 |
+
return output
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class Ernie4_5_Attention(nn.Module):
|
| 225 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 226 |
+
|
| 227 |
+
def __init__(self, config, layer_idx=0):
|
| 228 |
+
"""
|
| 229 |
+
Args:
|
| 230 |
+
config (ErnieConfig): Model configuration.
|
| 231 |
+
layer_idx (int, optional): Index in transformer stack. Defaults to 0.
|
| 232 |
+
"""
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.layer_idx = layer_idx
|
| 235 |
+
self.hidden_size = config.hidden_size
|
| 236 |
+
self.num_heads = config.num_attention_heads
|
| 237 |
+
self.num_key_value_heads = (
|
| 238 |
+
config.num_key_value_heads
|
| 239 |
+
if config.num_key_value_heads is not None
|
| 240 |
+
else self.nums_head
|
| 241 |
+
)
|
| 242 |
+
self.num_key_value_groups = (
|
| 243 |
+
config.num_attention_heads // config.num_key_value_heads
|
| 244 |
+
)
|
| 245 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 246 |
+
self.freq_allocation = (
|
| 247 |
+
config.freq_allocation if hasattr(config, "freq_allocation") else 0
|
| 248 |
+
)
|
| 249 |
+
self.scaling = self.head_dim**-0.5
|
| 250 |
+
self.attention_dropout = getattr(config, "attention_probs_dropout_prob", 0.0)
|
| 251 |
+
self.is_causal = True
|
| 252 |
+
|
| 253 |
+
self.q_proj = nn.Linear(
|
| 254 |
+
self.hidden_size,
|
| 255 |
+
self.num_heads * self.head_dim,
|
| 256 |
+
bias=config.use_bias,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
self.k_proj = nn.Linear(
|
| 260 |
+
self.hidden_size,
|
| 261 |
+
self.num_key_value_heads * self.head_dim,
|
| 262 |
+
bias=config.use_bias,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
self.v_proj = nn.Linear(
|
| 266 |
+
self.hidden_size,
|
| 267 |
+
self.num_key_value_heads * self.head_dim,
|
| 268 |
+
bias=config.use_bias,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
self.o_proj = nn.Linear(
|
| 272 |
+
self.hidden_size,
|
| 273 |
+
self.hidden_size,
|
| 274 |
+
bias=config.use_bias,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self.config = config
|
| 278 |
+
|
| 279 |
+
def forward(
|
| 280 |
+
self,
|
| 281 |
+
hidden_states: torch.Tensor,
|
| 282 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 283 |
+
past_key_value: Optional[Cache] = None,
|
| 284 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 285 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 286 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
|
| 287 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 288 |
+
) -> Tuple[
|
| 289 |
+
torch.Tensor,
|
| 290 |
+
Optional[torch.Tensor],
|
| 291 |
+
Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 292 |
+
]:
|
| 293 |
+
"""
|
| 294 |
+
Ernie4_5_Attention forward function
|
| 295 |
+
"""
|
| 296 |
+
B, L = hidden_states.shape[:-1]
|
| 297 |
+
|
| 298 |
+
query_states = (
|
| 299 |
+
self.q_proj(hidden_states).view(B, L, self.num_heads, -1).transpose(1, 2)
|
| 300 |
+
)
|
| 301 |
+
key_states = (
|
| 302 |
+
self.k_proj(hidden_states)
|
| 303 |
+
.view(B, L, self.num_key_value_heads, -1)
|
| 304 |
+
.transpose(1, 2)
|
| 305 |
+
)
|
| 306 |
+
value_states = (
|
| 307 |
+
self.v_proj(hidden_states)
|
| 308 |
+
.view(B, L, self.num_key_value_heads, -1)
|
| 309 |
+
.transpose(1, 2)
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
cos, sin = position_embeddings
|
| 313 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 314 |
+
query_states, key_states, cos, sin
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
if past_key_value is not None:
|
| 318 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 319 |
+
cache_kwargs = {"cache_position": cache_position}
|
| 320 |
+
key_states, value_states = past_key_value.update(
|
| 321 |
+
key_states, value_states, self.layer_idx, cache_kwargs
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
attention_interface: Callable = eager_attention_forward
|
| 325 |
+
if self.config._attn_implementation != "eager":
|
| 326 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
| 327 |
+
self.config._attn_implementation
|
| 328 |
+
]
|
| 329 |
+
|
| 330 |
+
attn_output, attn_weights = attention_interface(
|
| 331 |
+
self,
|
| 332 |
+
query_states,
|
| 333 |
+
key_states,
|
| 334 |
+
value_states,
|
| 335 |
+
attention_mask,
|
| 336 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 337 |
+
scaling=self.scaling,
|
| 338 |
+
**kwargs,
|
| 339 |
+
)
|
| 340 |
+
attn_output = attn_output.reshape(B, L, -1).contiguous()
|
| 341 |
+
attn_output = self.o_proj(attn_output)
|
| 342 |
+
|
| 343 |
+
return attn_output, attn_weights
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class Ernie4_5_MLP(nn.Module):
|
| 347 |
+
"""
|
| 348 |
+
Ernie4_5_MLP - Gated Multi-Layer Perceptron module used in Ernie model.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
def __init__(self, config, intermediate_size=None):
|
| 352 |
+
"""
|
| 353 |
+
Initialize the MLP module with configuration options.
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
config: Model configuration object with attributes:
|
| 357 |
+
- hidden_size: int
|
| 358 |
+
- intermediate_size: int
|
| 359 |
+
- use_bias: bool
|
| 360 |
+
layer_idx (int): Index of current layer (default: 0)
|
| 361 |
+
"""
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.config = config
|
| 364 |
+
self.hidden_size = config.hidden_size
|
| 365 |
+
self.intermediate_size = (
|
| 366 |
+
intermediate_size
|
| 367 |
+
if intermediate_size is not None
|
| 368 |
+
else config.intermediate_size
|
| 369 |
+
)
|
| 370 |
+
self.gate_proj = nn.Linear(
|
| 371 |
+
self.hidden_size, self.intermediate_size, bias=config.use_bias
|
| 372 |
+
)
|
| 373 |
+
self.up_proj = nn.Linear(
|
| 374 |
+
self.hidden_size, self.intermediate_size, bias=config.use_bias
|
| 375 |
+
)
|
| 376 |
+
self.down_proj = nn.Linear(
|
| 377 |
+
self.intermediate_size, self.hidden_size, bias=config.use_bias
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
def forward(self, x):
|
| 381 |
+
"""
|
| 382 |
+
Args:
|
| 383 |
+
x (Tensor): shape [batch_size, seq_len, hidden_size]
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
Tensor: shape [batch_size, seq_len, hidden_size]
|
| 387 |
+
"""
|
| 388 |
+
down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 389 |
+
return down_proj
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class Ernie4_5_MoeStatics(nn.Module):
|
| 393 |
+
"""
|
| 394 |
+
Stores MoE (Mixture of Experts) statistics
|
| 395 |
+
and expert usage information.
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
def __init__(self, config):
|
| 399 |
+
"""
|
| 400 |
+
Initialize MoE statistics tracking.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
config: Model configuration containing MoE parameters
|
| 404 |
+
"""
|
| 405 |
+
super().__init__()
|
| 406 |
+
|
| 407 |
+
num_experts = config.moe_num_experts
|
| 408 |
+
num_experts_groups = 1
|
| 409 |
+
|
| 410 |
+
self.e_score_correction_bias = nn.Parameter(
|
| 411 |
+
torch.zeros(num_experts_groups, num_experts, dtype=torch.float32),
|
| 412 |
+
requires_grad=False,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class Ernie4_5_MoeMLP(nn.Module):
|
| 417 |
+
"""Mixture of Experts (MoE) variant of ERNIE's MLP layer."""
|
| 418 |
+
|
| 419 |
+
def __init__(self, config):
|
| 420 |
+
super().__init__()
|
| 421 |
+
self.config = config
|
| 422 |
+
self.k = config.moe_k
|
| 423 |
+
self.sinkhorn_2gate = config.sinkhorn_2gate
|
| 424 |
+
self.sinkhorn_temp = config.sinkhorn_temp
|
| 425 |
+
|
| 426 |
+
moe_intermediate_size = (
|
| 427 |
+
config.moe_intermediate_size
|
| 428 |
+
if config.moe_intermediate_size
|
| 429 |
+
else config.intermediate_size
|
| 430 |
+
)
|
| 431 |
+
self.gate = nn.Linear(
|
| 432 |
+
config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32
|
| 433 |
+
)
|
| 434 |
+
if config.moe_gate_act == "softmax":
|
| 435 |
+
self.gate_act = partial(F.softmax, dim=-1)
|
| 436 |
+
elif config.moe_gate_act == "sigmoid":
|
| 437 |
+
self.gate_act = F.sigmoid
|
| 438 |
+
else:
|
| 439 |
+
raise ValueError(f"{config.moe_gate_act} is not supported.")
|
| 440 |
+
|
| 441 |
+
self.experts = nn.ModuleList(
|
| 442 |
+
[
|
| 443 |
+
Ernie4_5_MLP(config, moe_intermediate_size)
|
| 444 |
+
for i in range(config.moe_num_experts)
|
| 445 |
+
]
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
if config.moe_use_aux_free:
|
| 449 |
+
self.moe_statics = Ernie4_5_MoeStatics(config)
|
| 450 |
+
|
| 451 |
+
self.use_correction_bias = config.moe_use_aux_free
|
| 452 |
+
self.num_local_experts = len(self.experts)
|
| 453 |
+
|
| 454 |
+
self.shared_experts = self._init_shared_experts()
|
| 455 |
+
|
| 456 |
+
def _init_shared_experts(self):
|
| 457 |
+
"""
|
| 458 |
+
Initialize the shared expert module.
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
shared_experts: Shared expert module, returns None if no shared experts are needed.
|
| 462 |
+
|
| 463 |
+
"""
|
| 464 |
+
cfg = deepcopy(self.config)
|
| 465 |
+
if getattr(cfg, "moe_num_shared_experts", 0) > 0:
|
| 466 |
+
if getattr(cfg, "moe_intermediate_size", None):
|
| 467 |
+
cfg.intermediate_size = (
|
| 468 |
+
cfg.moe_intermediate_size * cfg.moe_num_shared_experts
|
| 469 |
+
)
|
| 470 |
+
else:
|
| 471 |
+
cfg.intermediate_size = (
|
| 472 |
+
cfg.intermediate_size * cfg.moe_num_shared_experts
|
| 473 |
+
)
|
| 474 |
+
shared_experts = Ernie4_5_MLP(cfg, cfg.intermediate_size)
|
| 475 |
+
else:
|
| 476 |
+
shared_experts = None
|
| 477 |
+
return shared_experts
|
| 478 |
+
|
| 479 |
+
def forward(
|
| 480 |
+
self,
|
| 481 |
+
input: torch.Tensor,
|
| 482 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 483 |
+
"""
|
| 484 |
+
Forward pass through MoE layer.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
input (Tensor): Input tensor of shape [s, d].
|
| 488 |
+
token_type_ids: Optional tensor for token types.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
tuple: (output, combine_weights, router_loss, gate_logits)
|
| 492 |
+
"""
|
| 493 |
+
|
| 494 |
+
if input.dim() == 3:
|
| 495 |
+
orig_shape = input.shape
|
| 496 |
+
input = input.reshape(-1, input.shape[-1])
|
| 497 |
+
else:
|
| 498 |
+
orig_shape = None
|
| 499 |
+
assert (
|
| 500 |
+
input.dim() == 2
|
| 501 |
+
), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}"
|
| 502 |
+
|
| 503 |
+
assert self.gate is not None
|
| 504 |
+
|
| 505 |
+
gate_input = input
|
| 506 |
+
|
| 507 |
+
(
|
| 508 |
+
dispatched_input,
|
| 509 |
+
combine_weights,
|
| 510 |
+
dispatch_mask,
|
| 511 |
+
scatter_index,
|
| 512 |
+
router_loss,
|
| 513 |
+
gate_logits,
|
| 514 |
+
gate_prob,
|
| 515 |
+
) = self.gate_and_dispatch(gate_input)
|
| 516 |
+
|
| 517 |
+
expert_out = self.forward_experts(dispatched_input)
|
| 518 |
+
|
| 519 |
+
combined_output = self.combine_expert_output(
|
| 520 |
+
expert_out, combine_weights, scatter_index
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
if self.shared_experts is not None:
|
| 524 |
+
shared_expert_out = self.shared_experts(gate_input)
|
| 525 |
+
combined_output += shared_expert_out
|
| 526 |
+
|
| 527 |
+
if orig_shape:
|
| 528 |
+
combined_output = combined_output.reshape(
|
| 529 |
+
orig_shape[:-1] + (combined_output.shape[-1],)
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
return combined_output, combine_weights, router_loss, gate_logits
|
| 533 |
+
|
| 534 |
+
def forward_experts(self, dispatched_input: torch.Tensor) -> torch.Tensor:
|
| 535 |
+
"""
|
| 536 |
+
Forward pass through experts sequentially.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
dispatched_input (Tensor): Input tensor of shape [num_experts, capacity, dim].
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
Tensor: Expert outputs of shape [num_experts, capacity, dim].
|
| 543 |
+
"""
|
| 544 |
+
true_experts = self.experts
|
| 545 |
+
dispatched_input = dispatched_input.reshape(
|
| 546 |
+
1, self.num_local_experts, -1, dispatched_input.shape[-1]
|
| 547 |
+
)
|
| 548 |
+
expert_outputs = []
|
| 549 |
+
if isinstance(self.experts, nn.ModuleList):
|
| 550 |
+
chunks = dispatched_input.permute(1, 0, 2, 3).contiguous().unbind(0)
|
| 551 |
+
assert len(chunks) == len(
|
| 552 |
+
true_experts
|
| 553 |
+
), f"{len(chunks)}, {len(true_experts)}"
|
| 554 |
+
for chunk, expert in zip(chunks, true_experts):
|
| 555 |
+
expert_outputs.append(expert(chunk))
|
| 556 |
+
else:
|
| 557 |
+
dispatched_input = dispatched_input.permute(1, 0, 2, 3).contiguous()
|
| 558 |
+
orig_shape = dispatched_input.shape
|
| 559 |
+
chunks = dispatched_input.reshape(orig_shape[0], -1, orig_shape[-1])
|
| 560 |
+
chunks = self.experts(chunks)
|
| 561 |
+
chunks = chunks.reshape(orig_shape[:-1] + (chunks.shape[-1],)).unbind(0)
|
| 562 |
+
expert_outputs.extend(chunks)
|
| 563 |
+
|
| 564 |
+
expert_output = torch.stack(expert_outputs, dim=1)
|
| 565 |
+
return expert_output
|
| 566 |
+
|
| 567 |
+
def moe_gate_dispatch(
|
| 568 |
+
self,
|
| 569 |
+
x: torch.Tensor,
|
| 570 |
+
gate_logits: torch.Tensor,
|
| 571 |
+
k: int,
|
| 572 |
+
capacity: Optional[int],
|
| 573 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 574 |
+
"""
|
| 575 |
+
Dispatch inputs to experts based on their routing probabilities.
|
| 576 |
+
"""
|
| 577 |
+
S, H = x.shape
|
| 578 |
+
E = gate_logits.shape[1]
|
| 579 |
+
device = x.device
|
| 580 |
+
topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1)
|
| 581 |
+
combine_weights = topk_prob
|
| 582 |
+
expert_id = topk_idx
|
| 583 |
+
y = x.new_zeros((E, capacity, H))
|
| 584 |
+
scatter_index = x.new_full((k, S), -1, dtype=torch.int32)
|
| 585 |
+
|
| 586 |
+
# per-expert slot counters
|
| 587 |
+
slot_counter = torch.zeros(E, dtype=torch.int32, device=device)
|
| 588 |
+
|
| 589 |
+
for tok in range(S):
|
| 590 |
+
for route in range(k):
|
| 591 |
+
e = expert_id[tok, route].item()
|
| 592 |
+
slot = slot_counter[e].item()
|
| 593 |
+
if slot >= capacity:
|
| 594 |
+
combine_weights[tok, route] = 0.0
|
| 595 |
+
continue
|
| 596 |
+
|
| 597 |
+
# record mapping & dispatch activation
|
| 598 |
+
scatter_index[route, tok] = e * capacity + slot
|
| 599 |
+
y[e, slot] = x[tok]
|
| 600 |
+
slot_counter[e] += 1
|
| 601 |
+
|
| 602 |
+
expert_offset = torch.cumsum(slot_counter, 0, dtype=torch.int64)
|
| 603 |
+
|
| 604 |
+
return y, combine_weights, scatter_index, expert_offset, expert_id
|
| 605 |
+
|
| 606 |
+
def combine_expert_output(
|
| 607 |
+
self,
|
| 608 |
+
expert_output: torch.Tensor,
|
| 609 |
+
combine_weights: torch.Tensor,
|
| 610 |
+
scatter_index: torch.Tensor,
|
| 611 |
+
) -> torch.Tensor:
|
| 612 |
+
"""
|
| 613 |
+
Combine expert outputs using combination weights.
|
| 614 |
+
|
| 615 |
+
Args:
|
| 616 |
+
expert_output (Tensor): Expert outputs [num_experts, capacity, dim].
|
| 617 |
+
combine_weights (Tensor): Combination weights.
|
| 618 |
+
scatter_index (Tensor): Scatter indices.
|
| 619 |
+
|
| 620 |
+
Returns:
|
| 621 |
+
Tensor: Combined output [seqlen, dim].
|
| 622 |
+
"""
|
| 623 |
+
expert_output = expert_output.reshape(-1, expert_output.shape[-1])
|
| 624 |
+
combined_output = self.combining(expert_output, combine_weights, scatter_index)
|
| 625 |
+
return combined_output
|
| 626 |
+
|
| 627 |
+
def combining(self, x, combine_weights, scatter_index):
|
| 628 |
+
"""
|
| 629 |
+
Combines and aggregates input matrix using combination weights.
|
| 630 |
+
|
| 631 |
+
Args:
|
| 632 |
+
x (Tensor): Input tensor of shape [num_experts * capacity, dim]
|
| 633 |
+
combine_weights (Tensor): Combination weights of shape [seq, 2]
|
| 634 |
+
scatter_index (Tensor): Scatter indices of shape [seq, 2]
|
| 635 |
+
|
| 636 |
+
Returns:
|
| 637 |
+
Tensor: Combined output tensor of shape [seq, dim]
|
| 638 |
+
"""
|
| 639 |
+
dim = x.shape[-1]
|
| 640 |
+
|
| 641 |
+
scatter_index = scatter_index.reshape([-1])
|
| 642 |
+
num_k = combine_weights.shape[-1]
|
| 643 |
+
|
| 644 |
+
combine_weights = combine_weights.unsqueeze(1)
|
| 645 |
+
|
| 646 |
+
x = x[scatter_index].reshape([-1, num_k, dim])
|
| 647 |
+
|
| 648 |
+
return torch.matmul(combine_weights, x).squeeze(1)
|
| 649 |
+
|
| 650 |
+
def gate_and_dispatch(self, input):
|
| 651 |
+
"""
|
| 652 |
+
Calculate gate and dispatch inputs.
|
| 653 |
+
|
| 654 |
+
Args:
|
| 655 |
+
input: Input tensor of shape [seq, dim]
|
| 656 |
+
|
| 657 |
+
Returns:
|
| 658 |
+
tuple: (dispatched_input, combine_weights, dispatch_mask,
|
| 659 |
+
scatter_index, router_loss, gate_logits, gate_prob)
|
| 660 |
+
"""
|
| 661 |
+
gate_logits, capacity, router_loss = topk_gate_func(self, input)
|
| 662 |
+
|
| 663 |
+
# capacity no use
|
| 664 |
+
prob = self.gate_act(gate_logits)
|
| 665 |
+
(
|
| 666 |
+
dispatched_input,
|
| 667 |
+
combine_weights_unnorm,
|
| 668 |
+
scatter_index,
|
| 669 |
+
dispatch_mask,
|
| 670 |
+
_,
|
| 671 |
+
) = self.moe_gate_dispatch(input, prob, k=self.k, capacity=capacity)
|
| 672 |
+
dispatch_mask = torch.diff(F.pad(dispatch_mask, (1, 0)))
|
| 673 |
+
|
| 674 |
+
scatter_index.detach()
|
| 675 |
+
dispatch_mask.detach()
|
| 676 |
+
|
| 677 |
+
scatter_index = scatter_index.transpose(0, 1) # [k, s] -> [s, k]
|
| 678 |
+
combine_weights = combine_weights_unnorm / torch.clamp(
|
| 679 |
+
combine_weights_unnorm.sum(dim=-1, keepdim=True), min=1e-12
|
| 680 |
+
)
|
| 681 |
+
combine_weights = combine_weights.to(dtype=dispatched_input.dtype)
|
| 682 |
+
|
| 683 |
+
return (
|
| 684 |
+
dispatched_input,
|
| 685 |
+
combine_weights,
|
| 686 |
+
dispatch_mask,
|
| 687 |
+
scatter_index,
|
| 688 |
+
router_loss,
|
| 689 |
+
gate_logits,
|
| 690 |
+
prob,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
def get_capacity(self, num_tokens, cap_factor=None):
|
| 694 |
+
"""
|
| 695 |
+
Calculate capacity based on number of tokens.
|
| 696 |
+
|
| 697 |
+
Args:
|
| 698 |
+
num_tokens: Number of input tokens
|
| 699 |
+
cap_factor: Optional capacity factor override
|
| 700 |
+
|
| 701 |
+
Returns:
|
| 702 |
+
int: Calculated capacity
|
| 703 |
+
"""
|
| 704 |
+
num_experts = self.config.moe_num_experts
|
| 705 |
+
if cap_factor is not None:
|
| 706 |
+
cap = cap_factor
|
| 707 |
+
else:
|
| 708 |
+
if self.training:
|
| 709 |
+
cap = self.config.moe_capacity[0]
|
| 710 |
+
elif num_tokens < num_experts:
|
| 711 |
+
cap = self.config.moe_capacity[2]
|
| 712 |
+
else:
|
| 713 |
+
cap = self.config.moe_capacity[1]
|
| 714 |
+
|
| 715 |
+
capacity = int(cap * num_tokens // num_experts)
|
| 716 |
+
assert (
|
| 717 |
+
capacity > 0
|
| 718 |
+
), f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}"
|
| 719 |
+
return capacity
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
class Ernie4_5_RMSNorm(nn.Module):
|
| 723 |
+
"""
|
| 724 |
+
Ernie Root Mean Square Layer Normalization (Ernie4_5_RMSNorm) implementation.
|
| 725 |
+
|
| 726 |
+
Ernie4_5_RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs,
|
| 727 |
+
omitting the mean-centering operation. This provides computational efficiency while maintaining
|
| 728 |
+
good performance.
|
| 729 |
+
|
| 730 |
+
"""
|
| 731 |
+
|
| 732 |
+
def __init__(self, config):
|
| 733 |
+
"""
|
| 734 |
+
Initialize RMSNorm layer.
|
| 735 |
+
|
| 736 |
+
Args:
|
| 737 |
+
config (ErnieConfig): Model configuration.
|
| 738 |
+
"""
|
| 739 |
+
super().__init__()
|
| 740 |
+
self.config = config
|
| 741 |
+
self.hidden_size = config.hidden_size
|
| 742 |
+
self.weight = nn.Parameter(torch.ones(config.hidden_size))
|
| 743 |
+
self.variance_epsilon = config.rms_norm_eps
|
| 744 |
+
|
| 745 |
+
def forward(self, hidden_states):
|
| 746 |
+
"""
|
| 747 |
+
Apply RMS normalization to input hidden states.
|
| 748 |
+
|
| 749 |
+
Args:
|
| 750 |
+
hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
|
| 751 |
+
|
| 752 |
+
Returns:
|
| 753 |
+
Tensor: Normalized output tensor of same shape as input
|
| 754 |
+
"""
|
| 755 |
+
input_dtype = hidden_states.dtype
|
| 756 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 757 |
+
variance = hidden_states.pow(2).mean(dim=-1, keepdim=True)
|
| 758 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 759 |
+
|
| 760 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
class Ernie4_5_RopeEmbedding(nn.Module):
|
| 764 |
+
"""
|
| 765 |
+
Implements Rotary Position Embedding (RoPE) for Ernie4_5_MoeModel.
|
| 766 |
+
"""
|
| 767 |
+
|
| 768 |
+
def __init__(self, config: Ernie4_5_MoeConfig, device=None):
|
| 769 |
+
super().__init__()
|
| 770 |
+
# BC: "rope_type" was originally "type"
|
| 771 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 772 |
+
self.rope_type = config.rope_scaling.get(
|
| 773 |
+
"rope_type", config.rope_scaling.get("type")
|
| 774 |
+
)
|
| 775 |
+
else:
|
| 776 |
+
self.rope_type = "default"
|
| 777 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 778 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 779 |
+
|
| 780 |
+
self.config = config
|
| 781 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 782 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 783 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 784 |
+
self.original_inv_freq = self.inv_freq
|
| 785 |
+
|
| 786 |
+
@torch.no_grad()
|
| 787 |
+
def forward(self, x, position_ids):
|
| 788 |
+
inv_freq_expanded = self.inv_freq[None, None, :].float()
|
| 789 |
+
position_ids_expanded = position_ids[..., None].float()
|
| 790 |
+
freqs = inv_freq_expanded.float() * position_ids_expanded.float()
|
| 791 |
+
cos = torch.cos(freqs) * self.attention_scaling
|
| 792 |
+
sin = torch.sin(freqs) * self.attention_scaling
|
| 793 |
+
return cos, sin
|
| 794 |
+
# return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
class Ernie4_5_DecoderLayer(nn.Module):
|
| 798 |
+
"""A single transformer decoder layer in ERNIE-MoE model.
|
| 799 |
+
|
| 800 |
+
Contains self-attention and feed-forward components with optional MoE (Mixture of Experts)
|
| 801 |
+
support, residual connections, and layer normalization.
|
| 802 |
+
"""
|
| 803 |
+
|
| 804 |
+
def __init__(self, config, layer_idx):
|
| 805 |
+
"""Initialize the decoder layer.
|
| 806 |
+
|
| 807 |
+
Args:
|
| 808 |
+
config (ErnieMoEConfig): Model configuration.
|
| 809 |
+
layer_idx (int): Index of this layer in the transformer stack
|
| 810 |
+
"""
|
| 811 |
+
super().__init__()
|
| 812 |
+
self.hidden_size = config.hidden_size
|
| 813 |
+
self.layer_idx = layer_idx
|
| 814 |
+
self.config = config
|
| 815 |
+
self.use_moe = config.use_moe
|
| 816 |
+
self.self_attn = Ernie4_5_Attention(config, layer_idx)
|
| 817 |
+
|
| 818 |
+
moe_layer_start_index = (
|
| 819 |
+
min(config.moe_layer_start_index)
|
| 820 |
+
if isinstance(config.moe_layer_start_index, (tuple, list))
|
| 821 |
+
else config.moe_layer_start_index
|
| 822 |
+
)
|
| 823 |
+
moe_layer_end_index = (
|
| 824 |
+
max(config.moe_layer_end_index)
|
| 825 |
+
if isinstance(config.moe_layer_end_index, (tuple, list))
|
| 826 |
+
else config.moe_layer_end_index
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
if (
|
| 830 |
+
self.use_moe
|
| 831 |
+
and ((layer_idx + 1) % config.moe_layer_interval == 0)
|
| 832 |
+
and layer_idx >= moe_layer_start_index
|
| 833 |
+
and layer_idx <= moe_layer_end_index
|
| 834 |
+
):
|
| 835 |
+
self.mlp = Ernie4_5_MoeMLP(config)
|
| 836 |
+
else:
|
| 837 |
+
self.mlp = Ernie4_5_MLP(config)
|
| 838 |
+
|
| 839 |
+
self.input_layernorm = Ernie4_5_RMSNorm(config)
|
| 840 |
+
self.post_attention_layernorm = Ernie4_5_RMSNorm(config)
|
| 841 |
+
|
| 842 |
+
self.residual_add1 = Ernie4_5_ResidualWithDropout(config.hidden_dropout_prob)
|
| 843 |
+
self.residual_add2 = Ernie4_5_ResidualWithDropout(config.hidden_dropout_prob)
|
| 844 |
+
|
| 845 |
+
def forward(
|
| 846 |
+
self,
|
| 847 |
+
hidden_states: torch.Tensor,
|
| 848 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 849 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 850 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 851 |
+
output_attentions: Optional[bool] = False,
|
| 852 |
+
use_cache: Optional[bool] = False,
|
| 853 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 854 |
+
position_embeddings: Optional[
|
| 855 |
+
tuple[torch.Tensor, torch.Tensor]
|
| 856 |
+
] = None, # necessary, but kept here for BC
|
| 857 |
+
output_router_loss: bool = True,
|
| 858 |
+
output_gate_logits: bool = True,
|
| 859 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 860 |
+
) -> tuple[
|
| 861 |
+
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
|
| 862 |
+
]:
|
| 863 |
+
"""Forward pass through the decoder layer.
|
| 864 |
+
|
| 865 |
+
Args:
|
| 866 |
+
hidden_states (torch.Tensor): Input tensor [batch_size, seq_len, hidden_size]
|
| 867 |
+
attention_mask (Optional[torch.Tensor]): Attention mask tensor
|
| 868 |
+
position_ids (Optional[torch.Tensor]): Position indices for rotary embeddings
|
| 869 |
+
past_key_value (Optional[Tuple[torch.Tensor]]): Cached key/value states
|
| 870 |
+
output_attentions (Optional[bool]): Whether to return attention weights
|
| 871 |
+
use_cache (Optional[bool]): Whether to cache key/value states
|
| 872 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 873 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 874 |
+
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
| 875 |
+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
| 876 |
+
with `head_dim` being the embedding dimension of each attention head.
|
| 877 |
+
output_router_loss (bool): Whether to return MoE router loss
|
| 878 |
+
output_gate_logits (bool): Whether to return MoE gate logits
|
| 879 |
+
|
| 880 |
+
Returns:
|
| 881 |
+
Union: Various output combinations depending on arguments:
|
| 882 |
+
- Base case: Hidden states tensor
|
| 883 |
+
- With attention: Tuple of (hidden_states, attention_weights)
|
| 884 |
+
- With router loss: May include gate logits in output tuple
|
| 885 |
+
- With MoE gate logits: May include gate logits in output tuple
|
| 886 |
+
"""
|
| 887 |
+
residual = hidden_states
|
| 888 |
+
|
| 889 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 890 |
+
|
| 891 |
+
# Self Attention
|
| 892 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 893 |
+
hidden_states=hidden_states,
|
| 894 |
+
attention_mask=attention_mask,
|
| 895 |
+
past_key_value=past_key_value,
|
| 896 |
+
position_ids=position_ids,
|
| 897 |
+
use_cache=use_cache,
|
| 898 |
+
cache_position=cache_position,
|
| 899 |
+
position_embeddings=position_embeddings,
|
| 900 |
+
**kwargs,
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
hidden_states = self.residual_add1(hidden_states, residual)
|
| 904 |
+
|
| 905 |
+
# Fully Connected
|
| 906 |
+
residual = hidden_states
|
| 907 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 908 |
+
|
| 909 |
+
router_loss = None
|
| 910 |
+
gate_logits = None
|
| 911 |
+
|
| 912 |
+
if isinstance(self.mlp, Ernie4_5_MoeMLP):
|
| 913 |
+
hidden_states, _, router_loss, gate_logits = self.mlp(hidden_states)
|
| 914 |
+
else:
|
| 915 |
+
hidden_states = self.mlp(hidden_states)
|
| 916 |
+
|
| 917 |
+
hidden_states = self.residual_add2(hidden_states, residual)
|
| 918 |
+
|
| 919 |
+
outputs = (hidden_states,)
|
| 920 |
+
|
| 921 |
+
if output_attentions:
|
| 922 |
+
outputs += (self_attn_weights,)
|
| 923 |
+
|
| 924 |
+
if output_router_loss:
|
| 925 |
+
outputs += (router_loss,)
|
| 926 |
+
|
| 927 |
+
if output_gate_logits:
|
| 928 |
+
outputs += (gate_logits,)
|
| 929 |
+
|
| 930 |
+
return outputs
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
@auto_docstring
|
| 934 |
+
class Ernie4_5_PretrainedModel(PreTrainedModel):
|
| 935 |
+
"""Base class for ERNIE pretrained models."""
|
| 936 |
+
|
| 937 |
+
config_class = Ernie4_5_MoeConfig
|
| 938 |
+
base_model_prefix = "model"
|
| 939 |
+
supports_gradient_checkpointing = True
|
| 940 |
+
_no_split_modules = ["Ernie4_5_DecoderLayer"]
|
| 941 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 942 |
+
_supports_flash_attn_2 = True
|
| 943 |
+
_supports_sdpa = True
|
| 944 |
+
_supports_flex_attn = True
|
| 945 |
+
_supports_cache_class = True
|
| 946 |
+
_supports_quantized_cache = True
|
| 947 |
+
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def subbatch(f, arg_idx, axis, bs, out_idx, same_arg_idx={}):
|
| 951 |
+
"""
|
| 952 |
+
Converts a function to one that applies to subbatch of an input dimension.
|
| 953 |
+
Useful for processing large tensors in smaller chunks to reduce memory usage.
|
| 954 |
+
|
| 955 |
+
Args:
|
| 956 |
+
f (Callable): Function to be subbatched.
|
| 957 |
+
arg_idx ([int]): Indices of the inputs to be subbatched.
|
| 958 |
+
axis ([int]): Indices of the dimensions to be subbatched for each input.
|
| 959 |
+
bs (int): Subbatch size.
|
| 960 |
+
out_idx (int): Dimension to concatenate outputs along.
|
| 961 |
+
same_arg_idx (dict): Mapping of argument indices that share the same tensor.
|
| 962 |
+
|
| 963 |
+
Returns:
|
| 964 |
+
Callable: New function that processes inputs in subbatches.
|
| 965 |
+
"""
|
| 966 |
+
|
| 967 |
+
@functools.wraps(f)
|
| 968 |
+
def wrapper(*args, **kwargs):
|
| 969 |
+
|
| 970 |
+
assert len(arg_idx) == len(
|
| 971 |
+
axis
|
| 972 |
+
), "Number of batching args and number of batching dims should match."
|
| 973 |
+
|
| 974 |
+
inps = [args[i] for i in arg_idx]
|
| 975 |
+
axis_width = [inp.shape[d] for inp, d in zip(inps, axis)]
|
| 976 |
+
assert len(set(axis_width)) == 1, "Batch sizes should be kept equal."
|
| 977 |
+
|
| 978 |
+
inp_axis = {idx: d for idx, d in zip(arg_idx, axis)}
|
| 979 |
+
|
| 980 |
+
axis_width = axis_width[0]
|
| 981 |
+
if axis_width < bs:
|
| 982 |
+
return f(*args, **kwargs)
|
| 983 |
+
|
| 984 |
+
outs = []
|
| 985 |
+
for slice_at in range(0, axis_width, bs):
|
| 986 |
+
_args = []
|
| 987 |
+
for i, inp in enumerate(args):
|
| 988 |
+
if i in same_arg_idx:
|
| 989 |
+
assert (
|
| 990 |
+
i > same_arg_idx[i]
|
| 991 |
+
), f"expect i > same_arg_idx[i], but got i: {i} and same_arg_idx[i]: {same_arg_idx[i]}"
|
| 992 |
+
_args.append(_args[same_arg_idx[i]])
|
| 993 |
+
elif i in arg_idx:
|
| 994 |
+
d = inp_axis[i]
|
| 995 |
+
start = slice_at
|
| 996 |
+
end = min(inp.shape[d], slice_at + bs)
|
| 997 |
+
# Build slice for all dims, only slice along axis d
|
| 998 |
+
slices = [slice(None)] * inp.ndim
|
| 999 |
+
slices[d] = slice(start, end)
|
| 1000 |
+
_args.append(inp[tuple(slices)])
|
| 1001 |
+
else:
|
| 1002 |
+
_args.append(inp)
|
| 1003 |
+
|
| 1004 |
+
out = f(*_args, **kwargs)
|
| 1005 |
+
outs.append(out)
|
| 1006 |
+
|
| 1007 |
+
return torch.cat(outs, dim=out_idx)
|
| 1008 |
+
|
| 1009 |
+
return wrapper
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
class ErniePretrainingCriterion(nn.Module):
|
| 1013 |
+
"""Criterion for ERNIE pretraining task."""
|
| 1014 |
+
|
| 1015 |
+
def __init__(self, config, return_tuple=True):
|
| 1016 |
+
"""Initialize the pretraining criterion.
|
| 1017 |
+
|
| 1018 |
+
Args:
|
| 1019 |
+
config (ErnieConfig): Model configuration.
|
| 1020 |
+
return_tuple (bool): Whether to return loss as tuple (loss, loss_sum). Defaults to True.
|
| 1021 |
+
"""
|
| 1022 |
+
super().__init__()
|
| 1023 |
+
self.ignored_index = getattr(config, "ignored_index", -100)
|
| 1024 |
+
self.config = config
|
| 1025 |
+
self.return_tuple = return_tuple
|
| 1026 |
+
|
| 1027 |
+
self.loss_func = nn.CrossEntropyLoss(reduction="none")
|
| 1028 |
+
|
| 1029 |
+
def forward(self, prediction_scores, masked_lm_labels, loss_mask, router_loss=None):
|
| 1030 |
+
"""Compute the combined pretraining loss.
|
| 1031 |
+
|
| 1032 |
+
Args:
|
| 1033 |
+
prediction_scores: Prediction scores tensor, [batch_size, seq_len, vocab_size]
|
| 1034 |
+
masked_lm_labels: Target labels tensor [batch_size, seq_len]
|
| 1035 |
+
loss_mask: Optional mask for valid tokens
|
| 1036 |
+
router_loss: Optional MoE router loss tensor
|
| 1037 |
+
|
| 1038 |
+
Returns:
|
| 1039 |
+
Union:
|
| 1040 |
+
- If return_tuple=True: Tuple of (combined_loss, mlm_loss_sum)
|
| 1041 |
+
- If return_tuple=False: Combined loss tensor
|
| 1042 |
+
"""
|
| 1043 |
+
res = self.forward_impl(prediction_scores, masked_lm_labels, loss_mask)
|
| 1044 |
+
|
| 1045 |
+
if self.return_tuple:
|
| 1046 |
+
loss, loss_sum = res
|
| 1047 |
+
else:
|
| 1048 |
+
loss, loss_sum = res, None
|
| 1049 |
+
|
| 1050 |
+
if router_loss is not None and isinstance(router_loss, torch.Tensor):
|
| 1051 |
+
loss = loss + router_loss - router_loss.detach()
|
| 1052 |
+
|
| 1053 |
+
return loss, loss_sum
|
| 1054 |
+
|
| 1055 |
+
def loss_impl(
|
| 1056 |
+
self, prediction_scores: torch.Tensor, masked_lm_labels: torch.Tensor
|
| 1057 |
+
) -> torch.Tensor:
|
| 1058 |
+
"""
|
| 1059 |
+
Core loss computation without reduction (but per-token).
|
| 1060 |
+
|
| 1061 |
+
Args:
|
| 1062 |
+
prediction_scores (torch.Tensor): Logits tensor [batch_size, seq_len, vocab_size].
|
| 1063 |
+
masked_lm_labels (torch.Tensor): Target labels tensor [batch_size, seq_len].
|
| 1064 |
+
|
| 1065 |
+
Returns:
|
| 1066 |
+
torch.Tensor: Unreduced loss tensor of shape [batch_size, seq_len].
|
| 1067 |
+
Losses are calculated in float32.
|
| 1068 |
+
"""
|
| 1069 |
+
scores_float32 = prediction_scores.to(torch.float32)
|
| 1070 |
+
# prediction_scores: [batch_size, seq_len, vocab_size]
|
| 1071 |
+
# masked_lm_labels: [batch_size, seq_len]
|
| 1072 |
+
# Transpose prediction_scores to [batch_size, vocab_size, seq_len]
|
| 1073 |
+
unreduced_loss = self.loss_func(
|
| 1074 |
+
scores_float32.transpose(1, 2), # Shape: [batch_size, vocab_size, seq_len]
|
| 1075 |
+
masked_lm_labels.long(), # Shape: [batch_size, seq_len], ensure long type
|
| 1076 |
+
)
|
| 1077 |
+
# unreduced_loss will be of shape [batch_size, seq_len] and dtype float32
|
| 1078 |
+
return unreduced_loss
|
| 1079 |
+
|
| 1080 |
+
def forward_impl(self, prediction_scores, masked_lm_labels, loss_mask=None):
|
| 1081 |
+
"""
|
| 1082 |
+
Loss function forward pass implementation.
|
| 1083 |
+
"""
|
| 1084 |
+
prediction_scores_dims = len(prediction_scores.shape)
|
| 1085 |
+
|
| 1086 |
+
loss_subbatch_seqlen_config_key = "loss_subbatch_seqlen"
|
| 1087 |
+
default_loss_subbatch_seqlen = 32768
|
| 1088 |
+
|
| 1089 |
+
current_loss_subbatch_seqlen = self.config.get(
|
| 1090 |
+
loss_subbatch_seqlen_config_key, default_loss_subbatch_seqlen
|
| 1091 |
+
)
|
| 1092 |
+
|
| 1093 |
+
if (
|
| 1094 |
+
prediction_scores_dims == 2
|
| 1095 |
+
and prediction_scores.shape[0] > current_loss_subbatch_seqlen
|
| 1096 |
+
):
|
| 1097 |
+
sb_loss_func = subbatch(
|
| 1098 |
+
self.loss_impl, [0, 1], [0, 0], current_loss_subbatch_seqlen, 0
|
| 1099 |
+
)
|
| 1100 |
+
masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
|
| 1101 |
+
elif (
|
| 1102 |
+
prediction_scores_dims == 3
|
| 1103 |
+
and prediction_scores.shape[1] > current_loss_subbatch_seqlen
|
| 1104 |
+
):
|
| 1105 |
+
sb_loss_func = subbatch(
|
| 1106 |
+
self.loss_impl, [0, 1], [1, 1], current_loss_subbatch_seqlen, 1
|
| 1107 |
+
)
|
| 1108 |
+
masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
|
| 1109 |
+
else:
|
| 1110 |
+
masked_lm_loss = self.loss_impl(prediction_scores, masked_lm_labels)
|
| 1111 |
+
|
| 1112 |
+
if loss_mask is None:
|
| 1113 |
+
loss_mask = masked_lm_labels != self.ignored_index
|
| 1114 |
+
|
| 1115 |
+
loss_mask = loss_mask.reshape(-1).to(torch.float32)
|
| 1116 |
+
|
| 1117 |
+
masked_lm_loss = torch.sum(
|
| 1118 |
+
masked_lm_loss.to(torch.float32).reshape(-1) * loss_mask
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
# The division will be in float32
|
| 1122 |
+
loss = masked_lm_loss / loss_mask.sum()
|
| 1123 |
+
|
| 1124 |
+
loss_sum = masked_lm_loss.sum().detach()
|
| 1125 |
+
|
| 1126 |
+
if not self.return_tuple:
|
| 1127 |
+
if self.training:
|
| 1128 |
+
return loss
|
| 1129 |
+
return loss_sum
|
| 1130 |
+
return loss, loss_sum
|
| 1131 |
+
|
| 1132 |
+
|
| 1133 |
+
@auto_docstring
|
| 1134 |
+
class Ernie4_5_Model(Ernie4_5_PretrainedModel):
|
| 1135 |
+
"""The core ERNIE transformer model with MoE (Mixture of Experts) support."""
|
| 1136 |
+
|
| 1137 |
+
_keep_in_fp32_modules = ["gate"]
|
| 1138 |
+
|
| 1139 |
+
def __init__(self, config: Ernie4_5_MoeConfig):
|
| 1140 |
+
"""Initialize the ERNIE model architecture."""
|
| 1141 |
+
super().__init__(config)
|
| 1142 |
+
self.padding_idx = config.pad_token_id
|
| 1143 |
+
self.vocab_size = config.vocab_size
|
| 1144 |
+
self.hidden_size = config.hidden_size
|
| 1145 |
+
self.config = config
|
| 1146 |
+
|
| 1147 |
+
self.embed_tokens = nn.Embedding(
|
| 1148 |
+
self.vocab_size,
|
| 1149 |
+
self.hidden_size,
|
| 1150 |
+
)
|
| 1151 |
+
|
| 1152 |
+
self.layers = nn.ModuleList(
|
| 1153 |
+
[Ernie4_5_DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
|
| 1154 |
+
)
|
| 1155 |
+
self.norm = Ernie4_5_RMSNorm(config)
|
| 1156 |
+
self.rotary_emb = Ernie4_5_RopeEmbedding(config=config)
|
| 1157 |
+
|
| 1158 |
+
self.gradient_checkpointing = False
|
| 1159 |
+
|
| 1160 |
+
self.post_init()
|
| 1161 |
+
|
| 1162 |
+
def get_input_embeddings(self):
|
| 1163 |
+
"""Get the input embedding layer."""
|
| 1164 |
+
return self.embed_tokens
|
| 1165 |
+
|
| 1166 |
+
def set_input_embeddings(self, value):
|
| 1167 |
+
"""Set new input embeddings."""
|
| 1168 |
+
self.embed_tokens = value
|
| 1169 |
+
|
| 1170 |
+
def forward(
|
| 1171 |
+
self,
|
| 1172 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1173 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1174 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1175 |
+
past_key_values: Optional[Cache] = None,
|
| 1176 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1177 |
+
use_cache: Optional[bool] = None,
|
| 1178 |
+
output_attentions: Optional[bool] = None,
|
| 1179 |
+
output_hidden_states: Optional[bool] = None,
|
| 1180 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1181 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 1182 |
+
):
|
| 1183 |
+
"""Forward pass through the ERNIE model."""
|
| 1184 |
+
output_attentions = (
|
| 1185 |
+
output_attentions
|
| 1186 |
+
if output_attentions is not None
|
| 1187 |
+
else self.config.output_attentions
|
| 1188 |
+
)
|
| 1189 |
+
output_hidden_states = (
|
| 1190 |
+
output_hidden_states
|
| 1191 |
+
if output_hidden_states is not None
|
| 1192 |
+
else self.config.output_hidden_states
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1196 |
+
raise ValueError(
|
| 1197 |
+
"You must specify exactly one of input_ids or inputs_embeds"
|
| 1198 |
+
)
|
| 1199 |
+
|
| 1200 |
+
if self.gradient_checkpointing and self.training:
|
| 1201 |
+
if use_cache:
|
| 1202 |
+
logger.warning_once(
|
| 1203 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 1204 |
+
)
|
| 1205 |
+
use_cache = False
|
| 1206 |
+
|
| 1207 |
+
if use_cache and past_key_values is None:
|
| 1208 |
+
past_key_values = DynamicCache()
|
| 1209 |
+
|
| 1210 |
+
if inputs_embeds is None:
|
| 1211 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 1212 |
+
|
| 1213 |
+
inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype)
|
| 1214 |
+
|
| 1215 |
+
if cache_position is None:
|
| 1216 |
+
past_seen_tokens = (
|
| 1217 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1218 |
+
)
|
| 1219 |
+
cache_position = torch.arange(
|
| 1220 |
+
past_seen_tokens,
|
| 1221 |
+
past_seen_tokens + inputs_embeds.shape[1],
|
| 1222 |
+
device=inputs_embeds.device,
|
| 1223 |
+
)
|
| 1224 |
+
if position_ids is None:
|
| 1225 |
+
position_ids = cache_position.unsqueeze(0)
|
| 1226 |
+
|
| 1227 |
+
causal_mask = self._update_causal_mask(
|
| 1228 |
+
attention_mask,
|
| 1229 |
+
inputs_embeds,
|
| 1230 |
+
cache_position,
|
| 1231 |
+
past_key_values,
|
| 1232 |
+
output_attentions,
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
hidden_states = inputs_embeds
|
| 1236 |
+
|
| 1237 |
+
# create position embeddings to be shared across the decoder layers
|
| 1238 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1239 |
+
|
| 1240 |
+
# decoder layers
|
| 1241 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1242 |
+
all_self_attns = () if output_attentions else None
|
| 1243 |
+
all_router_loss = (
|
| 1244 |
+
torch.tensor(0.0, device=inputs_embeds.device)
|
| 1245 |
+
if self.config.use_moe
|
| 1246 |
+
else None
|
| 1247 |
+
)
|
| 1248 |
+
all_gate_logits = ()
|
| 1249 |
+
|
| 1250 |
+
for decoder_layer in self.layers:
|
| 1251 |
+
if output_hidden_states:
|
| 1252 |
+
all_hidden_states += (hidden_states,)
|
| 1253 |
+
|
| 1254 |
+
if self.gradient_checkpointing and self.training:
|
| 1255 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1256 |
+
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
| 1257 |
+
hidden_states,
|
| 1258 |
+
causal_mask,
|
| 1259 |
+
position_ids,
|
| 1260 |
+
past_key_values,
|
| 1261 |
+
output_attentions,
|
| 1262 |
+
use_cache,
|
| 1263 |
+
cache_position,
|
| 1264 |
+
position_embeddings,
|
| 1265 |
+
)
|
| 1266 |
+
else:
|
| 1267 |
+
layer_outputs = decoder_layer(
|
| 1268 |
+
hidden_states,
|
| 1269 |
+
causal_mask,
|
| 1270 |
+
position_ids,
|
| 1271 |
+
past_key_values,
|
| 1272 |
+
output_attentions,
|
| 1273 |
+
use_cache,
|
| 1274 |
+
cache_position,
|
| 1275 |
+
position_embeddings,
|
| 1276 |
+
**flash_attn_kwargs,
|
| 1277 |
+
)
|
| 1278 |
+
|
| 1279 |
+
hidden_states = layer_outputs[0]
|
| 1280 |
+
|
| 1281 |
+
if output_attentions:
|
| 1282 |
+
all_self_attns += (layer_outputs[1],)
|
| 1283 |
+
|
| 1284 |
+
if self.config.use_moe:
|
| 1285 |
+
layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1]
|
| 1286 |
+
all_gate_logits = all_gate_logits + (gate_logits,)
|
| 1287 |
+
|
| 1288 |
+
hidden_states = self.norm(hidden_states)
|
| 1289 |
+
|
| 1290 |
+
# add hidden states from the last decoder layer
|
| 1291 |
+
if output_hidden_states:
|
| 1292 |
+
all_hidden_states += (hidden_states,)
|
| 1293 |
+
|
| 1294 |
+
# assert all_router_loss is None, f'moe not support `return-dict`'
|
| 1295 |
+
return Erine4_5_MoeModelOutputWithPast(
|
| 1296 |
+
last_hidden_state=hidden_states,
|
| 1297 |
+
past_key_values=past_key_values,
|
| 1298 |
+
hidden_states=all_hidden_states,
|
| 1299 |
+
attentions=all_self_attns,
|
| 1300 |
+
router_loss=all_router_loss,
|
| 1301 |
+
gate_logits=all_gate_logits,
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
def _update_causal_mask(
|
| 1305 |
+
self,
|
| 1306 |
+
attention_mask: Union[torch.Tensor, "BlockMask"],
|
| 1307 |
+
input_tensor: torch.Tensor,
|
| 1308 |
+
cache_position: torch.Tensor,
|
| 1309 |
+
past_key_values: Cache,
|
| 1310 |
+
output_attentions: bool = False,
|
| 1311 |
+
):
|
| 1312 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 1313 |
+
if attention_mask is not None and past_key_values is not None:
|
| 1314 |
+
is_padding_right = (
|
| 1315 |
+
attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
| 1316 |
+
)
|
| 1317 |
+
if is_padding_right:
|
| 1318 |
+
raise ValueError(
|
| 1319 |
+
"You are attempting to perform batched generation with padding_side='right'"
|
| 1320 |
+
" this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
|
| 1321 |
+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
| 1322 |
+
)
|
| 1323 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 1324 |
+
return attention_mask
|
| 1325 |
+
return None
|
| 1326 |
+
if self.config._attn_implementation == "flex_attention":
|
| 1327 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 1328 |
+
attention_mask = make_flex_block_causal_mask(attention_mask)
|
| 1329 |
+
return attention_mask
|
| 1330 |
+
|
| 1331 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 1332 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 1333 |
+
# to infer the attention mask.
|
| 1334 |
+
past_seen_tokens = (
|
| 1335 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1336 |
+
)
|
| 1337 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 1338 |
+
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
| 1339 |
+
|
| 1340 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 1341 |
+
if (
|
| 1342 |
+
self.config._attn_implementation == "sdpa"
|
| 1343 |
+
and not (using_static_cache or using_sliding_window_cache)
|
| 1344 |
+
and not output_attentions
|
| 1345 |
+
):
|
| 1346 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 1347 |
+
attention_mask,
|
| 1348 |
+
inputs_embeds=input_tensor,
|
| 1349 |
+
past_key_values_length=past_seen_tokens,
|
| 1350 |
+
sliding_window=self.config.sliding_window,
|
| 1351 |
+
is_training=self.training,
|
| 1352 |
+
):
|
| 1353 |
+
return None
|
| 1354 |
+
|
| 1355 |
+
dtype = input_tensor.dtype
|
| 1356 |
+
min_dtype = torch.finfo(dtype).min
|
| 1357 |
+
sequence_length = input_tensor.shape[1]
|
| 1358 |
+
# SlidingWindowCache or StaticCache
|
| 1359 |
+
if using_sliding_window_cache or using_static_cache:
|
| 1360 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 1361 |
+
# DynamicCache or no cache
|
| 1362 |
+
else:
|
| 1363 |
+
target_length = (
|
| 1364 |
+
attention_mask.shape[-1]
|
| 1365 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 1366 |
+
else past_seen_tokens + sequence_length + 1
|
| 1367 |
+
)
|
| 1368 |
+
|
| 1369 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 1370 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 1371 |
+
attention_mask,
|
| 1372 |
+
sequence_length=sequence_length,
|
| 1373 |
+
target_length=target_length,
|
| 1374 |
+
dtype=dtype,
|
| 1375 |
+
cache_position=cache_position,
|
| 1376 |
+
batch_size=input_tensor.shape[0],
|
| 1377 |
+
config=self.config,
|
| 1378 |
+
past_key_values=past_key_values,
|
| 1379 |
+
)
|
| 1380 |
+
|
| 1381 |
+
if (
|
| 1382 |
+
self.config._attn_implementation == "sdpa"
|
| 1383 |
+
and attention_mask is not None
|
| 1384 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 1385 |
+
and not output_attentions
|
| 1386 |
+
):
|
| 1387 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 1388 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 1389 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 1390 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(
|
| 1391 |
+
causal_mask, min_dtype
|
| 1392 |
+
)
|
| 1393 |
+
|
| 1394 |
+
return causal_mask
|
| 1395 |
+
|
| 1396 |
+
@staticmethod
|
| 1397 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1398 |
+
attention_mask: torch.Tensor,
|
| 1399 |
+
sequence_length: int,
|
| 1400 |
+
target_length: int,
|
| 1401 |
+
dtype: torch.dtype,
|
| 1402 |
+
cache_position: torch.Tensor,
|
| 1403 |
+
batch_size: int,
|
| 1404 |
+
config: Ernie4_5_MoeConfig,
|
| 1405 |
+
past_key_values: Cache,
|
| 1406 |
+
):
|
| 1407 |
+
"""
|
| 1408 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 1409 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 1410 |
+
|
| 1411 |
+
Args:
|
| 1412 |
+
attention_mask (`torch.Tensor`):
|
| 1413 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`,
|
| 1414 |
+
or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
| 1415 |
+
sequence_length (`int`):
|
| 1416 |
+
The sequence length being processed.
|
| 1417 |
+
target_length (`int`):
|
| 1418 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 1419 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 1420 |
+
dtype (`torch.dtype`):
|
| 1421 |
+
The dtype to use for the 4D attention mask.
|
| 1422 |
+
cache_position (`torch.Tensor`):
|
| 1423 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 1424 |
+
batch_size (`torch.Tensor`):
|
| 1425 |
+
Batch size.
|
| 1426 |
+
config (`Ernie4_5_MoeConfig`):
|
| 1427 |
+
The model's configuration class
|
| 1428 |
+
past_key_values (`Cache`):
|
| 1429 |
+
The cache class that is being used currently to generate
|
| 1430 |
+
"""
|
| 1431 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 1432 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 1433 |
+
causal_mask = attention_mask
|
| 1434 |
+
else:
|
| 1435 |
+
min_dtype = torch.finfo(dtype).min
|
| 1436 |
+
causal_mask = torch.full(
|
| 1437 |
+
(sequence_length, target_length),
|
| 1438 |
+
fill_value=min_dtype,
|
| 1439 |
+
dtype=dtype,
|
| 1440 |
+
device=cache_position.device,
|
| 1441 |
+
)
|
| 1442 |
+
diagonal_attend_mask = torch.arange(
|
| 1443 |
+
target_length, device=cache_position.device
|
| 1444 |
+
) > cache_position.reshape(-1, 1)
|
| 1445 |
+
text_config = config.get_text_config()
|
| 1446 |
+
if (
|
| 1447 |
+
getattr(text_config, "use_sliding_window", True)
|
| 1448 |
+
and text_config.sliding_window is not None
|
| 1449 |
+
):
|
| 1450 |
+
if (
|
| 1451 |
+
not isinstance(past_key_values, SlidingWindowCache)
|
| 1452 |
+
or sequence_length > target_length
|
| 1453 |
+
):
|
| 1454 |
+
sliding_attend_mask = torch.arange(
|
| 1455 |
+
target_length, device=cache_position.device
|
| 1456 |
+
) <= (cache_position.reshape(-1, 1) - text_config.sliding_window)
|
| 1457 |
+
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
| 1458 |
+
causal_mask *= diagonal_attend_mask
|
| 1459 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 1460 |
+
if attention_mask is not None:
|
| 1461 |
+
causal_mask = (
|
| 1462 |
+
causal_mask.clone()
|
| 1463 |
+
) # copy to contiguous memory for in-place edit
|
| 1464 |
+
if attention_mask.shape[-1] > target_length:
|
| 1465 |
+
attention_mask = attention_mask[:, :target_length]
|
| 1466 |
+
mask_length = attention_mask.shape[-1]
|
| 1467 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
|
| 1468 |
+
:, None, None, :
|
| 1469 |
+
].to(causal_mask.device)
|
| 1470 |
+
padding_mask = padding_mask == 0
|
| 1471 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[
|
| 1472 |
+
:, :, :, :mask_length
|
| 1473 |
+
].masked_fill(padding_mask, min_dtype)
|
| 1474 |
+
return causal_mask
|
| 1475 |
+
|
| 1476 |
+
|
| 1477 |
+
@auto_docstring
|
| 1478 |
+
class Ernie4_5_MoeForCausalLM(Ernie4_5_PretrainedModel, GenerationMixin):
|
| 1479 |
+
"""ERNIE Mixture of Experts (MoE) model for causal language modeling."""
|
| 1480 |
+
|
| 1481 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1482 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 1483 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 1484 |
+
|
| 1485 |
+
def __init__(self, config):
|
| 1486 |
+
"""
|
| 1487 |
+
Initializes the ERNIE MoE model for causal language modeling.
|
| 1488 |
+
|
| 1489 |
+
Args:
|
| 1490 |
+
config (dict): Model configuration.
|
| 1491 |
+
"""
|
| 1492 |
+
super().__init__(config)
|
| 1493 |
+
self.config = config
|
| 1494 |
+
self.model = Ernie4_5_Model(config)
|
| 1495 |
+
self.lm_head = nn.Linear(
|
| 1496 |
+
config.hidden_size,
|
| 1497 |
+
config.vocab_size,
|
| 1498 |
+
bias=config.weight_share_add_bias and config.use_bias,
|
| 1499 |
+
) # TODO
|
| 1500 |
+
self.loss_function = ErniePretrainingCriterion(config)
|
| 1501 |
+
|
| 1502 |
+
# Initialize weights and apply final processing
|
| 1503 |
+
self.post_init()
|
| 1504 |
+
|
| 1505 |
+
def get_input_embeddings(self):
|
| 1506 |
+
"""Returns the input embeddings layer."""
|
| 1507 |
+
return self.model.embed_tokens
|
| 1508 |
+
|
| 1509 |
+
def set_input_embeddings(self, value):
|
| 1510 |
+
"""Sets the input embeddings layer."""
|
| 1511 |
+
self.ernie.embed_tokens = value
|
| 1512 |
+
|
| 1513 |
+
def get_output_embeddings(self):
|
| 1514 |
+
"""Returns the output embeddings (LM head)."""
|
| 1515 |
+
return self.lm_head
|
| 1516 |
+
|
| 1517 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1518 |
+
"""Sets the output embeddings layer."""
|
| 1519 |
+
self.lm_head = new_embeddings
|
| 1520 |
+
|
| 1521 |
+
def set_decoder(self, decoder):
|
| 1522 |
+
"""Sets the ERNIE decoder model."""
|
| 1523 |
+
self.model = decoder
|
| 1524 |
+
|
| 1525 |
+
def get_decoder(self):
|
| 1526 |
+
"""Get the transformer decoder."""
|
| 1527 |
+
return self.model
|
| 1528 |
+
|
| 1529 |
+
@can_return_tuple
|
| 1530 |
+
def forward(
|
| 1531 |
+
self,
|
| 1532 |
+
input_ids,
|
| 1533 |
+
attention_mask=None,
|
| 1534 |
+
position_ids=None,
|
| 1535 |
+
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
| 1536 |
+
inputs_embeds=None,
|
| 1537 |
+
labels=None,
|
| 1538 |
+
loss_mask=None,
|
| 1539 |
+
use_cache=False,
|
| 1540 |
+
output_attentions: Optional[bool] = None,
|
| 1541 |
+
output_hidden_states: Optional[bool] = None,
|
| 1542 |
+
**kwargs: Unpack[KwargsForCausalLM],
|
| 1543 |
+
):
|
| 1544 |
+
"""
|
| 1545 |
+
Forward pass for causal language modeling.
|
| 1546 |
+
"""
|
| 1547 |
+
output_attentions = (
|
| 1548 |
+
output_attentions
|
| 1549 |
+
if output_attentions is not None
|
| 1550 |
+
else self.config.output_attentions
|
| 1551 |
+
)
|
| 1552 |
+
output_hidden_states = (
|
| 1553 |
+
output_hidden_states
|
| 1554 |
+
if output_hidden_states is not None
|
| 1555 |
+
else self.config.output_hidden_states
|
| 1556 |
+
)
|
| 1557 |
+
|
| 1558 |
+
outputs = self.model(
|
| 1559 |
+
input_ids,
|
| 1560 |
+
position_ids=position_ids,
|
| 1561 |
+
attention_mask=attention_mask,
|
| 1562 |
+
inputs_embeds=inputs_embeds,
|
| 1563 |
+
use_cache=use_cache,
|
| 1564 |
+
past_key_values=past_key_values,
|
| 1565 |
+
output_attentions=output_attentions,
|
| 1566 |
+
output_hidden_states=output_hidden_states,
|
| 1567 |
+
**kwargs,
|
| 1568 |
+
)
|
| 1569 |
+
|
| 1570 |
+
hidden_states = outputs.last_hidden_state
|
| 1571 |
+
logits = self.lm_head(hidden_states)
|
| 1572 |
+
|
| 1573 |
+
loss, router_loss = None, None
|
| 1574 |
+
if getattr(self.config, "use_moe", False):
|
| 1575 |
+
router_loss = outputs.router_loss
|
| 1576 |
+
|
| 1577 |
+
if labels is not None:
|
| 1578 |
+
loss, _ = self.loss_function(logits, labels, loss_mask, router_loss)
|
| 1579 |
+
|
| 1580 |
+
return Ernie4_5_MoeCausalLMOutputWithPast(
|
| 1581 |
+
loss=loss,
|
| 1582 |
+
logits=logits,
|
| 1583 |
+
past_key_values=outputs.past_key_values,
|
| 1584 |
+
hidden_states=outputs.hidden_states,
|
| 1585 |
+
attentions=outputs.attentions,
|
| 1586 |
+
router_loss=router_loss,
|
| 1587 |
+
)
|
| 1588 |
+
|
| 1589 |
+
|
| 1590 |
+
__all__ = ["Ernie4_5_Model", "Ernie4_5_MoeForCausalLM", "Ernie4_5_PretrainedModel"]
|
quantization_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<unk>", "unk_token": "<unk>", "cls_token": "<|begin_of_sentence|>", "sep_token": "<|end_of_sentence|>", "mask_token": "<mask:1>", "sys_start_token": "<mask:4>", "sys_end_token": "<mask:5>", "header_start_token": "<mask:6>", "header_end_token": "<mask:7>", "additional_special_tokens": ["<|IMAGE_PLACEHOLDER|>", "<|AUDIO_PLACEHOLDER|>", "<|LOC_0|>", "<|LOC_1|>", "<|LOC_2|>", "<|LOC_3|>", "<|LOC_4|>", "<|LOC_5|>", "<|LOC_6|>", "<|LOC_7|>", "<|LOC_8|>", "<|LOC_9|>", "<|LOC_10|>", "<|LOC_11|>", "<|LOC_12|>", "<|LOC_13|>", "<|LOC_14|>", "<|LOC_15|>", "<|LOC_16|>", "<|LOC_17|>", "<|LOC_18|>", "<|LOC_19|>", "<|LOC_20|>", "<|LOC_21|>", "<|LOC_22|>", "<|LOC_23|>", "<|LOC_24|>", "<|LOC_25|>", "<|LOC_26|>", "<|LOC_27|>", "<|LOC_28|>", "<|LOC_29|>", "<|LOC_30|>", "<|LOC_31|>", "<|LOC_32|>", "<|LOC_33|>", "<|LOC_34|>", "<|LOC_35|>", "<|LOC_36|>", "<|LOC_37|>", "<|LOC_38|>", "<|LOC_39|>", "<|LOC_40|>", "<|LOC_41|>", "<|LOC_42|>", "<|LOC_43|>", "<|LOC_44|>", "<|LOC_45|>", "<|LOC_46|>", "<|LOC_47|>", "<|LOC_48|>", "<|LOC_49|>", "<|LOC_50|>", "<|LOC_51|>", "<|LOC_52|>", "<|LOC_53|>", "<|LOC_54|>", "<|LOC_55|>", "<|LOC_56|>", "<|LOC_57|>", "<|LOC_58|>", "<|LOC_59|>", "<|LOC_60|>", "<|LOC_61|>", "<|LOC_62|>", "<|LOC_63|>", "<|LOC_64|>", "<|LOC_65|>", "<|LOC_66|>", "<|LOC_67|>", "<|LOC_68|>", "<|LOC_69|>", "<|LOC_70|>", "<|LOC_71|>", "<|LOC_72|>", "<|LOC_73|>", "<|LOC_74|>", "<|LOC_75|>", "<|LOC_76|>", "<|LOC_77|>", "<|LOC_78|>", "<|LOC_79|>", "<|LOC_80|>", "<|LOC_81|>", "<|LOC_82|>", "<|LOC_83|>", "<|LOC_84|>", "<|LOC_85|>", "<|LOC_86|>", "<|LOC_87|>", "<|LOC_88|>", "<|LOC_89|>", "<|LOC_90|>", "<|LOC_91|>", "<|LOC_92|>", "<|LOC_93|>", "<|LOC_94|>", "<|LOC_95|>", "<|LOC_96|>", "<|LOC_97|>", "<|LOC_98|>", "<|LOC_99|>", "<|LOC_100|>", "<|LOC_101|>", "<|LOC_102|>", "<|LOC_103|>", "<|LOC_104|>", "<|LOC_105|>", "<|LOC_106|>", "<|LOC_107|>", "<|LOC_108|>", "<|LOC_109|>", "<|LOC_110|>", "<|LOC_111|>", "<|LOC_112|>", "<|LOC_113|>", "<|LOC_114|>", "<|LOC_115|>", "<|LOC_116|>", "<|LOC_117|>", "<|LOC_118|>", "<|LOC_119|>", "<|LOC_120|>", "<|LOC_121|>", "<|LOC_122|>", "<|LOC_123|>", "<|LOC_124|>", "<|LOC_125|>", "<|LOC_126|>", "<|LOC_127|>", "<|LOC_128|>", "<|LOC_129|>", "<|LOC_130|>", "<|LOC_131|>", "<|LOC_132|>", "<|LOC_133|>", "<|LOC_134|>", "<|LOC_135|>", "<|LOC_136|>", "<|LOC_137|>", "<|LOC_138|>", "<|LOC_139|>", "<|LOC_140|>", "<|LOC_141|>", "<|LOC_142|>", "<|LOC_143|>", "<|LOC_144|>", "<|LOC_145|>", "<|LOC_146|>", "<|LOC_147|>", "<|LOC_148|>", "<|LOC_149|>", "<|LOC_150|>", "<|LOC_151|>", "<|LOC_152|>", "<|LOC_153|>", "<|LOC_154|>", "<|LOC_155|>", "<|LOC_156|>", "<|LOC_157|>", "<|LOC_158|>", "<|LOC_159|>", "<|LOC_160|>", "<|LOC_161|>", "<|LOC_162|>", "<|LOC_163|>", "<|LOC_164|>", "<|LOC_165|>", "<|LOC_166|>", "<|LOC_167|>", "<|LOC_168|>", "<|LOC_169|>", "<|LOC_170|>", "<|LOC_171|>", "<|LOC_172|>", "<|LOC_173|>", "<|LOC_174|>", "<|LOC_175|>", "<|LOC_176|>", "<|LOC_177|>", "<|LOC_178|>", "<|LOC_179|>", "<|LOC_180|>", "<|LOC_181|>", "<|LOC_182|>", "<|LOC_183|>", "<|LOC_184|>", "<|LOC_185|>", "<|LOC_186|>", "<|LOC_187|>", "<|LOC_188|>", "<|LOC_189|>", "<|LOC_190|>", "<|LOC_191|>", "<|LOC_192|>", "<|LOC_193|>", "<|LOC_194|>", "<|LOC_195|>", "<|LOC_196|>", "<|LOC_197|>", "<|LOC_198|>", "<|LOC_199|>", "<|LOC_200|>", "<|LOC_201|>", "<|LOC_202|>", "<|LOC_203|>", "<|LOC_204|>", "<|LOC_205|>", "<|LOC_206|>", "<|LOC_207|>", "<|LOC_208|>", "<|LOC_209|>", "<|LOC_210|>", "<|LOC_211|>", "<|LOC_212|>", "<|LOC_213|>", "<|LOC_214|>", "<|LOC_215|>", "<|LOC_216|>", "<|LOC_217|>", "<|LOC_218|>", "<|LOC_219|>", "<|LOC_220|>", "<|LOC_221|>", "<|LOC_222|>", "<|LOC_223|>", "<|LOC_224|>", "<|LOC_225|>", "<|LOC_226|>", "<|LOC_227|>", "<|LOC_228|>", "<|LOC_229|>", "<|LOC_230|>", "<|LOC_231|>", "<|LOC_232|>", "<|LOC_233|>", "<|LOC_234|>", "<|LOC_235|>", "<|LOC_236|>", "<|LOC_237|>", "<|LOC_238|>", "<|LOC_239|>", "<|LOC_240|>", "<|LOC_241|>", "<|LOC_242|>", "<|LOC_243|>", "<|LOC_244|>", "<|LOC_245|>", "<|LOC_246|>", "<|LOC_247|>", "<|LOC_248|>", "<|LOC_249|>", "<|LOC_250|>", "<|LOC_251|>", "<|LOC_252|>", "<|LOC_253|>", "<|LOC_254|>", "<|LOC_255|>", "<|LOC_256|>", "<|LOC_257|>", "<|LOC_258|>", "<|LOC_259|>", "<|LOC_260|>", "<|LOC_261|>", "<|LOC_262|>", "<|LOC_263|>", "<|LOC_264|>", "<|LOC_265|>", "<|LOC_266|>", "<|LOC_267|>", "<|LOC_268|>", "<|LOC_269|>", "<|LOC_270|>", "<|LOC_271|>", "<|LOC_272|>", "<|LOC_273|>", "<|LOC_274|>", "<|LOC_275|>", "<|LOC_276|>", "<|LOC_277|>", "<|LOC_278|>", "<|LOC_279|>", "<|LOC_280|>", "<|LOC_281|>", "<|LOC_282|>", "<|LOC_283|>", "<|LOC_284|>", "<|LOC_285|>", "<|LOC_286|>", "<|LOC_287|>", "<|LOC_288|>", "<|LOC_289|>", "<|LOC_290|>", "<|LOC_291|>", "<|LOC_292|>", "<|LOC_293|>", "<|LOC_294|>", "<|LOC_295|>", "<|LOC_296|>", "<|LOC_297|>", "<|LOC_298|>", "<|LOC_299|>", "<|LOC_300|>", "<|LOC_301|>", "<|LOC_302|>", "<|LOC_303|>", "<|LOC_304|>", "<|LOC_305|>", "<|LOC_306|>", "<|LOC_307|>", "<|LOC_308|>", "<|LOC_309|>", "<|LOC_310|>", "<|LOC_311|>", "<|LOC_312|>", "<|LOC_313|>", "<|LOC_314|>", "<|LOC_315|>", "<|LOC_316|>", "<|LOC_317|>", "<|LOC_318|>", "<|LOC_319|>", "<|LOC_320|>", "<|LOC_321|>", "<|LOC_322|>", "<|LOC_323|>", "<|LOC_324|>", "<|LOC_325|>", "<|LOC_326|>", "<|LOC_327|>", "<|LOC_328|>", "<|LOC_329|>", "<|LOC_330|>", "<|LOC_331|>", "<|LOC_332|>", "<|LOC_333|>", "<|LOC_334|>", "<|LOC_335|>", "<|LOC_336|>", "<|LOC_337|>", "<|LOC_338|>", "<|LOC_339|>", "<|LOC_340|>", "<|LOC_341|>", "<|LOC_342|>", "<|LOC_343|>", "<|LOC_344|>", "<|LOC_345|>", "<|LOC_346|>", "<|LOC_347|>", "<|LOC_348|>", "<|LOC_349|>", "<|LOC_350|>", "<|LOC_351|>", "<|LOC_352|>", "<|LOC_353|>", "<|LOC_354|>", "<|LOC_355|>", "<|LOC_356|>", "<|LOC_357|>", "<|LOC_358|>", "<|LOC_359|>", "<|LOC_360|>", "<|LOC_361|>", "<|LOC_362|>", "<|LOC_363|>", "<|LOC_364|>", "<|LOC_365|>", "<|LOC_366|>", "<|LOC_367|>", "<|LOC_368|>", "<|LOC_369|>", "<|LOC_370|>", "<|LOC_371|>", "<|LOC_372|>", "<|LOC_373|>", "<|LOC_374|>", "<|LOC_375|>", "<|LOC_376|>", "<|LOC_377|>", "<|LOC_378|>", "<|LOC_379|>", "<|LOC_380|>", "<|LOC_381|>", "<|LOC_382|>", "<|LOC_383|>", "<|LOC_384|>", "<|LOC_385|>", "<|LOC_386|>", "<|LOC_387|>", "<|LOC_388|>", "<|LOC_389|>", "<|LOC_390|>", "<|LOC_391|>", "<|LOC_392|>", "<|LOC_393|>", "<|LOC_394|>", "<|LOC_395|>", "<|LOC_396|>", "<|LOC_397|>", "<|LOC_398|>", "<|LOC_399|>", "<|LOC_400|>", "<|LOC_401|>", "<|LOC_402|>", "<|LOC_403|>", "<|LOC_404|>", "<|LOC_405|>", "<|LOC_406|>", "<|LOC_407|>", "<|LOC_408|>", "<|LOC_409|>", "<|LOC_410|>", "<|LOC_411|>", "<|LOC_412|>", "<|LOC_413|>", "<|LOC_414|>", "<|LOC_415|>", "<|LOC_416|>", "<|LOC_417|>", "<|LOC_418|>", "<|LOC_419|>", "<|LOC_420|>", "<|LOC_421|>", "<|LOC_422|>", "<|LOC_423|>", "<|LOC_424|>", "<|LOC_425|>", "<|LOC_426|>", "<|LOC_427|>", "<|LOC_428|>", "<|LOC_429|>", "<|LOC_430|>", "<|LOC_431|>", "<|LOC_432|>", "<|LOC_433|>", "<|LOC_434|>", "<|LOC_435|>", "<|LOC_436|>", "<|LOC_437|>", "<|LOC_438|>", "<|LOC_439|>", "<|LOC_440|>", "<|LOC_441|>", "<|LOC_442|>", "<|LOC_443|>", "<|LOC_444|>", "<|LOC_445|>", "<|LOC_446|>", "<|LOC_447|>", "<|LOC_448|>", "<|LOC_449|>", "<|LOC_450|>", "<|LOC_451|>", "<|LOC_452|>", "<|LOC_453|>", "<|LOC_454|>", "<|LOC_455|>", "<|LOC_456|>", "<|LOC_457|>", "<|LOC_458|>", "<|LOC_459|>", "<|LOC_460|>", "<|LOC_461|>", "<|LOC_462|>", "<|LOC_463|>", "<|LOC_464|>", "<|LOC_465|>", "<|LOC_466|>", "<|LOC_467|>", "<|LOC_468|>", "<|LOC_469|>", "<|LOC_470|>", "<|LOC_471|>", "<|LOC_472|>", "<|LOC_473|>", "<|LOC_474|>", "<|LOC_475|>", "<|LOC_476|>", "<|LOC_477|>", "<|LOC_478|>", "<|LOC_479|>", "<|LOC_480|>", "<|LOC_481|>", "<|LOC_482|>", "<|LOC_483|>", "<|LOC_484|>", "<|LOC_485|>", "<|LOC_486|>", "<|LOC_487|>", "<|LOC_488|>", "<|LOC_489|>", "<|LOC_490|>", "<|LOC_491|>", "<|LOC_492|>", "<|LOC_493|>", "<|LOC_494|>", "<|LOC_495|>", "<|LOC_496|>", "<|LOC_497|>", "<|LOC_498|>", "<|LOC_499|>", "<|LOC_500|>", "<|LOC_501|>", "<|LOC_502|>", "<|LOC_503|>", "<|LOC_504|>", "<|LOC_505|>", "<|LOC_506|>", "<|LOC_507|>", "<|LOC_508|>", "<|LOC_509|>", "<|LOC_510|>", "<|LOC_511|>", "<|LOC_512|>", "<|LOC_513|>", "<|LOC_514|>", "<|LOC_515|>", "<|LOC_516|>", "<|LOC_517|>", "<|LOC_518|>", "<|LOC_519|>", "<|LOC_520|>", "<|LOC_521|>", "<|LOC_522|>", "<|LOC_523|>", "<|LOC_524|>", "<|LOC_525|>", "<|LOC_526|>", "<|LOC_527|>", "<|LOC_528|>", "<|LOC_529|>", "<|LOC_530|>", "<|LOC_531|>", "<|LOC_532|>", "<|LOC_533|>", "<|LOC_534|>", "<|LOC_535|>", "<|LOC_536|>", "<|LOC_537|>", "<|LOC_538|>", "<|LOC_539|>", "<|LOC_540|>", "<|LOC_541|>", "<|LOC_542|>", "<|LOC_543|>", "<|LOC_544|>", "<|LOC_545|>", "<|LOC_546|>", "<|LOC_547|>", "<|LOC_548|>", "<|LOC_549|>", "<|LOC_550|>", "<|LOC_551|>", "<|LOC_552|>", "<|LOC_553|>", "<|LOC_554|>", "<|LOC_555|>", "<|LOC_556|>", "<|LOC_557|>", "<|LOC_558|>", "<|LOC_559|>", "<|LOC_560|>", "<|LOC_561|>", "<|LOC_562|>", "<|LOC_563|>", "<|LOC_564|>", "<|LOC_565|>", "<|LOC_566|>", "<|LOC_567|>", "<|LOC_568|>", "<|LOC_569|>", "<|LOC_570|>", "<|LOC_571|>", "<|LOC_572|>", "<|LOC_573|>", "<|LOC_574|>", "<|LOC_575|>", "<|LOC_576|>", "<|LOC_577|>", "<|LOC_578|>", "<|LOC_579|>", "<|LOC_580|>", "<|LOC_581|>", "<|LOC_582|>", "<|LOC_583|>", "<|LOC_584|>", "<|LOC_585|>", "<|LOC_586|>", "<|LOC_587|>", "<|LOC_588|>", "<|LOC_589|>", "<|LOC_590|>", "<|LOC_591|>", "<|LOC_592|>", "<|LOC_593|>", "<|LOC_594|>", "<|LOC_595|>", "<|LOC_596|>", "<|LOC_597|>", "<|LOC_598|>", "<|LOC_599|>", "<|LOC_600|>", "<|LOC_601|>", "<|LOC_602|>", "<|LOC_603|>", "<|LOC_604|>", "<|LOC_605|>", "<|LOC_606|>", "<|LOC_607|>", "<|LOC_608|>", "<|LOC_609|>", "<|LOC_610|>", "<|LOC_611|>", "<|LOC_612|>", "<|LOC_613|>", "<|LOC_614|>", "<|LOC_615|>", "<|LOC_616|>", "<|LOC_617|>", "<|LOC_618|>", "<|LOC_619|>", "<|LOC_620|>", "<|LOC_621|>", "<|LOC_622|>", "<|LOC_623|>", "<|LOC_624|>", "<|LOC_625|>", "<|LOC_626|>", "<|LOC_627|>", "<|LOC_628|>", "<|LOC_629|>", "<|LOC_630|>", "<|LOC_631|>", "<|LOC_632|>", "<|LOC_633|>", "<|LOC_634|>", "<|LOC_635|>", "<|LOC_636|>", "<|LOC_637|>", "<|LOC_638|>", "<|LOC_639|>", "<|LOC_640|>", "<|LOC_641|>", "<|LOC_642|>", "<|LOC_643|>", "<|LOC_644|>", "<|LOC_645|>", "<|LOC_646|>", "<|LOC_647|>", "<|LOC_648|>", "<|LOC_649|>", "<|LOC_650|>", "<|LOC_651|>", "<|LOC_652|>", "<|LOC_653|>", "<|LOC_654|>", "<|LOC_655|>", "<|LOC_656|>", "<|LOC_657|>", "<|LOC_658|>", "<|LOC_659|>", "<|LOC_660|>", "<|LOC_661|>", "<|LOC_662|>", "<|LOC_663|>", "<|LOC_664|>", "<|LOC_665|>", "<|LOC_666|>", "<|LOC_667|>", "<|LOC_668|>", "<|LOC_669|>", "<|LOC_670|>", "<|LOC_671|>", "<|LOC_672|>", "<|LOC_673|>", "<|LOC_674|>", "<|LOC_675|>", "<|LOC_676|>", "<|LOC_677|>", "<|LOC_678|>", "<|LOC_679|>", "<|LOC_680|>", "<|LOC_681|>", "<|LOC_682|>", "<|LOC_683|>", "<|LOC_684|>", "<|LOC_685|>", "<|LOC_686|>", "<|LOC_687|>", "<|LOC_688|>", "<|LOC_689|>", "<|LOC_690|>", "<|LOC_691|>", "<|LOC_692|>", "<|LOC_693|>", "<|LOC_694|>", "<|LOC_695|>", "<|LOC_696|>", "<|LOC_697|>", "<|LOC_698|>", "<|LOC_699|>", "<|LOC_700|>", "<|LOC_701|>", "<|LOC_702|>", "<|LOC_703|>", "<|LOC_704|>", "<|LOC_705|>", "<|LOC_706|>", "<|LOC_707|>", "<|LOC_708|>", "<|LOC_709|>", "<|LOC_710|>", "<|LOC_711|>", "<|LOC_712|>", "<|LOC_713|>", "<|LOC_714|>", "<|LOC_715|>", "<|LOC_716|>", "<|LOC_717|>", "<|LOC_718|>", "<|LOC_719|>", "<|LOC_720|>", "<|LOC_721|>", "<|LOC_722|>", "<|LOC_723|>", "<|LOC_724|>", "<|LOC_725|>", "<|LOC_726|>", "<|LOC_727|>", "<|LOC_728|>", "<|LOC_729|>", "<|LOC_730|>", "<|LOC_731|>", "<|LOC_732|>", "<|LOC_733|>", "<|LOC_734|>", "<|LOC_735|>", "<|LOC_736|>", "<|LOC_737|>", "<|LOC_738|>", "<|LOC_739|>", "<|LOC_740|>", "<|LOC_741|>", "<|LOC_742|>", "<|LOC_743|>", "<|LOC_744|>", "<|LOC_745|>", "<|LOC_746|>", "<|LOC_747|>", "<|LOC_748|>", "<|LOC_749|>", "<|LOC_750|>", "<|LOC_751|>", "<|LOC_752|>", "<|LOC_753|>", "<|LOC_754|>", "<|LOC_755|>", "<|LOC_756|>", "<|LOC_757|>", "<|LOC_758|>", "<|LOC_759|>", "<|LOC_760|>", "<|LOC_761|>", "<|LOC_762|>", "<|LOC_763|>", "<|LOC_764|>", "<|LOC_765|>", "<|LOC_766|>", "<|LOC_767|>", "<|LOC_768|>", "<|LOC_769|>", "<|LOC_770|>", "<|LOC_771|>", "<|LOC_772|>", "<|LOC_773|>", "<|LOC_774|>", "<|LOC_775|>", "<|LOC_776|>", "<|LOC_777|>", "<|LOC_778|>", "<|LOC_779|>", "<|LOC_780|>", "<|LOC_781|>", "<|LOC_782|>", "<|LOC_783|>", "<|LOC_784|>", "<|LOC_785|>", "<|LOC_786|>", "<|LOC_787|>", "<|LOC_788|>", "<|LOC_789|>", "<|LOC_790|>", "<|LOC_791|>", "<|LOC_792|>", "<|LOC_793|>", "<|LOC_794|>", "<|LOC_795|>", "<|LOC_796|>", "<|LOC_797|>", "<|LOC_798|>", "<|LOC_799|>", "<|LOC_800|>", "<|LOC_801|>", "<|LOC_802|>", "<|LOC_803|>", "<|LOC_804|>", "<|LOC_805|>", "<|LOC_806|>", "<|LOC_807|>", "<|LOC_808|>", "<|LOC_809|>", "<|LOC_810|>", "<|LOC_811|>", "<|LOC_812|>", "<|LOC_813|>", "<|LOC_814|>", "<|LOC_815|>", "<|LOC_816|>", "<|LOC_817|>", "<|LOC_818|>", "<|LOC_819|>", "<|LOC_820|>", "<|LOC_821|>", "<|LOC_822|>", "<|LOC_823|>", "<|LOC_824|>", "<|LOC_825|>", "<|LOC_826|>", "<|LOC_827|>", "<|LOC_828|>", "<|LOC_829|>", "<|LOC_830|>", "<|LOC_831|>", "<|LOC_832|>", "<|LOC_833|>", "<|LOC_834|>", "<|LOC_835|>", "<|LOC_836|>", "<|LOC_837|>", "<|LOC_838|>", "<|LOC_839|>", "<|LOC_840|>", "<|LOC_841|>", "<|LOC_842|>", "<|LOC_843|>", "<|LOC_844|>", "<|LOC_845|>", "<|LOC_846|>", "<|LOC_847|>", "<|LOC_848|>", "<|LOC_849|>", "<|LOC_850|>", "<|LOC_851|>", "<|LOC_852|>", "<|LOC_853|>", "<|LOC_854|>", "<|LOC_855|>", "<|LOC_856|>", "<|LOC_857|>", "<|LOC_858|>", "<|LOC_859|>", "<|LOC_860|>", "<|LOC_861|>", "<|LOC_862|>", "<|LOC_863|>", "<|LOC_864|>", "<|LOC_865|>", "<|LOC_866|>", "<|LOC_867|>", "<|LOC_868|>", "<|LOC_869|>", "<|LOC_870|>", "<|LOC_871|>", "<|LOC_872|>", "<|LOC_873|>", "<|LOC_874|>", "<|LOC_875|>", "<|LOC_876|>", "<|LOC_877|>", "<|LOC_878|>", "<|LOC_879|>", "<|LOC_880|>", "<|LOC_881|>", "<|LOC_882|>", "<|LOC_883|>", "<|LOC_884|>", "<|LOC_885|>", "<|LOC_886|>", "<|LOC_887|>", "<|LOC_888|>", "<|LOC_889|>", "<|LOC_890|>", "<|LOC_891|>", "<|LOC_892|>", "<|LOC_893|>", "<|LOC_894|>", "<|LOC_895|>", "<|LOC_896|>", "<|LOC_897|>", "<|LOC_898|>", "<|LOC_899|>", "<|LOC_900|>", "<|LOC_901|>", "<|LOC_902|>", "<|LOC_903|>", "<|LOC_904|>", "<|LOC_905|>", "<|LOC_906|>", "<|LOC_907|>", "<|LOC_908|>", "<|LOC_909|>", "<|LOC_910|>", "<|LOC_911|>", "<|LOC_912|>", "<|LOC_913|>", "<|LOC_914|>", "<|LOC_915|>", "<|LOC_916|>", "<|LOC_917|>", "<|LOC_918|>", "<|LOC_919|>", "<|LOC_920|>", "<|LOC_921|>", "<|LOC_922|>", "<|LOC_923|>", "<|LOC_924|>", "<|LOC_925|>", "<|LOC_926|>", "<|LOC_927|>", "<|LOC_928|>", "<|LOC_929|>", "<|LOC_930|>", "<|LOC_931|>", "<|LOC_932|>", "<|LOC_933|>", "<|LOC_934|>", "<|LOC_935|>", "<|LOC_936|>", "<|LOC_937|>", "<|LOC_938|>", "<|LOC_939|>", "<|LOC_940|>", "<|LOC_941|>", "<|LOC_942|>", "<|LOC_943|>", "<|LOC_944|>", "<|LOC_945|>", "<|LOC_946|>", "<|LOC_947|>", "<|LOC_948|>", "<|LOC_949|>", "<|LOC_950|>", "<|LOC_951|>", "<|LOC_952|>", "<|LOC_953|>", "<|LOC_954|>", "<|LOC_955|>", "<|LOC_956|>", "<|LOC_957|>", "<|LOC_958|>", "<|LOC_959|>", "<|LOC_960|>", "<|LOC_961|>", "<|LOC_962|>", "<|LOC_963|>", "<|LOC_964|>", "<|LOC_965|>", "<|LOC_966|>", "<|LOC_967|>", "<|LOC_968|>", "<|LOC_969|>", "<|LOC_970|>", "<|LOC_971|>", "<|LOC_972|>", "<|LOC_973|>", "<|LOC_974|>", "<|LOC_975|>", "<|LOC_976|>", "<|LOC_977|>", "<|LOC_978|>", "<|LOC_979|>", "<|LOC_980|>", "<|LOC_981|>", "<|LOC_982|>", "<|LOC_983|>", "<|LOC_984|>", "<|LOC_985|>", "<|LOC_986|>", "<|LOC_987|>", "<|LOC_988|>", "<|LOC_989|>", "<|LOC_990|>", "<|LOC_991|>", "<|LOC_992|>", "<|LOC_993|>", "<|LOC_994|>", "<|LOC_995|>", "<|LOC_996|>", "<|LOC_997|>", "<|LOC_998|>", "<|LOC_999|>", "<|LOC_1000|>", "<|LOC_BEGIN|>", "<|LOC_END|>", "<|LOC_SEP|>", "<|CROP_COL_SEP|>", "<|CROP_ROW_SEP|>", "<|IMAGE_SEP|>"]}
|
tokenization_ernie4_5.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Ernie4_5_Tokenizer"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from shutil import copyfile
|
| 18 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
import sentencepiece as spm
|
| 22 |
+
|
| 23 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 24 |
+
from transformers.tokenization_utils_base import (
|
| 25 |
+
PaddingStrategy,
|
| 26 |
+
)
|
| 27 |
+
from transformers.utils import logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Ernie4_5_Tokenizer(PreTrainedTokenizer):
|
| 35 |
+
"""
|
| 36 |
+
Ernie4_5_Tokenizer
|
| 37 |
+
vocab_files_names (dict): Mapping vocabulary-related config name to actual filename.
|
| 38 |
+
model_input_names (List): Model input names expected by the tokenizer
|
| 39 |
+
padding_side (str): Padding side (where to add padding tokens)
|
| 40 |
+
"""
|
| 41 |
+
vocab_files_names = {
|
| 42 |
+
"vocab_file": "tokenizer.model",
|
| 43 |
+
}
|
| 44 |
+
model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"]
|
| 45 |
+
padding_side = "right"
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
vocab_file,
|
| 50 |
+
bos_token="<s>",
|
| 51 |
+
cls_token="<cls>",
|
| 52 |
+
eos_token="</s>",
|
| 53 |
+
mask_token="<mask:0>",
|
| 54 |
+
pad_token="<pad>",
|
| 55 |
+
sep_token="<sep>",
|
| 56 |
+
unk_token="<unk>",
|
| 57 |
+
additional_special_tokens=None,
|
| 58 |
+
split_special_tokens=False,
|
| 59 |
+
tokenizer_alpha=None,
|
| 60 |
+
**kwargs
|
| 61 |
+
):
|
| 62 |
+
"""
|
| 63 |
+
Initialize the ERNIE tokenizer.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
vocab_file (str): Path to the SentencePiece model file.
|
| 67 |
+
bos_token (str, optional): Beginning of sentence token. Defaults to "<s>".
|
| 68 |
+
cls_token (str, optional): Classification token. Defaults to "<cls>".
|
| 69 |
+
eos_token (str, optional): End of sentence token. Defaults to "</s>".
|
| 70 |
+
mask_token (str, optional): Mask token. Defaults to "<mask:0>".
|
| 71 |
+
pad_token (str, optional): Padding token. Defaults to "<pad>".
|
| 72 |
+
sep_token (str, optional): Separator token. Defaults to "<sep>".
|
| 73 |
+
unk_token (str, optional): Unknown token. Defaults to "<unk>".
|
| 74 |
+
additional_special_tokens (List[str], optional): Additional special tokens.
|
| 75 |
+
Defaults to ["<mask:1>", "<mask:7>"].
|
| 76 |
+
split_special_tokens (bool, optional): Whether to split special tokens. Defaults to False.
|
| 77 |
+
tokenizer_alpha (float, optional): Alpha parameter for SentencePiece sampling.
|
| 78 |
+
**kwargs: Additional keyword arguments passed to the parent class.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
self.vocab_file = vocab_file
|
| 82 |
+
self.sp_model = spm.SentencePieceProcessor()
|
| 83 |
+
self.sp_model.Load(vocab_file)
|
| 84 |
+
self.tokenizer_alpha = tokenizer_alpha
|
| 85 |
+
|
| 86 |
+
if additional_special_tokens is None:
|
| 87 |
+
additional_special_tokens = ["<mask:1>", "<mask:7>"]
|
| 88 |
+
super().__init__(
|
| 89 |
+
bos_token=bos_token,
|
| 90 |
+
cls_token=cls_token,
|
| 91 |
+
eos_token=eos_token,
|
| 92 |
+
mask_token=mask_token,
|
| 93 |
+
pad_token=pad_token,
|
| 94 |
+
sep_token=sep_token,
|
| 95 |
+
unk_token=unk_token,
|
| 96 |
+
additional_special_tokens=additional_special_tokens,
|
| 97 |
+
split_special_tokens=split_special_tokens,
|
| 98 |
+
**kwargs,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def vocab_size(self):
|
| 103 |
+
"""Returns the size of the vocabulary.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
int: The number of tokens in the vocabulary.
|
| 107 |
+
"""
|
| 108 |
+
return self.sp_model.vocab_size()
|
| 109 |
+
|
| 110 |
+
def get_vocab(self):
|
| 111 |
+
"""Get the vocabulary as a dictionary mapping tokens to their IDs.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
dict: A dictionary mapping tokens to their corresponding IDs.
|
| 115 |
+
"""
|
| 116 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 117 |
+
vocab.update(self.added_tokens_encoder)
|
| 118 |
+
return vocab
|
| 119 |
+
|
| 120 |
+
def _tokenize(self, text):
|
| 121 |
+
"""Tokenize text using SentencePiece.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
text (str): The text to tokenize.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
list: A list of tokens.
|
| 128 |
+
"""
|
| 129 |
+
if self.tokenizer_alpha is not None:
|
| 130 |
+
return self.sp_model.encode_as_pieces(
|
| 131 |
+
text,
|
| 132 |
+
enable_sampling=True,
|
| 133 |
+
nbest_size=-1,
|
| 134 |
+
alpha=self.tokenizer_alpha,
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
return self.sp_model.encode_as_pieces(text)
|
| 138 |
+
|
| 139 |
+
def _convert_token_to_id(self, token):
|
| 140 |
+
"""Convert a token (str) to an ID using the vocabulary.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
token (str): The token to convert.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
int: The corresponding token ID.
|
| 147 |
+
"""
|
| 148 |
+
return self.sp_model.piece_to_id(token)
|
| 149 |
+
|
| 150 |
+
def _convert_id_to_token(self, id):
|
| 151 |
+
"""Convert an ID to a token (str) using the vocabulary.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
id (int): The token ID to convert.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
str: The corresponding token.
|
| 158 |
+
"""
|
| 159 |
+
if id >= self.vocab_size:
|
| 160 |
+
return self.unk_token
|
| 161 |
+
else:
|
| 162 |
+
return self.sp_model.id_to_piece(id)
|
| 163 |
+
|
| 164 |
+
def convert_tokens_to_string(self, tokens):
|
| 165 |
+
"""Convert a sequence of tokens back to a single string.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
tokens (List[str]): A list of tokens to convert.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
str: The reconstructed string.
|
| 172 |
+
"""
|
| 173 |
+
current_sub_tokens = []
|
| 174 |
+
out_string = ""
|
| 175 |
+
prev_is_special = False
|
| 176 |
+
for token in tokens:
|
| 177 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 178 |
+
if token in self.all_special_tokens:
|
| 179 |
+
if not prev_is_special:
|
| 180 |
+
out_string += " "
|
| 181 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 182 |
+
prev_is_special = True
|
| 183 |
+
current_sub_tokens = []
|
| 184 |
+
else:
|
| 185 |
+
current_sub_tokens.append(token)
|
| 186 |
+
prev_is_special = False
|
| 187 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 188 |
+
return out_string
|
| 189 |
+
|
| 190 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 191 |
+
"""Build model inputs by adding special tokens to sequences.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
token_ids_0 (List[int]): List of token IDs for the first sequence.
|
| 195 |
+
token_ids_1 (List[int], optional): List of token IDs for the second sequence.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
List[int]: List of token IDs with special tokens added.
|
| 199 |
+
"""
|
| 200 |
+
output = token_ids_0
|
| 201 |
+
last_cls_index = -1
|
| 202 |
+
last_sep_index = -1
|
| 203 |
+
if self.cls_token_id in output:
|
| 204 |
+
last_cls_index = len(output) - output[::-1].index(self.cls_token_id) - 1
|
| 205 |
+
if self.sep_token_id in output:
|
| 206 |
+
last_sep_index = len(output) - output[::-1].index(self.sep_token_id) - 1
|
| 207 |
+
|
| 208 |
+
if last_cls_index > last_sep_index:
|
| 209 |
+
next_token_id = self.sep_token_id
|
| 210 |
+
elif last_sep_index > last_cls_index:
|
| 211 |
+
next_token_id = self.cls_token_id
|
| 212 |
+
else:
|
| 213 |
+
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 214 |
+
next_token_id = self.cls_token_id
|
| 215 |
+
|
| 216 |
+
output = [self.bos_token_id] + output
|
| 217 |
+
# Assume no markup in text if token_ids_1 is given.
|
| 218 |
+
if token_ids_1 is not None:
|
| 219 |
+
output = output + token_ids_1 + [next_token_id]
|
| 220 |
+
return output
|
| 221 |
+
|
| 222 |
+
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
|
| 223 |
+
"""Get a mask showing which tokens are special tokens.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
token_ids_0 (List[int]): List of token IDs for the first sequence.
|
| 227 |
+
token_ids_1 (List[int], optional): List of token IDs for the second sequence.
|
| 228 |
+
already_has_special_tokens (bool): Whether the tokens already include special tokens.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
List[int]: A mask where 1 indicates special tokens and 0 indicates regular tokens.
|
| 232 |
+
"""
|
| 233 |
+
if already_has_special_tokens:
|
| 234 |
+
return super().get_special_tokens_mask(token_ids_0, token_ids_1, already_has_special_tokens=True)
|
| 235 |
+
|
| 236 |
+
# [bos_token, cls_token, tokens_0, sep_token]
|
| 237 |
+
if token_ids_1 is None:
|
| 238 |
+
return [1, 1] + ([0] * len(token_ids_0)) + [1]
|
| 239 |
+
# [bos_token, cls_token, tokens_0, sep_token, tokens_1, cls_token]
|
| 240 |
+
return [1, 1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 241 |
+
|
| 242 |
+
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 243 |
+
"""
|
| 244 |
+
Save the vocabulary and special tokens file to a directory.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
save_directory (str): The directory in which to save the vocabulary.
|
| 248 |
+
filename_prefix (Optional[str]): Optional prefix for the saved filename.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Tuple[str]: Paths to the files saved.
|
| 252 |
+
|
| 253 |
+
Raises:
|
| 254 |
+
ValueError: If the save_directory is not a valid directory.
|
| 255 |
+
"""
|
| 256 |
+
if not os.path.isdir(save_directory):
|
| 257 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 258 |
+
return
|
| 259 |
+
out_vocab_file = os.path.join(
|
| 260 |
+
save_directory,
|
| 261 |
+
(filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"],
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 265 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 266 |
+
elif not os.path.isfile(self.vocab_file):
|
| 267 |
+
with open(out_vocab_file, "wb") as fi:
|
| 268 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 269 |
+
fi.write(content_spiece_model)
|
| 270 |
+
|
| 271 |
+
return (out_vocab_file,)
|
| 272 |
+
|
| 273 |
+
def _pad(
|
| 274 |
+
self,
|
| 275 |
+
encoded_inputs: Union[Dict],
|
| 276 |
+
max_length: Optional[int] = None,
|
| 277 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 278 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 279 |
+
padding_side: Optional[str] = None,
|
| 280 |
+
return_attention_mask: Optional[bool] = None,
|
| 281 |
+
) -> dict:
|
| 282 |
+
"""
|
| 283 |
+
Pad encoded inputs according to specified strategy.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
encoded_inputs (Union[Dict]): Dictionary of encoded inputs.
|
| 287 |
+
max_length (Optional[int]): Maximum length to pad to.
|
| 288 |
+
padding_strategy (PaddingStrategy): Strategy for padding.
|
| 289 |
+
pad_to_multiple_of (Optional[int]): Pad to a multiple of this value.
|
| 290 |
+
return_attention_mask (Optional[bool]): Whether to return attention mask.
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
dict: Dictionary with padded inputs and optional attention mask.
|
| 294 |
+
|
| 295 |
+
Raises:
|
| 296 |
+
ValueError: If attention_mask has unexpected type or invalid padding strategy.
|
| 297 |
+
"""
|
| 298 |
+
if return_attention_mask is None:
|
| 299 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
| 300 |
+
if return_attention_mask:
|
| 301 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
| 302 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
| 303 |
+
max_length = len(required_input)
|
| 304 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 305 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 306 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
| 307 |
+
|
| 308 |
+
if "attention_mask" in encoded_inputs and encoded_inputs["attention_mask"] is not None:
|
| 309 |
+
attention_mask = encoded_inputs.pop("attention_mask")
|
| 310 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 311 |
+
attention_mask = attention_mask.numpy()
|
| 312 |
+
elif isinstance(attention_mask, list):
|
| 313 |
+
attention_mask = np.array(attention_mask)
|
| 314 |
+
elif not isinstance(attention_mask, np.ndarray):
|
| 315 |
+
raise ValueError(f"Unexpected type {type(attention_mask)} of attention_mask, ")
|
| 316 |
+
else:
|
| 317 |
+
# Create default attention mask if none provided
|
| 318 |
+
attention_mask = np.tril(np.ones((len(required_input), len(required_input)), dtype=np.int64))
|
| 319 |
+
attention_mask = np.expand_dims(attention_mask, axis=0)
|
| 320 |
+
|
| 321 |
+
if needs_to_be_padded:
|
| 322 |
+
difference = max_length - len(required_input)
|
| 323 |
+
if self.padding_side == "right":
|
| 324 |
+
if attention_mask.ndim == 1:
|
| 325 |
+
pad_width = [(0, difference)]
|
| 326 |
+
else:
|
| 327 |
+
pad_width = [(0, 0), (0, difference), (0, difference)]
|
| 328 |
+
elif self.padding_side == "left":
|
| 329 |
+
if attention_mask.ndim == 1:
|
| 330 |
+
pad_width = [(difference, 0)]
|
| 331 |
+
else:
|
| 332 |
+
pad_width = [(0, 0), (difference, 0), (difference, 0)]
|
| 333 |
+
else:
|
| 334 |
+
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
| 335 |
+
attention_mask = np.pad(
|
| 336 |
+
attention_mask,
|
| 337 |
+
pad_width=pad_width,
|
| 338 |
+
mode="constant",
|
| 339 |
+
constant_values=0,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
encoded_inputs = super()._pad(
|
| 343 |
+
encoded_inputs,
|
| 344 |
+
max_length,
|
| 345 |
+
padding_strategy=padding_strategy,
|
| 346 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 347 |
+
return_attention_mask=False,
|
| 348 |
+
)
|
| 349 |
+
if return_attention_mask:
|
| 350 |
+
encoded_inputs["attention_mask"] = attention_mask.tolist()
|
| 351 |
+
return encoded_inputs
|
| 352 |
+
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34ef7db83df785924fb83d7b887b6e822a031c56e15cff40aaf9b982988180df
|
| 3 |
+
size 1614363
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<s>",
|
| 3 |
+
"eos_token": "</s>",
|
| 4 |
+
"pad_token": "<unk>",
|
| 5 |
+
"unk_token": "<unk>",
|
| 6 |
+
"cls_token": "<|begin_of_sentence|>",
|
| 7 |
+
"sep_token": "<|end_of_sentence|>",
|
| 8 |
+
"mask_token": "<mask:1>",
|
| 9 |
+
"sys_start_token": "<mask:4>",
|
| 10 |
+
"sys_end_token": "<mask:5>",
|
| 11 |
+
"header_start_token": "<mask:6>",
|
| 12 |
+
"header_end_token": "<mask:7>",
|
| 13 |
+
"additional_special_tokens": null,
|
| 14 |
+
"chat_template": "{%- if not add_generation_prompt is defined -%}\n {%- set add_generation_prompt = true -%}\n{%- endif -%}\n{%- if not cls_token is defined -%}\n {%- set cls_token = \"<|begin_of_sentence|>\" -%}\n{%- endif -%}\n{%- if not sep_token is defined -%}\n {%- set sep_token = \"<|end_of_sentence|>\" -%}\n{%- endif -%}\n{{- cls_token -}}\n{%- for message in messages -%}\n {%- if message[\"role\"] == \"user\" -%}\n {{- \"User: \" + message[\"content\"] + \"\n\" -}}\n {%- elif message[\"role\"] == \"assistant\" -%}\n {{- \"Assistant: \" + message[\"content\"] + sep_token -}}\n {%- elif message[\"role\"] == \"system\" -%}\n {{- message[\"content\"] + \"\n\" -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{- \"Assistant: \" -}}\n{%- endif -%}",
|
| 15 |
+
"tokenizer_class": "Ernie4_5_Tokenizer",
|
| 16 |
+
"auto_map": {
|
| 17 |
+
"AutoTokenizer": [
|
| 18 |
+
"tokenization_ernie4_5.Ernie4_5_Tokenizer",
|
| 19 |
+
"tokenization_ernie4_5.Ernie4_5_Tokenizer"
|
| 20 |
+
]
|
| 21 |
+
}
|
| 22 |
+
}
|