rulixiang commited on
Commit
b1e036f
·
1 Parent(s): e4effe5

First model version

Browse files
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "patchmoe",
3
+ "architectures": [
4
+ "PatchMoEForPrediction"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_patch_moe.PatchMoeConfig",
8
+ "AutoModelForCausalLM": "modeling_patch_moe.PatchMoEForPrediction"
9
+ },
10
+ "disable_bias_linear": false,
11
+ "do_base_forecast": false,
12
+ "do_expert_forecast": true,
13
+ "expert_num_layers": 4,
14
+ "ffn_hidden_size": 4096,
15
+ "heterogeneous_moe_layer": false,
16
+ "hidden_size": 1024,
17
+ "init_method_std": 0.06,
18
+ "is_revin": true,
19
+ "k_layernorm": false,
20
+ "kv_channels": 64,
21
+ "mask_pad_value": 255.0,
22
+ "model_type": "patch_moe",
23
+ "moe_expert_final_layernorm": true,
24
+ "moe_ffn_hidden_size": 4096,
25
+ "moe_router_enable_expert_bias": false,
26
+ "moe_router_input_size": 2880,
27
+ "moe_router_pre_softmax": true,
28
+ "moe_router_score_function": "softmax",
29
+ "moe_router_topk": 1,
30
+ "moe_shared_expert_intermediate_size": 4096,
31
+ "multi_forecast_head_list": [
32
+ 24,
33
+ 96,
34
+ 336
35
+ ],
36
+ "num_attention_heads": 16,
37
+ "num_hidden_layers": 2,
38
+ "num_moe_experts": 4,
39
+ "torch_dtype": "bfloat16",
40
+ "patch_size_list": [
41
+ 120,
42
+ 96,
43
+ 64,
44
+ 36
45
+ ],
46
+ "pred_length": 336,
47
+ "q_layernorm": false,
48
+ "residual_backcast": true,
49
+ "rotary_base": 1000000,
50
+ "rotary_interleaved": false,
51
+ "seq_length": 2880,
52
+ "shared_patch_size": 32,
53
+ "tie_word_embeddings": false,
54
+ "transformer_input_layernorm": true,
55
+ "transformers_version": "4.40.1",
56
+ "use_cache": true,
57
+ "use_cpu_initialization": true
58
+ }
configuration_patch_moe.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration class for PatchMoE model.
3
+
4
+ This module defines the configuration for PatchMoE, a large-scale time series foundation model
5
+ that utilizes Mixture of Experts (MoE) architecture with multiple patch tokenizers.
6
+ """
7
+
8
+ from typing import List, Optional
9
+ from transformers import PretrainedConfig
10
+
11
+
12
+ class PatchMoeConfig(PretrainedConfig):
13
+ """
14
+ Configuration class for PatchMoE model.
15
+
16
+ PatchMoE is a time series foundation model that uses Mixture of Experts architecture
17
+ with multiple patch tokenizers for efficient time series forecasting.
18
+
19
+ This configuration inherits from [`PretrainedConfig`] and can be used to control the model
20
+ output. Read the documentation from [`PretrainedConfig`] for more information.
21
+
22
+ Args:
23
+ hidden_size (`int`, *optional*, defaults to 1024):
24
+ Dimensionality of the encoder layers and the pooler layer.
25
+ ffn_hidden_size (`int`, *optional*, defaults to 4096):
26
+ Dimensionality of the feed-forward networks in the transformer layers.
27
+ seq_length (`int`, *optional*, defaults to 2880):
28
+ Maximum sequence length that the model can handle.
29
+ add_bias_linear (`bool`, *optional*, defaults to `False`):
30
+ Whether to add bias in linear layers.
31
+ rope_theta (`int`, *optional*, defaults to 10000):
32
+ The base period of the RoPE embeddings.
33
+ num_hidden_layers (`int`, *optional*, defaults to 3):
34
+ Number of hidden layers in the transformer encoder.
35
+ num_attention_heads (`int`, *optional*, defaults to 16):
36
+ Number of attention heads for each attention layer in the transformer encoder.
37
+ mask_pad_value (`float`, *optional*, defaults to 255.0):
38
+ Value used for padding/masking in input sequences.
39
+ expert_num_layers (`int`, *optional*, defaults to 4):
40
+ Number of transformer layers within each expert.
41
+ shared_patch_size (`int`, *optional*, defaults to 64):
42
+ Size of patches for the shared expert.
43
+ patch_size_list (`List[int]`, *optional*, defaults to [96, 64, 48, 24]):
44
+ List of patch sizes for different experts.
45
+ multi_forecast_head_list (`List[int]`, *optional*, defaults to [24, 96, 336]):
46
+ List of forecast lengths for multi-head prediction.
47
+ is_revin (`bool`, *optional*, defaults to `True`):
48
+ Whether to use RevIN (Reversible Instance Normalization).
49
+ params_dtype (`str`, *optional*, defaults to "bfloat16"):
50
+ Data type for model parameters.
51
+ use_cpu_initialization (`bool`, *optional*, defaults to `False`):
52
+ Whether to initialize model parameters on CPU.
53
+ rotary_interleaved (`bool`, *optional*, defaults to `False`):
54
+ Whether to use interleaved rotary position embeddings.
55
+ do_expert_forecast (`bool`, *optional*, defaults to `True`):
56
+ Whether experts perform forecasting.
57
+ residual_backcast (`bool`, *optional*, defaults to `True`):
58
+ Whether to use residual connections for backcast.
59
+ do_base_forecast (`bool`, *optional*, defaults to `False`):
60
+ Whether to use base forecasting.
61
+ heterogeneous_moe_layer (`bool`, *optional*, defaults to `True`):
62
+ Whether to use heterogeneous MoE layers.
63
+ test_data_seq_len (`int`, *optional*, defaults to 2880):
64
+ Sequence length for test data.
65
+ test_data_test_len (`int`, *optional*, defaults to 720):
66
+ Test length for test data.
67
+ autoregressive_step_list (`List[int]`, *optional*, defaults to [2, 4, 1]):
68
+ List of autoregressive steps for different forecast heads.
69
+ multi_forecast_head_type (`str`, *optional*, defaults to "single"):
70
+ Type of multi-forecast head aggregation.
71
+ num_experts (`int`, *optional*, defaults to 4):
72
+ Number of experts in the MoE layer.
73
+ moe_router_topk (`int`, *optional*, defaults to 2):
74
+ Number of top experts to route each token to.
75
+ moe_ffn_hidden_size (`int`, *optional*, defaults to 4096):
76
+ Hidden size for MoE feed-forward networks.
77
+ moe_shared_expert_intermediate_size (`int`, *optional*, defaults to 4096):
78
+ Intermediate size for shared experts.
79
+ init_method_std (`float`, *optional*, defaults to 0.06):
80
+ Standard deviation for weight initialization.
81
+ initializer_range (`float`, *optional*, defaults to 0.02):
82
+ Range for weight initialization.
83
+ moe_router_enable_expert_bias (`bool`, *optional*, defaults to `False`):
84
+ Whether to enable expert bias in routing.
85
+ moe_expert_final_layernorm (`bool`, *optional*, defaults to `True`):
86
+ Whether to apply layer normalization at the end of each expert.
87
+ transformer_input_layernorm (`bool`, *optional*, defaults to `True`):
88
+ Whether to apply layer normalization to transformer inputs.
89
+ moe_router_pre_softmax (`bool`, *optional*, defaults to `True`):
90
+ Whether to apply softmax before routing.
91
+ q_layernorm (`bool`, *optional*, defaults to `False`):
92
+ Whether to apply layer normalization to query vectors.
93
+ k_layernorm (`bool`, *optional*, defaults to `False`):
94
+ Whether to apply layer normalization to key vectors.
95
+ moe_router_score_function (`str`, *optional*, defaults to "softmax"):
96
+ Score function for MoE routing ("softmax" or "sigmoid").
97
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
98
+ Whether to tie word embeddings.
99
+ """
100
+
101
+ model_type = "patch_moe"
102
+ keys_to_ignore_at_inference = ["past_key_values"]
103
+
104
+ def __init__(
105
+ self,
106
+ hidden_size: int = 1024,
107
+ ffn_hidden_size: int = 4096,
108
+ seq_length: int = 2880,
109
+ add_bias_linear: bool = False,
110
+ rope_theta: int = 10000,
111
+ num_hidden_layers: int = 3,
112
+ num_attention_heads: int = 16,
113
+ mask_pad_value: float = 255.0,
114
+ expert_num_layers: int = 4,
115
+ shared_patch_size: int = 64,
116
+ patch_size_list: Optional[List[int]] = None,
117
+ multi_forecast_head_list: Optional[List[int]] = None,
118
+ is_revin: bool = True,
119
+ use_cpu_initialization: bool = False,
120
+ rotary_interleaved: bool = False,
121
+ do_expert_forecast: bool = True,
122
+ residual_backcast: bool = True,
123
+ do_base_forecast: bool = False,
124
+ heterogeneous_moe_layer: bool = True,
125
+ test_data_seq_len: int = 2880,
126
+ test_data_test_len: int = 720,
127
+ autoregressive_step_list: Optional[List[int]] = None,
128
+ multi_forecast_head_type: str = "single",
129
+ num_experts: int = 4,
130
+ moe_router_topk: int = 2,
131
+ moe_ffn_hidden_size: int = 4096,
132
+ moe_shared_expert_intermediate_size: int = 4096,
133
+ init_method_std: float = 0.06,
134
+ initializer_range: float = 0.02,
135
+ moe_router_enable_expert_bias: bool = False,
136
+ moe_expert_final_layernorm: bool = True,
137
+ transformer_input_layernorm: bool = True,
138
+ moe_router_pre_softmax: bool = True,
139
+ q_layernorm: bool = False,
140
+ k_layernorm: bool = False,
141
+ moe_router_score_function: str = "softmax",
142
+ tie_word_embeddings: bool = False,
143
+ **kwargs,
144
+ ):
145
+ """Initialize PatchMoE configuration."""
146
+ # Set default values for list parameters
147
+ if patch_size_list is None:
148
+ patch_size_list = [96, 64, 48, 24]
149
+ if multi_forecast_head_list is None:
150
+ multi_forecast_head_list = [24, 96, 336]
151
+ if autoregressive_step_list is None:
152
+ autoregressive_step_list = [2, 4, 1]
153
+ # patchmoe inference specific
154
+ self.test_data_seq_len = test_data_seq_len
155
+ self.inference_length = test_data_test_len
156
+ self.autoregressive_step_list = autoregressive_step_list
157
+ self.multi_forecast_head_type = multi_forecast_head_type
158
+ self.use_cache = True
159
+
160
+ # patchmoe specific
161
+ self.hidden_size = hidden_size
162
+ self.ffn_hidden_size = ffn_hidden_size
163
+ self.num_attention_heads = num_attention_heads
164
+ self.init_method_std = init_method_std
165
+ self.initializer_range = initializer_range
166
+ self.seq_length = seq_length
167
+ self.multi_forecast_head_list = multi_forecast_head_list
168
+ self.kv_channels = self.hidden_size // self.num_attention_heads
169
+ self.rotary_base = rope_theta
170
+ self.num_hidden_layers = num_hidden_layers
171
+ self.mask_pad_value = mask_pad_value
172
+ self.pred_length = max(self.multi_forecast_head_list)
173
+ self.add_bias_linear = add_bias_linear
174
+ self.is_revin = is_revin
175
+ self.do_base_forecast = do_base_forecast
176
+ self.do_expert_forecast = do_expert_forecast
177
+ self.residual_backcast = residual_backcast
178
+ self.heterogeneous_moe_layer = heterogeneous_moe_layer
179
+ self.use_cpu_initialization = use_cpu_initialization
180
+ self.rotary_interleaved = rotary_interleaved
181
+
182
+ # expert specific
183
+ self.patch_size_list = patch_size_list
184
+ self.num_moe_experts = num_experts
185
+ self.shared_patch_size = shared_patch_size
186
+ self.expert_num_layers = expert_num_layers
187
+ self.moe_router_input_size = self.seq_length
188
+ self.moe_router_topk = moe_router_topk
189
+ self.moe_router_score_function = moe_router_score_function
190
+ self.moe_ffn_hidden_size = moe_ffn_hidden_size
191
+ self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size
192
+ self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
193
+ self.moe_expert_final_layernorm = moe_expert_final_layernorm
194
+ self.transformer_input_layernorm = transformer_input_layernorm
195
+ self.moe_router_pre_softmax = moe_router_pre_softmax
196
+ self.q_layernorm = q_layernorm
197
+ self.k_layernorm = k_layernorm
198
+
199
+ kwargs.pop("tie_word_embeddings", None)
200
+ super().__init__(
201
+ tie_word_embeddings=tie_word_embeddings,
202
+ **kwargs,
203
+ )
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.40.1"
4
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8534fa131034e86c50ec43cc14e3d6f17af1d5d4161a11ada2218d12067e1c4c
3
+ size 3718382544
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5a15d1fcb6388aed06deb70f77918cd38899476dc0c4b1ac7dc57391cf8a477
3
+ size 1264771376
model.safetensors.index.json ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 4983109888
4
+ },
5
+ "weight_map": {
6
+ "model.decoder.layers.0.router.weight": "model-00001-of-00002.safetensors",
7
+ "model.decoder.layers.0.backcast_layernorm.weight": "model-00001-of-00002.safetensors",
8
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
9
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
10
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
11
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
12
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
13
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
14
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
15
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
16
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
17
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
18
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
19
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
20
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
21
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
22
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
23
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
24
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
25
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
26
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
27
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
28
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
29
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
30
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
31
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
32
+ "model.decoder.layers.0.experts.local_experts.0.final_layernorm.weight": "model-00001-of-00002.safetensors",
33
+ "model.decoder.layers.0.experts.local_experts.0.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
34
+ "model.decoder.layers.0.experts.local_experts.0.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
35
+ "model.decoder.layers.0.experts.local_experts.0.output_layer.weight": "model-00001-of-00002.safetensors",
36
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
37
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
38
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
39
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
40
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
41
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
42
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
43
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
44
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
45
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
46
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
47
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
48
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
49
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
50
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
51
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
52
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
53
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
54
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
55
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
56
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
57
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
58
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
59
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
60
+ "model.decoder.layers.0.experts.local_experts.1.final_layernorm.weight": "model-00001-of-00002.safetensors",
61
+ "model.decoder.layers.0.experts.local_experts.1.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
62
+ "model.decoder.layers.0.experts.local_experts.1.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
63
+ "model.decoder.layers.0.experts.local_experts.1.output_layer.weight": "model-00001-of-00002.safetensors",
64
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
65
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
66
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
67
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
68
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
69
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
70
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
71
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
72
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
73
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
74
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
75
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
76
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
77
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
78
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
79
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
80
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
81
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
82
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
83
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
84
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
85
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
86
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
87
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
88
+ "model.decoder.layers.0.experts.local_experts.2.final_layernorm.weight": "model-00001-of-00002.safetensors",
89
+ "model.decoder.layers.0.experts.local_experts.2.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
90
+ "model.decoder.layers.0.experts.local_experts.2.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
91
+ "model.decoder.layers.0.experts.local_experts.2.output_layer.weight": "model-00001-of-00002.safetensors",
92
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
93
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
94
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
95
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
96
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
97
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
98
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
99
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
100
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
101
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
102
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
103
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
104
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
105
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
106
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
107
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
108
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
109
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
110
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
111
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
112
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
113
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
114
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
115
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
116
+ "model.decoder.layers.0.experts.local_experts.3.final_layernorm.weight": "model-00001-of-00002.safetensors",
117
+ "model.decoder.layers.0.experts.local_experts.3.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
118
+ "model.decoder.layers.0.experts.local_experts.3.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
119
+ "model.decoder.layers.0.experts.local_experts.3.output_layer.weight": "model-00001-of-00002.safetensors",
120
+ "model.decoder.layers.0.shared_experts.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
121
+ "model.decoder.layers.0.shared_experts.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
122
+ "model.decoder.layers.0.shared_experts.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
123
+ "model.decoder.layers.0.shared_experts.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
124
+ "model.decoder.layers.0.shared_experts.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
125
+ "model.decoder.layers.0.shared_experts.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
126
+ "model.decoder.layers.0.shared_experts.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
127
+ "model.decoder.layers.0.shared_experts.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
128
+ "model.decoder.layers.0.shared_experts.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
129
+ "model.decoder.layers.0.shared_experts.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
130
+ "model.decoder.layers.0.shared_experts.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
131
+ "model.decoder.layers.0.shared_experts.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
132
+ "model.decoder.layers.0.shared_experts.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
133
+ "model.decoder.layers.0.shared_experts.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
134
+ "model.decoder.layers.0.shared_experts.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
135
+ "model.decoder.layers.0.shared_experts.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
136
+ "model.decoder.layers.0.shared_experts.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
137
+ "model.decoder.layers.0.shared_experts.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
138
+ "model.decoder.layers.0.shared_experts.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
139
+ "model.decoder.layers.0.shared_experts.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
140
+ "model.decoder.layers.0.shared_experts.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
141
+ "model.decoder.layers.0.shared_experts.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
142
+ "model.decoder.layers.0.shared_experts.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
143
+ "model.decoder.layers.0.shared_experts.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
144
+ "model.decoder.layers.0.shared_experts.final_layernorm.weight": "model-00001-of-00002.safetensors",
145
+ "model.decoder.layers.0.shared_experts.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
146
+ "model.decoder.layers.0.shared_experts.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
147
+ "model.decoder.layers.0.shared_experts.output_layer.weight": "model-00001-of-00002.safetensors",
148
+ "model.decoder.layers.1.router.weight": "model-00001-of-00002.safetensors",
149
+ "model.decoder.layers.1.backcast_layernorm.weight": "model-00001-of-00002.safetensors",
150
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
151
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
152
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
153
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
154
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
155
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
156
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
157
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
158
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
159
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
160
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
161
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
162
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
163
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
164
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
165
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
166
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
167
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
168
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
169
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
170
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
171
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
172
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
173
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
174
+ "model.decoder.layers.1.experts.local_experts.0.final_layernorm.weight": "model-00001-of-00002.safetensors",
175
+ "model.decoder.layers.1.experts.local_experts.0.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
176
+ "model.decoder.layers.1.experts.local_experts.0.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
177
+ "model.decoder.layers.1.experts.local_experts.0.output_layer.weight": "model-00001-of-00002.safetensors",
178
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
179
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
180
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
181
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
182
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
183
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
184
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
185
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
186
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
187
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
188
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
189
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
190
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
191
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
192
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
193
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
194
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
195
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
196
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
197
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
198
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
199
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
200
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
201
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
202
+ "model.decoder.layers.1.experts.local_experts.1.final_layernorm.weight": "model-00001-of-00002.safetensors",
203
+ "model.decoder.layers.1.experts.local_experts.1.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
204
+ "model.decoder.layers.1.experts.local_experts.1.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
205
+ "model.decoder.layers.1.experts.local_experts.1.output_layer.weight": "model-00001-of-00002.safetensors",
206
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
207
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
208
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
209
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
210
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
211
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
212
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
213
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
214
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
215
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
216
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
217
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
218
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
219
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
220
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
221
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
222
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
223
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
224
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
225
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
226
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
227
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
228
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
229
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
230
+ "model.decoder.layers.1.experts.local_experts.2.final_layernorm.weight": "model-00001-of-00002.safetensors",
231
+ "model.decoder.layers.1.experts.local_experts.2.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
232
+ "model.decoder.layers.1.experts.local_experts.2.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
233
+ "model.decoder.layers.1.experts.local_experts.2.output_layer.weight": "model-00001-of-00002.safetensors",
234
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
235
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
236
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
237
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
238
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
239
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
240
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
241
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
242
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
243
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
244
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
245
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
246
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
247
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
248
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
249
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
250
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
251
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
252
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
253
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
254
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
255
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
256
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
257
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
258
+ "model.decoder.layers.1.experts.local_experts.3.final_layernorm.weight": "model-00001-of-00002.safetensors",
259
+ "model.decoder.layers.1.experts.local_experts.3.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
260
+ "model.decoder.layers.1.experts.local_experts.3.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
261
+ "model.decoder.layers.1.experts.local_experts.3.output_layer.weight": "model-00002-of-00002.safetensors",
262
+ "model.decoder.layers.1.shared_experts.layers.0.input_layernorm.weight": "model-00002-of-00002.safetensors",
263
+ "model.decoder.layers.1.shared_experts.layers.0.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
264
+ "model.decoder.layers.1.shared_experts.layers.0.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
265
+ "model.decoder.layers.1.shared_experts.layers.0.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
266
+ "model.decoder.layers.1.shared_experts.layers.0.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
267
+ "model.decoder.layers.1.shared_experts.layers.0.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
268
+ "model.decoder.layers.1.shared_experts.layers.1.input_layernorm.weight": "model-00002-of-00002.safetensors",
269
+ "model.decoder.layers.1.shared_experts.layers.1.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
270
+ "model.decoder.layers.1.shared_experts.layers.1.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
271
+ "model.decoder.layers.1.shared_experts.layers.1.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
272
+ "model.decoder.layers.1.shared_experts.layers.1.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
273
+ "model.decoder.layers.1.shared_experts.layers.1.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
274
+ "model.decoder.layers.1.shared_experts.layers.2.input_layernorm.weight": "model-00002-of-00002.safetensors",
275
+ "model.decoder.layers.1.shared_experts.layers.2.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
276
+ "model.decoder.layers.1.shared_experts.layers.2.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
277
+ "model.decoder.layers.1.shared_experts.layers.2.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
278
+ "model.decoder.layers.1.shared_experts.layers.2.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
279
+ "model.decoder.layers.1.shared_experts.layers.2.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
280
+ "model.decoder.layers.1.shared_experts.layers.3.input_layernorm.weight": "model-00002-of-00002.safetensors",
281
+ "model.decoder.layers.1.shared_experts.layers.3.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
282
+ "model.decoder.layers.1.shared_experts.layers.3.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
283
+ "model.decoder.layers.1.shared_experts.layers.3.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
284
+ "model.decoder.layers.1.shared_experts.layers.3.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
285
+ "model.decoder.layers.1.shared_experts.layers.3.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
286
+ "model.decoder.layers.1.shared_experts.final_layernorm.weight": "model-00002-of-00002.safetensors",
287
+ "model.decoder.layers.1.shared_experts.patch_embedding.linear_fc1.weight": "model-00002-of-00002.safetensors",
288
+ "model.decoder.layers.1.shared_experts.patch_embedding.linear_fc2.weight": "model-00002-of-00002.safetensors",
289
+ "model.decoder.layers.1.shared_experts.output_layer.weight": "model-00002-of-00002.safetensors",
290
+ "model.output_layer.weight": "model-00002-of-00002.safetensors"
291
+ }
292
+ }
modeling_patch_moe.py ADDED
@@ -0,0 +1,1326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ import math
7
+ from functools import reduce
8
+ from abc import ABC, abstractmethod
9
+ from .configuration_patch_moe import PatchMoeConfig
10
+ from .ts_generation_mixin import PatchMoEGenerationMixin
11
+ from transformers import PreTrainedModel
12
+
13
+
14
+ def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor:
15
+ """Change sign so the last dimension becomes [-odd, +even]
16
+
17
+ Args:
18
+ x (Tensor): Input tensor
19
+
20
+ Returns:
21
+ Tensor: Tensor rotated half
22
+ """
23
+ if not rotary_interleaved:
24
+ x1, x2 = torch.chunk(x, 2, dim=-1)
25
+ return torch.cat((-x2, x1), dim=-1)
26
+ else:
27
+ x1 = x[:, :, :, ::2]
28
+ x2 = x[:, :, :, 1::2]
29
+ x_new = torch.stack((-x2, x1), dim=-1)
30
+ return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)
31
+
32
+
33
+ def _apply_rotary_pos_emb_bshd(
34
+ t: Tensor,
35
+ freqs: Tensor,
36
+ rotary_interleaved: bool = False,
37
+ multi_latent_attention: bool = False,
38
+ mscale: float = 1.0,
39
+ ) -> Tensor:
40
+ """Apply rotary positional embedding to input tensor T.
41
+
42
+ check https://kexue.fm/archives/8265 for detailed formulas
43
+
44
+ Args:
45
+ t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
46
+ freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]
47
+
48
+ Returns:
49
+ Tensor: The input tensor after applying RoPE
50
+ """
51
+ freqs = freqs.to(t.device)
52
+ rot_dim = freqs.shape[-1]
53
+
54
+ # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
55
+ t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
56
+
57
+ if multi_latent_attention:
58
+ x1 = t[..., 0::2]
59
+ x2 = t[..., 1::2]
60
+ t = torch.cat((x1, x2), dim=-1)
61
+
62
+ # first part is cosine component
63
+ # second part is sine component, need to change signs with _rotate_half method
64
+ cos_ = (torch.cos(freqs) * mscale).to(t.dtype)
65
+ sin_ = (torch.sin(freqs) * mscale).to(t.dtype)
66
+
67
+ t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_)
68
+ return torch.cat((t, t_pass), dim=-1)
69
+
70
+
71
+ def topk_softmax_with_capacity(
72
+ logits: torch.Tensor,
73
+ topk: int,
74
+ use_pre_softmax: bool = False,
75
+ score_function: str = "softmax",
76
+ expert_bias: Optional[torch.Tensor] = None,
77
+ ):
78
+ """Apply capacity and padding to the top-k selection.
79
+ Args:
80
+ logits (torch.Tensor): Logits tensor.
81
+ topk (int): The number of experts to select for each token.
82
+ use_pre_softmax (bool): Whether to apply softmax or sigmoid before top-k selection.
83
+ score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
84
+ expert_bias (torch.Tensor): The bias added to logits for expert routing.
85
+ Returns:
86
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
87
+ - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
88
+ the routing probabilities for each token to each expert.
89
+ - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts]
90
+ indicating which experts were selected for each token. True values represent
91
+ the selected experts.
92
+ - tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing
93
+ the number of local tokens assigned to each expert before dropping and padding.
94
+ """
95
+ assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
96
+
97
+ def compute_topk(
98
+ scores,
99
+ topk,
100
+ ):
101
+ return torch.topk(scores, k=topk, dim=1)
102
+
103
+ if score_function == "softmax":
104
+ if use_pre_softmax:
105
+ scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
106
+ probs, top_indices = compute_topk(
107
+ scores,
108
+ topk,
109
+ )
110
+ else:
111
+ scores, top_indices = compute_topk(
112
+ logits,
113
+ topk,
114
+ )
115
+ probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
116
+ elif score_function == "sigmoid":
117
+ scores = torch.sigmoid(logits.float()).type_as(logits)
118
+ if expert_bias is not None:
119
+ scores_for_routing = scores + expert_bias
120
+ _, top_indices = compute_topk(
121
+ scores_for_routing,
122
+ topk,
123
+ )
124
+ scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
125
+ else:
126
+ scores, top_indices = compute_topk(
127
+ scores,
128
+ topk,
129
+ )
130
+ probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
131
+ else:
132
+ raise ValueError(f"Invalid score_function: {score_function}")
133
+
134
+ # TODO Try using element-wise operations instead of scatter?
135
+ topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
136
+ topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
137
+ # TODO: Reset topk_map to realize load-balancing?
138
+ tokens_per_expert = topk_map.sum(dim=0)
139
+
140
+ return topk_masked_gates, topk_map, tokens_per_expert
141
+
142
+
143
+ class RotaryEmbedding(nn.Module):
144
+ """Rotary Embedding.
145
+
146
+ Args:
147
+ kv_channels (int): Projection weights dimension in multi-head attention. Obtained
148
+ from transformer config
149
+ rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings.
150
+ Defaults to False.
151
+ rotary_base (int, optional): Base period for rotary position embeddings. Defaults to
152
+ 10000.
153
+ use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly
154
+ on the GPU. Defaults to False
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ kv_channels: int,
160
+ rotary_interleaved: bool = False,
161
+ rotary_base: int = 10000,
162
+ use_cpu_initialization: bool = False,
163
+ ) -> None:
164
+ super().__init__()
165
+
166
+ dim = kv_channels
167
+ self.rotary_interleaved = rotary_interleaved
168
+ device = "cpu" if use_cpu_initialization else torch.cuda.current_device()
169
+ self.inv_freq = 1.0 / (
170
+ rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
171
+ )
172
+
173
+ def get_freqs_non_repeated(self, max_seq_len: int, offset: int = 0) -> Tensor:
174
+ """Generates matrix of frequencies based on positions in the sequence,
175
+ used to create positional encodings"""
176
+ seq = (
177
+ torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
178
+ + offset
179
+ )
180
+ freqs = torch.outer(seq, self.inv_freq) # [seq len, dim]
181
+ return freqs
182
+
183
+ def forward(
184
+ self, max_seq_len: int, offset: int = 0, packed_seq: bool = False, device=None
185
+ ) -> Tensor:
186
+ """Forward pass of RoPE embedding.
187
+
188
+ Args:
189
+ max_seq_len (int): Maximum size of sequence
190
+ offset (int, optional): RoPE offset. Defaults to 0.
191
+ packed_seq (bool, optional): Whether to use packed sequence. Defaults to False.
192
+
193
+ Returns:
194
+ Tensor: Embeddings after applying RoPE.
195
+ """
196
+ if device is None:
197
+ device = self.inv_freq.device
198
+ if self.inv_freq.device.type == "cpu":
199
+ # move `inv_freq` to GPU once at the first micro-batch forward pass
200
+ self.inv_freq = self.inv_freq.to(device=device)
201
+
202
+ freqs = self.get_freqs_non_repeated(max_seq_len, offset).to(device)
203
+ # first part even vector components, second part odd vector components,
204
+ # 2 * dim in dimension size
205
+ if not self.rotary_interleaved:
206
+ emb = torch.cat((freqs, freqs), dim=-1)
207
+ else:
208
+ emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
209
+ freqs.shape[0], -1
210
+ )
211
+ # emb [seq_length, .., dim]
212
+ emb = emb[:, None, None, :]
213
+ return emb.to(device)
214
+
215
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
216
+ state_dict.pop(f"{prefix}inv_freq", None)
217
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
218
+
219
+ def get_rotary_seq_len(
220
+ self,
221
+ transformer_input: Tensor,
222
+ ) -> float:
223
+ """Function to get the rotary sequence length.
224
+ Args:
225
+ transformer_input (Tensor): Input tensor to the transformer
226
+ Returns:
227
+ float: The rotary sequence length
228
+ """
229
+ rotary_seq_len = transformer_input.size(0)
230
+ return rotary_seq_len
231
+
232
+
233
+ class IdentityOp(nn.Module):
234
+ def forward(self, x):
235
+ return x
236
+
237
+
238
+ class IdentityFuncOp(nn.Module):
239
+ def forward(self, x):
240
+ return x
241
+
242
+
243
+ class RMSNorm(nn.Module):
244
+ def __init__(self, hidden_size, eps=1e-5):
245
+ super().__init__()
246
+ self.weight = nn.Parameter(torch.ones(hidden_size))
247
+ self.variance_epsilon = eps
248
+
249
+ def forward(self, hidden_states):
250
+ """
251
+ hidden_states [bs, patch_num, d_model]
252
+ """
253
+ input_dtype = hidden_states.dtype
254
+ hidden_states = hidden_states.to(torch.float32)
255
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
256
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
257
+ return self.weight * hidden_states.to(input_dtype)
258
+
259
+
260
+ class TEDotProductAttention(nn.Module):
261
+ """Implement the scaled dot product attention with softmax.
262
+ Arguments
263
+ ---------
264
+ softmax_scale: The temperature to use for the softmax attention.
265
+ (default: 1/sqrt(d_keys) where d_keys is computed at
266
+ runtime)
267
+ attention_dropout: The dropout rate to apply to the attention
268
+ (default: 0.0)
269
+ """
270
+
271
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
272
+ super().__init__()
273
+ self.causal = causal
274
+ self.softmax_scale = softmax_scale
275
+ self.drop = nn.Dropout(attention_dropout)
276
+
277
+ def forward(
278
+ self,
279
+ q,
280
+ k,
281
+ v,
282
+ attention_mask,
283
+ causal=None,
284
+ ):
285
+ """Implements the multihead softmax attention.
286
+ Arguments
287
+ ---------
288
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
289
+ causal: if passed, will override self.causal
290
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
291
+ False means to mask out. (B, S)
292
+ """
293
+ causal = self.causal if causal is None else causal
294
+
295
+ q = q.transpose(0, 1).contiguous()
296
+ k = k.transpose(0, 1).contiguous()
297
+ v = v.transpose(0, 1).contiguous()
298
+
299
+ batch_size, seq_len = q.shape[0], q.shape[1]
300
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
301
+ # scores
302
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
303
+ scores = scores.masked_fill(attention_mask == 0, float("-1e9"))
304
+ # Softmax
305
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
306
+ # Dropout
307
+ attention_drop = self.drop(attention)
308
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
309
+ output = output.reshape(batch_size, seq_len, -1).transpose(0, 1).contiguous()
310
+ return output
311
+
312
+
313
+ class SelfAttention(nn.Module):
314
+ def __init__(
315
+ self,
316
+ config,
317
+ ):
318
+ super().__init__()
319
+ self.config = config
320
+ q_layernorm = config.q_layernorm
321
+ k_layernorm = config.k_layernorm
322
+ self.hidden_size = config.hidden_size
323
+ self.core_attention = TEDotProductAttention()
324
+ self.linear_proj = nn.Linear(
325
+ self.hidden_size,
326
+ self.hidden_size,
327
+ bias=config.add_bias_linear,
328
+ )
329
+ self.linear_qkv = nn.Linear(
330
+ self.hidden_size,
331
+ 3 * self.hidden_size,
332
+ bias=config.add_bias_linear,
333
+ )
334
+ if q_layernorm:
335
+ self.q_layernorm = RMSNorm(self.hidden_size)
336
+ else:
337
+ self.q_layernorm = IdentityOp()
338
+ if k_layernorm:
339
+ self.k_layernorm = RMSNorm(self.hidden_size)
340
+ else:
341
+ self.k_layernorm = IdentityOp()
342
+
343
+ def forward(self, x, attention_mask, rotary_pos_emb):
344
+ qkv = self.linear_qkv(x)
345
+ qkv = qkv.view(qkv.size(0), qkv.size(1), self.config.num_attention_heads, -1)
346
+ q, k, v = qkv.chunk(3, dim=-1)
347
+
348
+ # q/k norm
349
+ q = self.q_layernorm(q)
350
+ k = self.k_layernorm(k)
351
+
352
+ # Apply rotary encoding to q and k
353
+ rotary_pos_emb = (rotary_pos_emb,) * 2
354
+ q_pos_emb, k_pos_emb = rotary_pos_emb
355
+ q = _apply_rotary_pos_emb_bshd(q, q_pos_emb)
356
+ k = _apply_rotary_pos_emb_bshd(k, k_pos_emb)
357
+
358
+ # attention
359
+ attn_output = self.core_attention(q, k, v, attention_mask)
360
+ output = self.linear_proj(attn_output)
361
+ return output
362
+
363
+
364
+ class MLP(nn.Module):
365
+ def __init__(self, config, in_features):
366
+ super().__init__()
367
+ self.config = config
368
+ self.linear_fc1 = nn.Linear(
369
+ in_features,
370
+ self.config.moe_ffn_hidden_size * 2,
371
+ bias=self.config.add_bias_linear,
372
+ )
373
+ self.linear_fc2 = nn.Linear(
374
+ self.config.moe_ffn_hidden_size,
375
+ self.config.hidden_size,
376
+ bias=self.config.add_bias_linear,
377
+ )
378
+
379
+ def forward(self, x):
380
+ x = self.swiglu(self.linear_fc1(x))
381
+ x = self.linear_fc2(x)
382
+ return x
383
+
384
+ def swiglu(self, y):
385
+ """Performs SwiGLU (Swish-Gated Linear Unit) activation function.
386
+
387
+ Args:
388
+ y (torch.Tensor): Input tensor to be split into two halves along the last dimension.
389
+
390
+ Returns:
391
+ torch.Tensor: Result of SwiGLU activation: SiLU(y1) * y2, where y1, y2 are the split halves.
392
+ """
393
+ y_1, y_2 = torch.chunk(y, 2, -1)
394
+ return F.silu(y_1) * y_2
395
+
396
+
397
+ class TransformerLayer(nn.Module):
398
+ def __init__(self, config, input_layernorm):
399
+ super().__init__()
400
+ self.config = config
401
+ if input_layernorm:
402
+ self.input_layernorm = RMSNorm(self.config.hidden_size)
403
+ else:
404
+ self.input_layernorm = IdentityOp()
405
+ self.self_attention = SelfAttention(config)
406
+ self.pre_mlp_layernorm = RMSNorm(self.config.hidden_size)
407
+ self.mlp = MLP(config, self.config.hidden_size)
408
+
409
+ def forward(self, x, attention_mask, rotary_pos_emb):
410
+ residual = x
411
+ x = self.input_layernorm(x)
412
+ x = self.self_attention(x, attention_mask, rotary_pos_emb)
413
+ x = x + residual
414
+ residual = x
415
+ x = self.pre_mlp_layernorm(x)
416
+ x = self.mlp(x)
417
+ x = x + residual
418
+ return x
419
+
420
+
421
+ class PatchMoEExpert_v2(nn.Module):
422
+ def __init__(self, config, patch_input_size=32, expert_output_size=336, final_layernorm=True):
423
+ super().__init__()
424
+ self.config = config
425
+ self.patch_size = patch_input_size
426
+ self.seq_length = config.seq_length
427
+ assert (
428
+ self.seq_length % self.patch_size == 0
429
+ ), f"invalid patch_size: {self.patch_size} when seq_length={self.seq_length}"
430
+ self.patch_num = self.seq_length // self.patch_size
431
+ self.flatten_size = self.patch_num * self.config.hidden_size
432
+
433
+ self.layers = nn.ModuleList(
434
+ [
435
+ TransformerLayer(config, input_layernorm=config.transformer_input_layernorm)
436
+ for _ in range(self.config.expert_num_layers)
437
+ ]
438
+ )
439
+ if final_layernorm:
440
+ self.final_layernorm = RMSNorm(self.config.hidden_size)
441
+ else:
442
+ self.final_layernorm = IdentityOp()
443
+ self.patch_embedding = MLP(config, in_features=patch_input_size)
444
+ self.output_layer = nn.Linear(
445
+ in_features=self.flatten_size,
446
+ out_features=expert_output_size,
447
+ bias=False,
448
+ )
449
+
450
+ def _forward_patch_embedding(
451
+ self,
452
+ input: Tensor, # [batch_size, seq_len]
453
+ ):
454
+ """
455
+ Perform patch embedding on the input time series.
456
+
457
+ This method applies a linear transformation to the input tensor to
458
+ convert it into patches and then embeds these patches using a linear layer.
459
+ """
460
+ batch_size, seq_len = input.shape
461
+ assert (
462
+ seq_len == self.seq_length
463
+ ), f"Expected sequence length {self.seq_length}, but got {seq_len}"
464
+
465
+ # Create input_mask based on pad_length
466
+ # When a time point is masked, its value is mask_pad_value(default:255.)
467
+ input_mask = (
468
+ input != self.config.mask_pad_value
469
+ ) # 0: mask, 1: unmask [batch_size, seq_len]
470
+
471
+ # so whether the masked value 0 has the same effective of attention_mask
472
+ input_data = input * input_mask # [batch_size, seq_len]
473
+
474
+ # Patchify the input
475
+ input_data = input_data.unfold(
476
+ dimension=-1, size=self.patch_size, step=self.patch_size
477
+ ).contiguous() # input [batch_size, patch_num, patch_size]
478
+ hidden_states = self.patch_embedding(
479
+ input_data
480
+ ) # hidden_states [batch_size, patch_num, hidden_size]
481
+ hidden_states = hidden_states.transpose(
482
+ 0, 1
483
+ ).contiguous() # hidden_states [patch_num, batch_size, hidden_size], To adapt to the Megatron
484
+
485
+ # Patchify the mask: only the entire time points in a patch are masked then this patch is masked
486
+ attention_mask = input_mask.unfold(
487
+ dimension=-1, size=self.patch_size, step=self.patch_size
488
+ ).contiguous() # [batch_size, patch_num, patch_size]
489
+ attention_mask = (
490
+ attention_mask.sum(-1) == self.patch_size
491
+ ) # [batch_size, patch_num] # 0: mask, 1: unmask
492
+ attention_mask[:, -1] = True # The last patch is not masked
493
+ _, patch_num = attention_mask.shape
494
+ attention_mask = attention_mask.unsqueeze(2).repeat(
495
+ 1, 1, patch_num
496
+ ) * attention_mask.unsqueeze(1).repeat(
497
+ 1, patch_num, 1
498
+ ) # [batch_size, patch_num, patch_num]
499
+ attention_mask = attention_mask.unsqueeze(
500
+ 1
501
+ ).contiguous() # [batch_size, 1, patch_num, patch_num]
502
+
503
+ return hidden_states, attention_mask, input_mask
504
+
505
+ def _forward_output(
506
+ self, hidden_states, output_scale=None, input_mask=None, inference_context=None
507
+ ):
508
+ """
509
+ Perform a forward pass through the output layer.
510
+
511
+ Args:
512
+ expert_input (Tensor): Expert input of shape [batch_size, seq_len]
513
+ hidden_states (Tensor): Transformed hidden states of shape [patch_num, batch_size, hidden_size]
514
+ output_scale (Tensor, optional): Expert probabilities for the output layer [batch_size]
515
+ input_mask (Tensor, optional): Expert input mask of shape [batch_size, seq_len], 0:mask, 1:unmask
516
+
517
+ Returns:
518
+ expert_output (Tensor): Expert output of shape [batch_size, expert_output_size]
519
+ """
520
+
521
+ # [patch_num, batch_size, hidden_size] -> [batch_size, flatten_size (patch_num * hidden_size)]
522
+ patch_num, batch_size, hidden_size = hidden_states.shape
523
+ assert (
524
+ patch_num * hidden_size
525
+ ) == self.flatten_size, f"patch_num ({patch_num}) * hidden_size ({hidden_size}) != flatten_size ({self.flatten_size})"
526
+ hidden_states = hidden_states.transpose(0, 1).reshape(-1, self.flatten_size).contiguous()
527
+ expert_output = self.output_layer(hidden_states) # [batch_size, expert_output_size]
528
+ if output_scale is not None:
529
+ original_dtype = expert_output.dtype
530
+ expert_output = expert_output * output_scale.unsqueeze(-1)
531
+ expert_output = expert_output.to(original_dtype)
532
+
533
+ return expert_output
534
+
535
+ def forward(self, expert_input, rotary_pos_emb, expert_probs=None):
536
+ hidden_states, attention_mask, input_mask = self._forward_patch_embedding(expert_input)
537
+ for layer in self.layers:
538
+ hidden_states = layer(
539
+ hidden_states, attention_mask, rotary_pos_emb[: hidden_states.shape[0]]
540
+ )
541
+ hidden_states = self.final_layernorm(hidden_states)
542
+ expert_output = self._forward_output(hidden_states, expert_probs, input_mask)
543
+ return expert_output
544
+
545
+
546
+ class SequentialPatchMoE(nn.Module):
547
+ def __init__(self, config, expert_output_size=336):
548
+ super().__init__()
549
+ self.config = config
550
+ self.expert_output_size = expert_output_size
551
+ self.local_experts = nn.ModuleList(
552
+ [
553
+ PatchMoEExpert_v2(
554
+ config,
555
+ expert_output_size=expert_output_size,
556
+ patch_input_size=config.patch_size_list[expert_id],
557
+ final_layernorm=config.moe_expert_final_layernorm,
558
+ )
559
+ for expert_id in range(config.num_moe_experts)
560
+ ]
561
+ )
562
+
563
+ def forward(self, input, routing_map, rotary_pos_emb, expert_probs):
564
+ expert_output_list = []
565
+ batch_size, seq_len = input.size()
566
+
567
+ for i, expert in enumerate(self.local_experts):
568
+ token_mask = routing_map[:, i].bool() # shape (batch,)
569
+ current_inputs = input[token_mask] # (num_tokens_for_expert, seq_len)
570
+ current_probs = expert_probs[token_mask, i]
571
+
572
+ if current_inputs.numel() == 0:
573
+ expert_output = torch.zeros(
574
+ 0, self.expert_output_size, device=input.device, dtype=input.dtype
575
+ )
576
+ else:
577
+ expert_output = expert(current_inputs, rotary_pos_emb, current_probs)
578
+
579
+ full_output = torch.zeros(
580
+ batch_size, self.expert_output_size, device=input.device, dtype=input.dtype
581
+ )
582
+ full_output[token_mask] = expert_output
583
+ expert_output_list.append(full_output)
584
+
585
+ expert_output = reduce(torch.add, expert_output_list)
586
+ return expert_output
587
+
588
+
589
+ class RouterGatingLinearFunction(torch.autograd.Function):
590
+ """
591
+ Autograd function for router gating linear.
592
+ """
593
+
594
+ @staticmethod
595
+ def forward(ctx, inp: torch.Tensor, weight: torch.Tensor, router_dtype: torch.dtype):
596
+ """
597
+ Forward pass of the RouterGatingLinearFunction function.
598
+ """
599
+ ctx.router_dtype = router_dtype
600
+ ctx.input_dtype = inp.dtype
601
+ ctx.weight_dtype = weight.dtype
602
+ inp_shape = inp.shape
603
+ inp = inp.view(-1, inp_shape[-1])
604
+
605
+ output = torch.mm(inp.to(router_dtype), weight.to(router_dtype).t())
606
+
607
+ output = output.view(*inp_shape[:-1], -1)
608
+ return output
609
+
610
+
611
+ def router_gating_linear(inp: torch.Tensor, weight: torch.Tensor, router_dtype: torch.dtype):
612
+ """
613
+ Customized linear layer for router gating.
614
+ This linear layer accepts bfloat16 input and weight, and can return output with router_dtype.
615
+ It can reduce the memory usage by avoiding saving the intermediate high precision tensors.
616
+ """
617
+ return RouterGatingLinearFunction.apply(inp, weight, router_dtype)
618
+
619
+
620
+ class Router(ABC, nn.Module):
621
+ """Base Router class"""
622
+
623
+ def __init__(
624
+ self,
625
+ config: PatchMoeConfig,
626
+ ) -> None:
627
+ """
628
+ Initialize the Router module.
629
+
630
+ Args:
631
+ config (TransformerConfig): Configuration object for the Transformer model.
632
+ model_comm_pgs (ModelCommProcessGroups, optional): Process groups for MoE operations.
633
+ """
634
+ super().__init__()
635
+ self.config = config
636
+
637
+ # Initialize the gate weights.
638
+
639
+ if self.config.patch_size_list is not None:
640
+ assert self.config.moe_router_input_size is not None
641
+ self.weight = torch.nn.Parameter(
642
+ torch.empty(
643
+ (self.config.num_moe_experts, self.config.moe_router_input_size),
644
+ dtype=torch.float32,
645
+ )
646
+ )
647
+ else:
648
+ self.weight = torch.nn.Parameter(
649
+ torch.empty(
650
+ (self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32
651
+ )
652
+ )
653
+ self.reset_parameters()
654
+
655
+ def reset_parameters(self):
656
+ """Reset the router parameters."""
657
+ torch.nn.init.normal_(self.weight, mean=0, std=self.config.init_method_std)
658
+ self.weight.data = self.weight.data.to(dtype=self.config.torch_dtype)
659
+
660
+ def gating(self, input: torch.Tensor):
661
+ """Forward pass of the router gate.
662
+
663
+ Args:
664
+ input (torch.Tensor): Input tensor.
665
+
666
+ Returns:
667
+ torch.Tensor: Logits tensor.
668
+ """
669
+ if self.weight.device != input.device:
670
+ self.weight = self.weight.to(input.device)
671
+ router_dtype = input.dtype
672
+ logits = router_gating_linear(input, self.weight, router_dtype)
673
+ return logits
674
+
675
+ @abstractmethod
676
+ def routing(self, logits: torch.Tensor):
677
+ """Routing function.
678
+
679
+ Args:
680
+ logits (torch.Tensor): Logits tensor.
681
+
682
+ Returns:
683
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
684
+ probabilities and mapping.
685
+ """
686
+ raise NotImplementedError("Routing function not implemented.")
687
+
688
+ @abstractmethod
689
+ def forward(self, input: torch.Tensor):
690
+ """
691
+ Forward pass of the router.
692
+
693
+ Args:
694
+ input (torch.Tensor): Input tensor.
695
+ """
696
+ raise NotImplementedError("Forward function not implemented.")
697
+
698
+
699
+ class TopKRouter(Router):
700
+ """Route each token to the top-k experts."""
701
+
702
+ def __init__(
703
+ self,
704
+ config: PatchMoeConfig,
705
+ ) -> None:
706
+ """Initialize the zero token dropping router.
707
+
708
+ Args:
709
+ config (TransformerConfig): The configuration for the transformer model.
710
+ model_comm_pgs (ModelCommProcessGroups, optional): Process groups for MoE operations.
711
+ """
712
+ super().__init__(config=config)
713
+ self.topk = self.config.moe_router_topk
714
+ self.score_function = self.config.moe_router_score_function
715
+
716
+ self.enable_expert_bias = self.config.moe_router_enable_expert_bias
717
+ if self.enable_expert_bias:
718
+ self.register_buffer(
719
+ "local_tokens_per_expert",
720
+ torch.zeros(self.config.num_moe_experts, dtype=torch.float32),
721
+ persistent=False,
722
+ )
723
+ self.register_buffer(
724
+ "expert_bias", torch.zeros(self.config.num_moe_experts, dtype=torch.float32)
725
+ )
726
+ else:
727
+ self.local_tokens_per_expert = None
728
+ self.expert_bias = None
729
+
730
+ def routing(self, logits: torch.Tensor):
731
+ """Top-k routing function
732
+
733
+ Args:
734
+ logits (torch.Tensor): Logits tensor after gating.
735
+
736
+ Returns:
737
+ probs (torch.Tensor): The probabilities of token to experts assignment.
738
+ routing_map (torch.Tensor): The mapping of token to experts assignment,
739
+ with shape [num_tokens, num_experts].
740
+ """
741
+ logits = logits.view(-1, self.config.num_moe_experts)
742
+
743
+ scores, routing_map, tokens_per_expert = topk_softmax_with_capacity(
744
+ logits,
745
+ self.topk,
746
+ use_pre_softmax=self.config.moe_router_pre_softmax,
747
+ score_function=self.score_function,
748
+ expert_bias=self.expert_bias,
749
+ )
750
+ return scores, routing_map
751
+
752
+ def forward(self, input: torch.Tensor):
753
+ """
754
+ Forward pass of the router.
755
+
756
+ Args:
757
+ input (torch.Tensor): Input tensor.
758
+ """
759
+ logits = self.gating(input)
760
+
761
+ scores, routing_map = self.routing(logits)
762
+
763
+ return scores, routing_map
764
+
765
+
766
+ class PatchMoEMoELayer(nn.Module):
767
+ def __init__(self, config, layer_number):
768
+ super().__init__()
769
+ self.config = config
770
+ self.seq_length = config.seq_length
771
+ self.router = TopKRouter(config)
772
+ self.layer_number = layer_number
773
+ self.pred_length = config.pred_length
774
+ self.is_last_layer = self.layer_number == config.num_hidden_layers
775
+ if self.is_last_layer and self.config.heterogeneous_moe_layer:
776
+ self.expert_output_size = config.pred_length
777
+ else:
778
+ if self.config.do_expert_forecast:
779
+ self.expert_output_size = config.seq_length + config.pred_length
780
+ else:
781
+ self.expert_output_size = config.seq_length
782
+
783
+ if self.is_last_layer and self.config.heterogeneous_moe_layer:
784
+ # If heterogeneous_moe_layer is True, the backcast will be None
785
+ self.backcast_layernorm = None
786
+ else:
787
+ self.backcast_layernorm = RMSNorm(self.seq_length)
788
+
789
+ self.experts = SequentialPatchMoE(
790
+ config,
791
+ expert_output_size=self.expert_output_size,
792
+ )
793
+ self.shared_experts = PatchMoEExpert_v2(
794
+ config,
795
+ expert_output_size=self.expert_output_size,
796
+ patch_input_size=config.shared_patch_size,
797
+ final_layernorm=config.moe_expert_final_layernorm,
798
+ )
799
+
800
+ def time_series_preprocess(self, input: torch.Tensor):
801
+ """
802
+ Preprocess time series(sample) for dispatch.
803
+
804
+ Applies RevIN to input time series(sample), and process the input mask (0: mask, 1: unmask)
805
+
806
+ Args:
807
+ input (torch.Tensor): The input time series (samples) to the MoE layer. [batch_size, seq_len]
808
+
809
+ Returns:
810
+ input (torch.Tensor): The (RevIN) backcast time series (samples). [batch_size, seq_len]
811
+ means (torch.Tensor): The means of the non-masked backcast time series (samples). [batch_size, 1]
812
+ stdev (torch.Tensor): The standard deviation of the non-masked backcast time series (samples). [batch_size, 1]
813
+ """
814
+
815
+ batch_size, seq_len = input.shape
816
+ assert seq_len == self.seq_length, f"seq_len {seq_len} != self.seq_length {self.seq_length}"
817
+
818
+ # Create input_mask based on pad_length
819
+ # When a time point is masked, its value is mask_pad_value(default:255.)
820
+ input_mask = (
821
+ input != self.config.mask_pad_value
822
+ ) # 0: mask, 1: unmask [batch_size, seq_len]
823
+
824
+ self.input_mask = input_mask
825
+
826
+ return input
827
+
828
+ def router_and_preprocess(self, backcast: torch.Tensor):
829
+ """Compute and preprocess time series(sample) routing for dispatch.
830
+
831
+ This method uses the router to determine which experts to send each time series(sample) to,
832
+ producing routing probabilities and a mapping. It then preprocesses the
833
+ input time series (samples) and probabilities for the time series(sample) dispatcher. The original
834
+ input time series (samples) are returned as a residual connection.
835
+ """
836
+ # backcast [batch_size, seq_len] means/stdev [batch_size, 1]
837
+ backcast = self.time_series_preprocess(backcast)
838
+
839
+ residual = backcast # residual: [batch_size, seq_len], the input to the shared experts
840
+
841
+ # TODO: Check the effective of the masked value to the router
842
+ probs, routing_map = self.router(
843
+ backcast * self.input_mask
844
+ ) # probs/routing_map: [batch_size, num_experts]
845
+
846
+ return backcast, probs, residual, routing_map
847
+
848
+ def experts_compute(
849
+ self,
850
+ input: torch.Tensor, # [num_permuted_samples_after_dispatch, seq_len]
851
+ probs: torch.Tensor, # [num_permuted_samples_after_dispatch]
852
+ residual: torch.Tensor, # [batch_size, seq_len]
853
+ rotary_pos_emb: torch.Tensor,
854
+ routing_map: torch.Tensor, # [seq_len, 1, 1, kv_channels(hidden_size // num_heads)]
855
+ ):
856
+ """Computes the output of the experts on the dispatched time series(sample).
857
+
858
+ This method first post-processes the dispatched input to get permuted time series(sample)
859
+ for each expert. It then passes the time series(sample) through the local experts.
860
+ If a shared expert is configured and not overlapped with communication,
861
+ it is also applied. The output from the experts is preprocessed for the
862
+ combine step.
863
+ """
864
+ # shared_expert_output: [batch_size, seq_len (+ pred_len)]
865
+ shared_experts_output = self.shared_experts(residual, rotary_pos_emb)
866
+
867
+ # dispatched_input (global_input_tokens): [num_permuted_samples_after_dispatch_postprocess(sorted), seq_len]
868
+ # tokens_per_expert (global_probs): [num_experts]
869
+ # permuted_probs (global_probs): [num_permuted_samples_after_dispatch_postprocess(sorted)]
870
+
871
+ experts_output = self.experts(input, routing_map, rotary_pos_emb, probs)
872
+
873
+ return experts_output, shared_experts_output
874
+
875
+ def postprocess(
876
+ self,
877
+ backcast: torch.Tensor, # [batch_size, seq_len]
878
+ forecast: torch.Tensor, # [batch_size, pred_len]
879
+ output_backcast: torch.Tensor, # [batch_size, seq_len]
880
+ output_forecast: torch.Tensor, # [batch_size, pred_len]
881
+ ):
882
+ """
883
+ Args:
884
+ backcast (torch.Tensor): The previous layer's backcast time series (samples). [batch_size, seq_len]
885
+ forecast (torch.Tensor): The previous layer's forecast time series (samples). [batch_size, pred_len]
886
+ output_backcast (torch.Tensor): The current layer's output backcast time series (samples). [batch_size, seq_len]
887
+ output_forecast (torch.Tensor): The current layer's output forecast time series (samples). [batch_size, pred_len]
888
+ means (torch.Tensor): The means of the non-masked backcast time series (samples). [batch_size, 1]
889
+ stdev (torch.Tensor): The standard deviation of the non-masked backcast time series (samples). [batch_size, 1]
890
+ backcast_mask (torch.Tensor): The previous layer's backcast mask of time series (samples) . [batch_size, seq_len]
891
+ """
892
+ if output_backcast is not None:
893
+ output_backcast = self.backcast_layernorm(output_backcast) # LayerNorm
894
+ if self.config.residual_backcast:
895
+ output_backcast = backcast - output_backcast
896
+
897
+ output_backcast[~self.input_mask] = (
898
+ self.config.mask_pad_value
899
+ ) # Important! Recover the mask time point back to mask_pad_value(default:255.)
900
+
901
+ if (
902
+ self.config.do_expert_forecast and forecast is not None
903
+ ): # The first layer's forecast is None
904
+ output_forecast = forecast + output_forecast
905
+
906
+ return output_backcast, output_forecast
907
+
908
+ def combine(
909
+ self,
910
+ experts_output: torch.Tensor,
911
+ shared_experts_output: torch.Tensor,
912
+ ):
913
+ """Combines expert outputs via communication and adds shared expert output.
914
+
915
+ This method uses the time series(sample) dispatcher to combine the outputs from different
916
+ experts (e.g., via an All-to-All communication). It then adds the output
917
+ from the shared expert if it exists.
918
+ """
919
+ assert (
920
+ experts_output.shape == shared_experts_output.shape
921
+ ), f"experts_output shape {experts_output.shape} doesn't equal to shared_experts_output shape:{shared_experts_output.shape}"
922
+ output = experts_output + shared_experts_output
923
+
924
+ if self.is_last_layer and self.config.heterogeneous_moe_layer:
925
+ output_backcast = None
926
+ output_forecast = output
927
+ assert (
928
+ output_forecast.shape[1] == self.pred_length
929
+ ), f"heterogeneous_moe_layer=True, expected the last moe layer's output pred len: {self.pred_length}, but got {output_forecast.shape[1]}"
930
+ else:
931
+ # Noting: the mask time point there maybe not mask_pad_value(default:255.), it will be postprocessed
932
+ output_backcast = output[:, : self.seq_length] # [batch_size, seq_len]
933
+
934
+ if self.config.do_expert_forecast:
935
+ output_forecast = output[:, self.seq_length :] # [batch_size, pred_len]
936
+ assert (
937
+ output_forecast.shape[1] == self.pred_length
938
+ ), f"do_expert_forecast=True, expected the last moe layer's output pred len: {self.pred_length}, but got {output_forecast.shape[1]}"
939
+ else:
940
+ output_forecast = None
941
+
942
+ return output_backcast, output_forecast
943
+
944
+ def forward(self, backcast, forecast, rotary_pos_emb):
945
+ inputs, probs, residual, routing_map = self.router_and_preprocess(backcast)
946
+ experts_output, shared_experts_output = self.experts_compute(
947
+ inputs, probs, residual, rotary_pos_emb, routing_map
948
+ )
949
+ output_backcast, output_forecast = self.combine(experts_output, shared_experts_output)
950
+ output_backcast, output_forecast = self.postprocess(
951
+ backcast, forecast, output_backcast, output_forecast
952
+ )
953
+ return output_backcast, output_forecast
954
+
955
+
956
+ class PatchMoEBlock(nn.Module):
957
+ def __init__(self, config):
958
+ super().__init__()
959
+ self.config = config
960
+ self.layers = nn.ModuleList(
961
+ [
962
+ PatchMoEMoELayer(config, layer_num + 1)
963
+ for layer_num in range(self.config.num_hidden_layers)
964
+ ]
965
+ )
966
+
967
+ def forward(self, x, rotary_pos_emb):
968
+ backcast = x
969
+ forecast = None
970
+ for layer in self.layers:
971
+ backcast, forecast = layer(backcast, forecast, rotary_pos_emb)
972
+ return backcast, forecast
973
+
974
+
975
+ class PatchMoEPreTrainedModel(PreTrainedModel):
976
+ config_class = PatchMoeConfig
977
+ base_model_prefix = "model"
978
+ supports_gradient_checkpointing = True
979
+ _no_split_modules = ["PatchMoEMoELayer"]
980
+ _skip_keys_device_placement = "past_key_values"
981
+ _supports_flash_attn_2 = True
982
+ _supports_sdpa = False
983
+ _supports_cache_class = True
984
+
985
+ def _init_weights(self, module):
986
+ if isinstance(module, nn.Linear):
987
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
988
+ if module.bias is not None:
989
+ module.bias.data.zero_()
990
+ elif isinstance(module, nn.Embedding):
991
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
992
+ if module.padding_idx is not None:
993
+ module.weight.data[module.padding_idx].zero_()
994
+
995
+
996
+ class PatchMoEModel(PatchMoEPreTrainedModel):
997
+ def __init__(self, config: PatchMoeConfig):
998
+ super().__init__(config)
999
+ self.config = config
1000
+ self.seq_length = config.seq_length
1001
+ self.rotary_pos_emb = RotaryEmbedding(
1002
+ kv_channels=self.config.kv_channels,
1003
+ rotary_base=config.rotary_base,
1004
+ use_cpu_initialization=self.config.use_cpu_initialization,
1005
+ rotary_interleaved=self.config.rotary_interleaved,
1006
+ )
1007
+ self.decoder = PatchMoEBlock(config=config)
1008
+ if self.config.do_expert_forecast and self.config.heterogeneous_moe_layer:
1009
+ self.output_layer = IdentityOp()
1010
+ else:
1011
+ self.output_layer = nn.Linear(
1012
+ in_features=self.seq_length,
1013
+ out_features=self.config.pred_length,
1014
+ bias=self.config.add_bias_linear,
1015
+ )
1016
+
1017
+ def revin(
1018
+ self,
1019
+ input: Tensor, # [batch_size, seq_len]
1020
+ input_mask: Tensor, # [batch_size, seq_len] 0:mask, 1:unmask
1021
+ ):
1022
+ """Normalization from Non-stationary Transformer"""
1023
+
1024
+ input_data = input * input_mask
1025
+ sum_per_sample = torch.sum(
1026
+ input_data, dim=1, keepdim=True
1027
+ ).detach() # [batch_size, 1], torch.bfloat16
1028
+ count_per_sample = torch.sum(
1029
+ input_mask, dim=1, keepdim=True
1030
+ ).detach() # [batch_size, 1], torch.int64
1031
+ assert (
1032
+ torch.any(count_per_sample == 0) == False
1033
+ ), f"There is zero in count_per_sample, shape: {input[torch.where(count_per_sample.squeeze(1) == 0)[0]]}"
1034
+ means = sum_per_sample / count_per_sample # [batch_size, 1]
1035
+ input_data = input_data - means
1036
+ input_data = input_data * input_mask
1037
+ var_per_sample = (
1038
+ torch.sum(input_data**2, dim=1, keepdim=True).detach() / count_per_sample
1039
+ ) # [batch_size, 1]
1040
+ stdev = torch.sqrt(var_per_sample + 1e-9)
1041
+ input_data = input_data / stdev
1042
+ input_data = input_data * input_mask
1043
+
1044
+ # recover the mask_pad_value(default:255.)
1045
+ input = input * ~(input_mask) + input_data
1046
+
1047
+ return input, means, stdev
1048
+
1049
+ def forward(self, input, revin):
1050
+ batch_size, input_len = input.shape
1051
+ if input_len > self.seq_length:
1052
+ input = input[:, -self.seq_length :]
1053
+ elif input_len < self.seq_length:
1054
+ pad_len = self.seq_length - input_len
1055
+ input = F.pad(
1056
+ input, pad=(pad_len, 0), mode="constant", value=self.config.mask_pad_value
1057
+ )
1058
+ input_len = self.seq_length
1059
+
1060
+ input_mask = input != self.config.mask_pad_value
1061
+
1062
+ # Step1. RevIN
1063
+ if revin:
1064
+ input, means, stdev = self.revin(input, input_mask)
1065
+
1066
+ # Step2. Get rotary_pos_emb
1067
+ # rotary_pos_emb [input_len, 1, 1, kv_channels(hidden_size // num_heads)]
1068
+ rotary_pos_emb = self.rotary_pos_emb(input_len, device=input.device)
1069
+
1070
+ # Step3. Do one-step inference to get mixed forecasts from multiple forecast heads
1071
+ # mixed_pred: [batch_size, sum(multi_forecast_head)]
1072
+ mixed_pred = self._inference_step(
1073
+ input=input, input_mask=input_mask, rotary_pos_emb=rotary_pos_emb
1074
+ )
1075
+
1076
+ # Step4. Based on the mixed forecasts, do auto-regressive inference according to
1077
+ # the step list of each forecast head
1078
+ if self.config.multi_forecast_head_type == "single":
1079
+ final_output = self._auto_regressive_single_head(
1080
+ input=input,
1081
+ input_mask=input_mask,
1082
+ patchmoe_forecast=mixed_pred,
1083
+ rotary_pos_emb=rotary_pos_emb,
1084
+ )
1085
+ else:
1086
+ raise NotImplementedError
1087
+
1088
+ # Step5. RevIN
1089
+ if revin:
1090
+ final_output = final_output * (stdev.repeat(1, self.config.inference_length))
1091
+ final_output = final_output + (means.repeat(1, self.config.inference_length))
1092
+
1093
+ return final_output.detach().float()
1094
+
1095
+ def _inference_step(
1096
+ self,
1097
+ input,
1098
+ input_mask,
1099
+ rotary_pos_emb,
1100
+ ):
1101
+ if self.config.do_base_forecast:
1102
+ base_forecast, _ = self.base_output_layer(input)
1103
+ else:
1104
+ base_forecast = None
1105
+
1106
+ decoder_backcast, decoder_forecast = self.decoder(
1107
+ input, # [batch_size, seq_len]
1108
+ rotary_pos_emb, # [input_len, 1, 1, kv_channels(hidden_size // num_heads)]
1109
+ )
1110
+
1111
+ if self.config.do_expert_forecast:
1112
+ assert decoder_forecast is not None, f"decoder_forecast is None"
1113
+ if self.config.heterogeneous_moe_layer:
1114
+ decoder_forecast = self.output_layer(decoder_forecast) # IdentityOp
1115
+ else:
1116
+ final_forecast = self.output_layer(decoder_backcast * input_mask)
1117
+ decoder_forecast = decoder_forecast + final_forecast
1118
+ else:
1119
+ # The decoder_backcast contains the mask_pad_val(default:255.)
1120
+ decoder_forecast, _ = self.output_layer(decoder_backcast * input_mask)
1121
+
1122
+ if self.config.do_base_forecast:
1123
+ assert base_forecast is not None, f"base_forecast is None"
1124
+ patchmoe_forecast = base_forecast + decoder_forecast
1125
+ else:
1126
+ patchmoe_forecast = decoder_forecast
1127
+
1128
+ return patchmoe_forecast
1129
+
1130
+ def _auto_regressive_single_head(
1131
+ self,
1132
+ input, # [batch_size, seq_len]
1133
+ input_mask, # [batch_size, seq_len]
1134
+ patchmoe_forecast, # [batch_size, max(multi_forecast_head)]
1135
+ rotary_pos_emb, # [seq_len, 1, 1, kv_channels(hidden_size // num_heads)]
1136
+ auto_regressive_strategy="from_long_to_short",
1137
+ ):
1138
+ """auto regressive prediction with [single] head"""
1139
+ assert (
1140
+ self.config.multi_forecast_head_type == "single"
1141
+ ), f"_auto_regressive_single_head only support multi_forecast_head_type==single "
1142
+
1143
+ if auto_regressive_strategy == "from_long_to_short":
1144
+ # From long to short
1145
+ multi_forecast_head_list = sorted(self.config.multi_forecast_head_list, reverse=True)
1146
+
1147
+ final_output = patchmoe_forecast
1148
+ while final_output.shape[1] < self.config.inference_length:
1149
+ # adaptive choose the forecast head
1150
+ remain_pred_len = self.config.inference_length - final_output.shape[1]
1151
+ for idx, head_pred_len in enumerate(multi_forecast_head_list):
1152
+ if head_pred_len <= remain_pred_len:
1153
+ break
1154
+ if idx == len(multi_forecast_head_list):
1155
+ idx = len(multi_forecast_head_list) - 1
1156
+ head_pred_len = multi_forecast_head_list[idx]
1157
+
1158
+ # one-step model prediction
1159
+ input = torch.cat([input, patchmoe_forecast], dim=1)[
1160
+ :, -self.seq_length :
1161
+ ].contiguous()
1162
+ input_mask = torch.cat(
1163
+ [
1164
+ input_mask,
1165
+ torch.ones(
1166
+ patchmoe_forecast.shape,
1167
+ dtype=input_mask.dtype,
1168
+ device=input_mask.device,
1169
+ ),
1170
+ ],
1171
+ dim=1,
1172
+ )[
1173
+ :, -self.seq_length :
1174
+ ].contiguous() # 0:mask, 1:unmask
1175
+
1176
+ patchmoe_forecast = self._inference_step(
1177
+ input=input,
1178
+ input_mask=input_mask,
1179
+ rotary_pos_emb=rotary_pos_emb,
1180
+ )
1181
+
1182
+ # the core idea of multi forecast head type of [single]
1183
+ patchmoe_forecast = patchmoe_forecast[:, :head_pred_len]
1184
+
1185
+ final_output = torch.cat([final_output, patchmoe_forecast], dim=1)
1186
+
1187
+ final_output = final_output[:, : self.config.inference_length]
1188
+
1189
+ elif auto_regressive_strategy == "from_short_to_long":
1190
+ # From short to long
1191
+ # in validate_args, it has been sorted, and check the valid config
1192
+ multi_forecast_head_list = sorted(self.config.multi_forecast_head_list)
1193
+ multi_forecast_head_dict = {}
1194
+ for idx, head_pred_len in enumerate(self.config.multi_forecast_head_list):
1195
+ if idx == len(multi_forecast_head_list) - 1:
1196
+ ar_step = math.ceil(self.config.inference_length / head_pred_len)
1197
+ else:
1198
+ ar_step = min(
1199
+ self.config.autoregressive_step_list[idx],
1200
+ self.config.multi_forecast_head_list[idx + 1]
1201
+ // self.config.multi_forecast_head_list[idx],
1202
+ )
1203
+ # ar_step = multi_forecast_head_list[idx + 1] // multi_forecast_head_list[idx]
1204
+
1205
+ multi_forecast_head_dict[head_pred_len] = ar_step
1206
+
1207
+ # the core idea of strategy [from_short_to_long]
1208
+ mixed_pred = patchmoe_forecast
1209
+ output_list = []
1210
+ cur_pred = None
1211
+ cur_pred_len = 0
1212
+
1213
+ # from the first(shortest) as begining
1214
+ for idx, head_pred_len in enumerate(self.config.multi_forecast_head_list):
1215
+ # assert cur_pred_len <= head_pred_len, \
1216
+ # "Accumulated prediction length exceeds the prediction length of current forecast head"
1217
+
1218
+ ar_step = multi_forecast_head_dict[head_pred_len]
1219
+ if ar_step == 0:
1220
+ # Ignore the current forecast head
1221
+ continue
1222
+
1223
+ # Add current head's first auto-regressive step of prediction
1224
+ head_pred = mixed_pred[:, :head_pred_len] # [single]
1225
+ output_list.append(head_pred[:, cur_pred_len:])
1226
+ cur_pred = torch.cat(output_list, dim=1)
1227
+ cur_pred_len = cur_pred.shape[1]
1228
+ if cur_pred_len >= self.config.inference_length:
1229
+ break
1230
+
1231
+ # Do auto-regressive of the rest of the steps
1232
+ for _ in range(1, ar_step + 1):
1233
+ # one-step model prediction
1234
+ cur_input = torch.cat([input, cur_pred], dim=1)[
1235
+ :, -self.seq_length :
1236
+ ].contiguous()
1237
+ cur_input_mask = torch.cat(
1238
+ [
1239
+ input_mask,
1240
+ torch.ones(
1241
+ cur_pred.shape, dtype=input_mask.dtype, device=input_mask.device
1242
+ ),
1243
+ ],
1244
+ dim=1,
1245
+ )[
1246
+ :, -self.seq_length :
1247
+ ].contiguous() # 0:mask, 1:unmask
1248
+
1249
+ patchmoe_forecast = self._inference_step(
1250
+ input=cur_input,
1251
+ input_mask=cur_input_mask,
1252
+ rotary_pos_emb=rotary_pos_emb,
1253
+ )
1254
+
1255
+ head_pred = patchmoe_forecast[:, :head_pred_len]
1256
+ output_list.append(head_pred)
1257
+ cur_pred = torch.cat(output_list, dim=1)
1258
+ cur_pred_len = cur_pred.shape[1]
1259
+ if cur_pred_len >= self.config.inference_length:
1260
+ break
1261
+
1262
+ if cur_pred_len >= self.config.inference_length:
1263
+ break
1264
+
1265
+ final_output = cur_pred[
1266
+ :, : self.config.inference_length
1267
+ ] # [batch_size, inference_len]
1268
+
1269
+ assert final_output.shape[1] == self.config.inference_length
1270
+ return final_output
1271
+
1272
+
1273
+ class PatchMoEForPrediction(PatchMoEPreTrainedModel, PatchMoEGenerationMixin):
1274
+ def __init__(self, config: PatchMoeConfig):
1275
+ super().__init__(config)
1276
+ self.config = config
1277
+ self.model = PatchMoEModel(self.config)
1278
+ self.post_init()
1279
+
1280
+ def forward(
1281
+ self,
1282
+ input_ids: torch.FloatTensor,
1283
+ attention_mask: Optional[torch.Tensor] = None,
1284
+ labels: Optional[torch.FloatTensor] = None,
1285
+ return_dict: Optional[bool] = False,
1286
+ max_output_length: Optional[int] = None,
1287
+ revin: Optional[bool] = False,
1288
+ ):
1289
+ self.model.config.inference_length = max_output_length
1290
+ outputs = self.model(input=input_ids, revin=revin)
1291
+
1292
+ loss = None
1293
+ logits = outputs
1294
+
1295
+ if labels is not None:
1296
+ loss_fn = nn.MSELoss()
1297
+ loss = loss_fn(logits, labels)
1298
+
1299
+ if not return_dict:
1300
+ output = (logits,)
1301
+ return ((loss,) + output) if loss is not None else output
1302
+
1303
+ return logits
1304
+
1305
+ def prepare_inputs_for_generation(
1306
+ self,
1307
+ input_ids,
1308
+ past_key_values=None,
1309
+ attention_mask=None,
1310
+ inputs_embeds=None,
1311
+ revin=False,
1312
+ **kwargs,
1313
+ ):
1314
+ """
1315
+ Prepare model inputs for autoregressive generation.
1316
+ """
1317
+
1318
+ model_inputs = {"input_ids": input_ids}
1319
+
1320
+ model_inputs.update(
1321
+ {
1322
+ "revin": revin,
1323
+ }
1324
+ )
1325
+
1326
+ return model_inputs
ts_generation_mixin.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Time Series Generation Mixin for PatchMoE
3
+
4
+ This module provides generation capabilities specifically designed for time series
5
+ forecasting tasks. It extends the standard Transformers GenerationMixin to handle
6
+ time series data with proper input/output reshaping and autoregressive generation.
7
+ """
8
+
9
+ from typing import List, Optional, Union, Callable
10
+ import torch
11
+ from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
12
+ from transformers.generation.utils import (
13
+ GenerateNonBeamOutput,
14
+ GenerationConfig,
15
+ GenerateOutput,
16
+ )
17
+
18
+
19
+ class PatchMoEGenerationMixin(GenerationMixin):
20
+ """
21
+ Generation mixin class for PatchMoE time series forecasting.
22
+
23
+ This class extends the standard Transformers GenerationMixin to provide
24
+ specialized generation capabilities for time series data, including proper
25
+ handling of multi-channel inputs and autoregressive forecasting.
26
+ """
27
+
28
+ @torch.no_grad()
29
+ def generate(
30
+ self,
31
+ inputs: Optional[torch.Tensor] = None,
32
+ generation_config: Optional[GenerationConfig] = None,
33
+ logits_processor: Optional[LogitsProcessorList] = None,
34
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
35
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
36
+ synced_gpus: Optional[bool] = None,
37
+ assistant_model: Optional["PreTrainedModel"] = None,
38
+ streamer: Optional["BaseStreamer"] = None,
39
+ negative_prompt_ids: Optional[torch.Tensor] = None,
40
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
41
+ revin: Optional[bool] = True,
42
+ num_samples: Optional[int] = 1,
43
+ **kwargs,
44
+ ) -> Union[GenerateOutput, torch.LongTensor]:
45
+ """
46
+ Generate time series forecasts using the PatchMoE model.
47
+
48
+ This method handles the generation of time series forecasts with proper
49
+ input preprocessing and output postprocessing for multi-channel data.
50
+
51
+ Args:
52
+ inputs (torch.Tensor): Input time series data of shape:
53
+ - [batch_size, seq_len] for single-channel
54
+ - [batch_size, seq_len, channels] for multi-channel
55
+ generation_config (GenerationConfig, optional): Generation configuration
56
+ logits_processor (LogitsProcessorList, optional): Logits processors
57
+ stopping_criteria (StoppingCriteriaList, optional): Stopping criteria
58
+ prefix_allowed_tokens_fn (Callable, optional): Prefix token function
59
+ synced_gpus (bool, optional): Whether to sync GPUs
60
+ assistant_model (PreTrainedModel, optional): Assistant model
61
+ streamer (BaseStreamer, optional): Output streamer
62
+ negative_prompt_ids (torch.Tensor, optional): Negative prompt IDs
63
+ negative_prompt_attention_mask (torch.Tensor, optional): Negative attention mask
64
+ revin (bool, optional): Whether to apply RevIN normalization
65
+ num_samples (int, optional): Number of samples to generate
66
+ **kwargs: Additional keyword arguments
67
+
68
+ Returns:
69
+ torch.Tensor: Generated forecasts of shape [batch_size, pred_len, channels]
70
+
71
+ Raises:
72
+ ValueError: If input shape is not supported
73
+ """
74
+ # Extract input dimensions
75
+ batch_size = inputs.shape[0]
76
+ length = inputs.shape[1]
77
+ channel = 1
78
+
79
+ # Handle multi-channel inputs
80
+ if len(inputs.shape) == 3:
81
+ channel = inputs.shape[2]
82
+ # Reshape to [batch_size * channels, seq_len] for processing
83
+ inputs = inputs.reshape(batch_size * channel, length)
84
+ elif len(inputs.shape) > 3:
85
+ raise ValueError("Input shape must be [batch, seq_len, channel] or [batch, seq_len]")
86
+
87
+ # Call parent generation method
88
+ outputs = super().generate(
89
+ inputs=inputs,
90
+ generation_config=generation_config,
91
+ logits_processor=logits_processor,
92
+ stopping_criteria=stopping_criteria,
93
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
94
+ synced_gpus=synced_gpus,
95
+ assistant_model=assistant_model,
96
+ streamer=streamer,
97
+ negative_prompt_ids=negative_prompt_ids,
98
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
99
+ revin=revin,
100
+ **kwargs,
101
+ )
102
+
103
+ # Reshape outputs back to [batch_size, pred_len, channels]
104
+ pred_len = outputs.shape[1]
105
+ outputs = outputs.reshape(batch_size, channel, pred_len)
106
+ outputs = outputs.transpose(1, 2).contiguous()
107
+ return outputs
108
+
109
+ def _greedy_search(
110
+ self,
111
+ input_ids: torch.Tensor,
112
+ logits_processor: Optional[LogitsProcessorList] = None,
113
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
114
+ max_length: Optional[int] = None,
115
+ pad_token_id: Optional[int] = None,
116
+ eos_token_id: Optional[Union[int, List[int]]] = None,
117
+ output_attentions: Optional[bool] = None,
118
+ output_hidden_states: Optional[bool] = None,
119
+ output_scores: Optional[bool] = None,
120
+ output_logits: Optional[bool] = None,
121
+ return_dict_in_generate: Optional[bool] = None,
122
+ synced_gpus: bool = False,
123
+ streamer: Optional["BaseStreamer"] = None,
124
+ **model_kwargs,
125
+ ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
126
+ """
127
+ Perform greedy search generation for time series forecasting.
128
+
129
+ This method implements greedy decoding specifically for time series data,
130
+ where the model generates forecasts autoregressively.
131
+
132
+ Args:
133
+ input_ids (torch.Tensor): Input time series data
134
+ logits_processor (LogitsProcessorList, optional): Logits processors
135
+ stopping_criteria (StoppingCriteriaList, optional): Stopping criteria
136
+ max_length (int, optional): Maximum generation length
137
+ pad_token_id (int, optional): Padding token ID (not used for time series)
138
+ eos_token_id (int or List[int], optional): End-of-sequence token ID
139
+ output_attentions (bool, optional): Whether to output attentions
140
+ output_hidden_states (bool, optional): Whether to output hidden states
141
+ output_scores (bool, optional): Whether to output scores
142
+ output_logits (bool, optional): Whether to output logits
143
+ return_dict_in_generate (bool, optional): Whether to return dict
144
+ synced_gpus (bool): Whether to sync GPUs
145
+ streamer (BaseStreamer, optional): Output streamer
146
+ **model_kwargs: Additional model arguments
147
+
148
+ Returns:
149
+ torch.Tensor: Generated time series forecasts
150
+ """
151
+ # Move inputs to model device
152
+ input_ids = input_ids.to(self.device)
153
+ batch_size, cur_len = input_ids.shape
154
+
155
+ # Initialize processors and criteria if not provided
156
+ logits_processor = (
157
+ logits_processor if logits_processor is not None else LogitsProcessorList()
158
+ )
159
+ stopping_criteria = (
160
+ stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
161
+ )
162
+
163
+ # Prepare model inputs for generation
164
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
165
+
166
+ # Generate forecasts with specified output length
167
+ outputs = self(
168
+ **model_inputs,
169
+ return_dict=True,
170
+ max_output_length=stopping_criteria.max_length - cur_len,
171
+ )
172
+ return outputs