rulixiang commited on
Commit
303bbc0
·
1 Parent(s): 067f25c

Update ckpt

Browse files
config.json CHANGED
@@ -1,58 +1,57 @@
1
  {
2
  "_name_or_path": "FalconTST",
 
 
3
  "architectures": [
4
  "FalconTSTForPrediction"
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "configuration_FalconTST.FalconTSTConfig",
8
- "AutoModelForCausalLM": "modeling_FalconTST.FalconTSTForPrediction"
9
  },
10
- "disable_bias_linear": false,
11
- "do_base_forecast": false,
12
- "do_expert_forecast": true,
13
- "expert_num_layers": 4,
14
- "ffn_hidden_size": 4096,
15
- "heterogeneous_moe_layer": false,
16
  "hidden_size": 1024,
17
- "init_method_std": 0.06,
18
- "is_revin": true,
19
- "k_layernorm": false,
20
- "kv_channels": 64,
21
- "mask_pad_value": 255.0,
22
- "model_type": "FalconTST",
23
- "moe_expert_final_layernorm": true,
24
- "moe_ffn_hidden_size": 4096,
25
- "moe_router_enable_expert_bias": false,
26
- "moe_router_input_size": 2880,
27
- "moe_router_pre_softmax": true,
28
- "moe_router_score_function": "softmax",
29
- "moe_router_topk": 1,
30
- "moe_shared_expert_intermediate_size": 4096,
31
- "multi_forecast_head_list": [
32
- 24,
33
- 96,
34
- 336
35
- ],
36
  "num_attention_heads": 16,
37
- "num_hidden_layers": 2,
38
- "num_moe_experts": 4,
39
- "torch_dtype": "bfloat16",
 
40
  "patch_size_list": [
41
  120,
42
  96,
43
  64,
44
  36
45
  ],
46
- "pred_length": 336,
47
- "q_layernorm": false,
48
  "residual_backcast": true,
 
 
 
 
 
 
 
 
 
 
49
  "rotary_base": 1000000,
50
  "rotary_interleaved": false,
51
- "seq_length": 2880,
52
- "shared_patch_size": 32,
53
- "tie_word_embeddings": false,
54
  "transformer_input_layernorm": true,
55
- "transformers_version": "4.40.1",
56
- "use_cache": true,
57
- "use_cpu_initialization": true
 
 
 
 
 
 
 
 
 
 
58
  }
 
1
  {
2
  "_name_or_path": "FalconTST",
3
+ "model_type": "FalconTST",
4
+ "transformers_version": "4.40.1",
5
  "architectures": [
6
  "FalconTSTForPrediction"
7
  ],
8
  "auto_map": {
9
  "AutoConfig": "configuration_FalconTST.FalconTSTConfig",
10
+ "AutoModel": "modeling_FalconTST.FalconTSTForPrediction"
11
  },
12
+
13
+ "add_bias_linear": false,
14
+ "num_hidden_layers": 2,
 
 
 
15
  "hidden_size": 1024,
16
+ "ffn_hidden_size": 4096,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  "num_attention_heads": 16,
18
+ "seq_length": 2880,
19
+ "mask_pad_value": 255.0,
20
+ "is_revin": true,
21
+ "shared_patch_size": 32,
22
  "patch_size_list": [
23
  120,
24
  96,
25
  64,
26
  36
27
  ],
 
 
28
  "residual_backcast": true,
29
+ "do_base_forecast": false,
30
+ "do_expert_forecast": true,
31
+ "heterogeneous_moe_layer": false,
32
+ "expert_num_layers": 4,
33
+ "multi_forecast_head_list": [
34
+ 24,
35
+ 96,
36
+ 336
37
+ ],
38
+ "multi_forecast_head_type": "single",
39
  "rotary_base": 1000000,
40
  "rotary_interleaved": false,
41
+ "q_layernorm": false,
42
+ "k_layernorm": false,
 
43
  "transformer_input_layernorm": true,
44
+
45
+ "num_experts": 4,
46
+ "moe_router_topk": 1,
47
+ "moe_router_pre_softmax": true,
48
+ "moe_router_score_function": "softmax",
49
+ "moe_ffn_hidden_size": 4096,
50
+ "moe_shared_expert_intermediate_size": 4096,
51
+ "moe_router_enable_expert_bias": false,
52
+ "moe_expert_final_layernorm": true,
53
+
54
+ "use_cpu_initialization": true,
55
+ "init_method_std": 0.06,
56
+ "use_cache": true
57
  }
configuration_FalconTST.py CHANGED
@@ -100,112 +100,110 @@ class FalconTSTConfig(PretrainedConfig):
100
  """
101
 
102
  model_type = "FalconTST"
103
- keys_to_ignore_at_inference = ["past_key_values"]
104
 
105
  def __init__(
106
  self,
107
- hidden_size: int = 1024,
108
- ffn_hidden_size: int = 4096,
109
- seq_length: int = 2880,
110
  add_bias_linear: bool = False,
111
- rope_theta: int = 10000,
112
  num_hidden_layers: int = 3,
 
 
113
  num_attention_heads: int = 16,
 
114
  mask_pad_value: float = 255.0,
115
- expert_num_layers: int = 4,
116
- shared_patch_size: int = 64,
117
-
118
- patch_size_list: Optional[List[int]] = None,
119
- multi_forecast_head_list: Optional[List[int]] = None,
120
  is_revin: bool = True,
121
- use_cpu_initialization: bool = False,
122
- rotary_interleaved: bool = False,
123
- do_expert_forecast: bool = True,
124
  residual_backcast: bool = True,
125
  do_base_forecast: bool = False,
126
- heterogeneous_moe_layer: bool = True,
127
- test_data_seq_len: int = 2880,
128
- test_data_test_len: int = 720,
129
- autoregressive_step_list: Optional[List[int]] = None,
130
  multi_forecast_head_type: str = "single",
 
 
 
131
 
 
132
  num_experts: int = 4,
133
  moe_router_topk: int = 2,
 
 
134
  moe_ffn_hidden_size: int = 4096,
135
  moe_shared_expert_intermediate_size: int = 4096,
136
- init_method_std: float = 0.06,
137
- initializer_range: float = 0.02,
138
  moe_router_enable_expert_bias: bool = False,
139
  moe_expert_final_layernorm: bool = True,
140
- transformer_input_layernorm: bool = True,
141
- moe_router_pre_softmax: bool = True,
142
- q_layernorm: bool = False,
143
- k_layernorm: bool = False,
144
- moe_router_score_function: str = "softmax",
145
- tie_word_embeddings: bool = False,
 
 
 
 
 
146
  **kwargs,
147
  ):
148
  """Initialize FalconTST configuration."""
149
 
150
- # Set default values for list parameters
151
- if patch_size_list is None:
152
- patch_size_list = [96, 64, 48, 24]
153
- if multi_forecast_head_list is None:
154
- multi_forecast_head_list = [24, 96, 336]
155
- if autoregressive_step_list is None:
156
- autoregressive_step_list = [2, 4, 1]
157
-
158
- # FalconTST inference specific
159
- self.test_data_seq_len = test_data_seq_len
160
- self.inference_length = test_data_test_len
161
- self.autoregressive_step_list = autoregressive_step_list
162
- self.multi_forecast_head_type = multi_forecast_head_type
163
- self.use_cache = True
164
-
165
- # FalconTST specific
166
  self.hidden_size = hidden_size
167
  self.ffn_hidden_size = ffn_hidden_size
168
  self.num_attention_heads = num_attention_heads
169
- self.init_method_std = init_method_std
170
- self.initializer_range = initializer_range
171
  self.seq_length = seq_length
172
- self.multi_forecast_head_list = multi_forecast_head_list
173
- self.kv_channels=self.hidden_size // self.num_attention_heads
174
- self.rotary_base = rope_theta
175
- self.num_hidden_layers = num_hidden_layers
176
  self.mask_pad_value = mask_pad_value
177
- self.pred_length = max(self.multi_forecast_head_list)
178
- self.add_bias_linear = add_bias_linear
179
  self.is_revin = is_revin
 
 
 
 
 
180
  self.do_base_forecast = do_base_forecast
181
  self.do_expert_forecast = do_expert_forecast
182
- self.residual_backcast = residual_backcast
183
  self.heterogeneous_moe_layer = heterogeneous_moe_layer
184
- self.use_cpu_initialization = use_cpu_initialization
 
 
 
 
 
 
185
  self.rotary_interleaved = rotary_interleaved
186
-
187
- # expert specific
188
- self.patch_size_list = patch_size_list
189
  self.num_moe_experts = num_experts
190
- self.shared_patch_size = shared_patch_size
191
- self.expert_num_layers = expert_num_layers
192
- self.moe_router_input_size = self.seq_length
193
  self.moe_router_topk = moe_router_topk
 
 
194
  self.moe_router_score_function = moe_router_score_function
195
  self.moe_ffn_hidden_size = moe_ffn_hidden_size
196
- self.moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size
197
  self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
198
  self.moe_expert_final_layernorm = moe_expert_final_layernorm
199
- self.transformer_input_layernorm = transformer_input_layernorm
200
- self.moe_router_pre_softmax = moe_router_pre_softmax
201
- self.q_layernorm = q_layernorm
202
- self.k_layernorm = k_layernorm
203
 
 
 
 
 
 
 
 
 
 
 
 
204
 
 
205
 
206
- kwargs.pop('tie_word_embeddings', None)
207
  super().__init__(
208
- tie_word_embeddings=tie_word_embeddings,
209
  **kwargs,
210
  )
211
 
 
100
  """
101
 
102
  model_type = "FalconTST"
 
103
 
104
  def __init__(
105
  self,
106
+
107
+ # model configs
 
108
  add_bias_linear: bool = False,
 
109
  num_hidden_layers: int = 3,
110
+ hidden_size: int = 1024,
111
+ ffn_hidden_size: int = 4096,
112
  num_attention_heads: int = 16,
113
+ seq_length: int = 2880,
114
  mask_pad_value: float = 255.0,
 
 
 
 
 
115
  is_revin: bool = True,
116
+ shared_patch_size: int = 32,
117
+ patch_size_list: Optional[List[int]] = None,
 
118
  residual_backcast: bool = True,
119
  do_base_forecast: bool = False,
120
+ do_expert_forecast: bool = True,
121
+ heterogeneous_moe_layer: bool = False,
122
+ expert_num_layers: int = 4,
123
+ multi_forecast_head_list: Optional[List[int]] = None,
124
  multi_forecast_head_type: str = "single",
125
+ rope_theta: int = 1000000,
126
+ rotary_interleaved: bool = False,
127
+ block_input_layernorm: bool = True,
128
 
129
+ # moe configs
130
  num_experts: int = 4,
131
  moe_router_topk: int = 2,
132
+ moe_router_pre_softmax: bool = True,
133
+ moe_router_score_function: str = "softmax",
134
  moe_ffn_hidden_size: int = 4096,
135
  moe_shared_expert_intermediate_size: int = 4096,
 
 
136
  moe_router_enable_expert_bias: bool = False,
137
  moe_expert_final_layernorm: bool = True,
138
+
139
+ # initial configs
140
+ use_cpu_initialization: bool = False,
141
+ init_method_std: float = 0.06,
142
+ initializer_range: float = 0.02,
143
+
144
+ # test configs
145
+ test_data_seq_len: int = 2880,
146
+ test_data_test_len: int = 720,
147
+ autoregressive_step_list: Optional[List[int]] = None,
148
+
149
  **kwargs,
150
  ):
151
  """Initialize FalconTST configuration."""
152
 
153
+ # model configs
154
+ self.add_bias_linear = add_bias_linear
155
+ self.num_hidden_layers = num_hidden_layers
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  self.hidden_size = hidden_size
157
  self.ffn_hidden_size = ffn_hidden_size
158
  self.num_attention_heads = num_attention_heads
159
+ self.kv_channels = self.hidden_size // self.num_attention_heads
 
160
  self.seq_length = seq_length
 
 
 
 
161
  self.mask_pad_value = mask_pad_value
 
 
162
  self.is_revin = is_revin
163
+ self.shared_patch_size = shared_patch_size
164
+ if patch_size_list is None:
165
+ patch_size_list = [96, 64, 48, 24]
166
+ self.patch_size_list = patch_size_list
167
+ self.residual_backcast = residual_backcast
168
  self.do_base_forecast = do_base_forecast
169
  self.do_expert_forecast = do_expert_forecast
 
170
  self.heterogeneous_moe_layer = heterogeneous_moe_layer
171
+ self.expert_num_layers = expert_num_layers
172
+ if multi_forecast_head_list is None:
173
+ multi_forecast_head_list = [24, 96, 336]
174
+ self.multi_forecast_head_list = multi_forecast_head_list
175
+ self.pred_length = max(self.multi_forecast_head_list)
176
+ self.multi_forecast_head_type = multi_forecast_head_type
177
+ self.rotary_base = rope_theta
178
  self.rotary_interleaved = rotary_interleaved
179
+ self.block_input_layernorm = block_input_layernorm
180
+
181
+ # moe configs
182
  self.num_moe_experts = num_experts
 
 
 
183
  self.moe_router_topk = moe_router_topk
184
+ self.moe_router_input_size = self.seq_length
185
+ self.moe_router_pre_softmax = moe_router_pre_softmax
186
  self.moe_router_score_function = moe_router_score_function
187
  self.moe_ffn_hidden_size = moe_ffn_hidden_size
188
+ self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size
189
  self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
190
  self.moe_expert_final_layernorm = moe_expert_final_layernorm
 
 
 
 
191
 
192
+ # initial configs
193
+ self.use_cpu_initialization = use_cpu_initialization
194
+ self.init_method_std = init_method_std
195
+ self.initializer_range = initializer_range
196
+
197
+ # test configs
198
+ self.test_data_seq_len = test_data_seq_len
199
+ self.inference_length = test_data_test_len
200
+ if autoregressive_step_list is None:
201
+ autoregressive_step_list = [2, 4, 1]
202
+ self.autoregressive_step_list = autoregressive_step_list
203
 
204
+ self.use_cache = True
205
 
 
206
  super().__init__(
 
207
  **kwargs,
208
  )
209
 
generation_config.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "_from_model_config": true,
3
- "transformers_version": "4.40.1"
4
- }
 
 
 
 
 
model-00002-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e5a15d1fcb6388aed06deb70f77918cd38899476dc0c4b1ac7dc57391cf8a477
3
- size 1264771376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a7689d19b8af45f5261b86f22d6d57ffa6feb1690170c1afc6a43d8be8f46ca
3
+ size 1264777232
model.safetensors.index.json CHANGED
@@ -1,292 +1,293 @@
1
  {
2
- "metadata": {
3
- "total_size": 4983109888
4
- },
5
- "weight_map": {
6
- "model.decoder.layers.0.router.weight": "model-00001-of-00002.safetensors",
7
- "model.decoder.layers.0.backcast_layernorm.weight": "model-00001-of-00002.safetensors",
8
- "model.decoder.layers.0.experts.local_experts.0.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
9
- "model.decoder.layers.0.experts.local_experts.0.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
10
- "model.decoder.layers.0.experts.local_experts.0.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
11
- "model.decoder.layers.0.experts.local_experts.0.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
12
- "model.decoder.layers.0.experts.local_experts.0.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
13
- "model.decoder.layers.0.experts.local_experts.0.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
14
- "model.decoder.layers.0.experts.local_experts.0.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
15
- "model.decoder.layers.0.experts.local_experts.0.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
16
- "model.decoder.layers.0.experts.local_experts.0.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
17
- "model.decoder.layers.0.experts.local_experts.0.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
18
- "model.decoder.layers.0.experts.local_experts.0.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
19
- "model.decoder.layers.0.experts.local_experts.0.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
20
- "model.decoder.layers.0.experts.local_experts.0.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
21
- "model.decoder.layers.0.experts.local_experts.0.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
22
- "model.decoder.layers.0.experts.local_experts.0.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
23
- "model.decoder.layers.0.experts.local_experts.0.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
24
- "model.decoder.layers.0.experts.local_experts.0.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
25
- "model.decoder.layers.0.experts.local_experts.0.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
26
- "model.decoder.layers.0.experts.local_experts.0.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
27
- "model.decoder.layers.0.experts.local_experts.0.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
28
- "model.decoder.layers.0.experts.local_experts.0.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
29
- "model.decoder.layers.0.experts.local_experts.0.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
30
- "model.decoder.layers.0.experts.local_experts.0.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
31
- "model.decoder.layers.0.experts.local_experts.0.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
32
- "model.decoder.layers.0.experts.local_experts.0.final_layernorm.weight": "model-00001-of-00002.safetensors",
33
- "model.decoder.layers.0.experts.local_experts.0.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
34
- "model.decoder.layers.0.experts.local_experts.0.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
35
- "model.decoder.layers.0.experts.local_experts.0.output_layer.weight": "model-00001-of-00002.safetensors",
36
- "model.decoder.layers.0.experts.local_experts.1.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
37
- "model.decoder.layers.0.experts.local_experts.1.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
38
- "model.decoder.layers.0.experts.local_experts.1.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
39
- "model.decoder.layers.0.experts.local_experts.1.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
40
- "model.decoder.layers.0.experts.local_experts.1.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
41
- "model.decoder.layers.0.experts.local_experts.1.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
42
- "model.decoder.layers.0.experts.local_experts.1.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
43
- "model.decoder.layers.0.experts.local_experts.1.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
44
- "model.decoder.layers.0.experts.local_experts.1.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
45
- "model.decoder.layers.0.experts.local_experts.1.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
46
- "model.decoder.layers.0.experts.local_experts.1.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
47
- "model.decoder.layers.0.experts.local_experts.1.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
48
- "model.decoder.layers.0.experts.local_experts.1.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
49
- "model.decoder.layers.0.experts.local_experts.1.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
50
- "model.decoder.layers.0.experts.local_experts.1.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
51
- "model.decoder.layers.0.experts.local_experts.1.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
52
- "model.decoder.layers.0.experts.local_experts.1.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
53
- "model.decoder.layers.0.experts.local_experts.1.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
54
- "model.decoder.layers.0.experts.local_experts.1.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
55
- "model.decoder.layers.0.experts.local_experts.1.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
56
- "model.decoder.layers.0.experts.local_experts.1.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
57
- "model.decoder.layers.0.experts.local_experts.1.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
58
- "model.decoder.layers.0.experts.local_experts.1.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
59
- "model.decoder.layers.0.experts.local_experts.1.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
60
- "model.decoder.layers.0.experts.local_experts.1.final_layernorm.weight": "model-00001-of-00002.safetensors",
61
- "model.decoder.layers.0.experts.local_experts.1.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
62
- "model.decoder.layers.0.experts.local_experts.1.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
63
- "model.decoder.layers.0.experts.local_experts.1.output_layer.weight": "model-00001-of-00002.safetensors",
64
- "model.decoder.layers.0.experts.local_experts.2.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
65
- "model.decoder.layers.0.experts.local_experts.2.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
66
- "model.decoder.layers.0.experts.local_experts.2.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
67
- "model.decoder.layers.0.experts.local_experts.2.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
68
- "model.decoder.layers.0.experts.local_experts.2.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
69
- "model.decoder.layers.0.experts.local_experts.2.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
70
- "model.decoder.layers.0.experts.local_experts.2.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
71
- "model.decoder.layers.0.experts.local_experts.2.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
72
- "model.decoder.layers.0.experts.local_experts.2.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
73
- "model.decoder.layers.0.experts.local_experts.2.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
74
- "model.decoder.layers.0.experts.local_experts.2.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
75
- "model.decoder.layers.0.experts.local_experts.2.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
76
- "model.decoder.layers.0.experts.local_experts.2.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
77
- "model.decoder.layers.0.experts.local_experts.2.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
78
- "model.decoder.layers.0.experts.local_experts.2.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
79
- "model.decoder.layers.0.experts.local_experts.2.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
80
- "model.decoder.layers.0.experts.local_experts.2.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
81
- "model.decoder.layers.0.experts.local_experts.2.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
82
- "model.decoder.layers.0.experts.local_experts.2.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
83
- "model.decoder.layers.0.experts.local_experts.2.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
84
- "model.decoder.layers.0.experts.local_experts.2.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
85
- "model.decoder.layers.0.experts.local_experts.2.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
86
- "model.decoder.layers.0.experts.local_experts.2.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
87
- "model.decoder.layers.0.experts.local_experts.2.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
88
- "model.decoder.layers.0.experts.local_experts.2.final_layernorm.weight": "model-00001-of-00002.safetensors",
89
- "model.decoder.layers.0.experts.local_experts.2.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
90
- "model.decoder.layers.0.experts.local_experts.2.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
91
- "model.decoder.layers.0.experts.local_experts.2.output_layer.weight": "model-00001-of-00002.safetensors",
92
- "model.decoder.layers.0.experts.local_experts.3.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
93
- "model.decoder.layers.0.experts.local_experts.3.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
94
- "model.decoder.layers.0.experts.local_experts.3.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
95
- "model.decoder.layers.0.experts.local_experts.3.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
96
- "model.decoder.layers.0.experts.local_experts.3.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
97
- "model.decoder.layers.0.experts.local_experts.3.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
98
- "model.decoder.layers.0.experts.local_experts.3.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
99
- "model.decoder.layers.0.experts.local_experts.3.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
100
- "model.decoder.layers.0.experts.local_experts.3.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
101
- "model.decoder.layers.0.experts.local_experts.3.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
102
- "model.decoder.layers.0.experts.local_experts.3.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
103
- "model.decoder.layers.0.experts.local_experts.3.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
104
- "model.decoder.layers.0.experts.local_experts.3.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
105
- "model.decoder.layers.0.experts.local_experts.3.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
106
- "model.decoder.layers.0.experts.local_experts.3.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
107
- "model.decoder.layers.0.experts.local_experts.3.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
108
- "model.decoder.layers.0.experts.local_experts.3.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
109
- "model.decoder.layers.0.experts.local_experts.3.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
110
- "model.decoder.layers.0.experts.local_experts.3.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
111
- "model.decoder.layers.0.experts.local_experts.3.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
112
- "model.decoder.layers.0.experts.local_experts.3.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
113
- "model.decoder.layers.0.experts.local_experts.3.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
114
- "model.decoder.layers.0.experts.local_experts.3.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
115
- "model.decoder.layers.0.experts.local_experts.3.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
116
- "model.decoder.layers.0.experts.local_experts.3.final_layernorm.weight": "model-00001-of-00002.safetensors",
117
- "model.decoder.layers.0.experts.local_experts.3.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
118
- "model.decoder.layers.0.experts.local_experts.3.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
119
- "model.decoder.layers.0.experts.local_experts.3.output_layer.weight": "model-00001-of-00002.safetensors",
120
- "model.decoder.layers.0.shared_experts.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
121
- "model.decoder.layers.0.shared_experts.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
122
- "model.decoder.layers.0.shared_experts.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
123
- "model.decoder.layers.0.shared_experts.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
124
- "model.decoder.layers.0.shared_experts.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
125
- "model.decoder.layers.0.shared_experts.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
126
- "model.decoder.layers.0.shared_experts.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
127
- "model.decoder.layers.0.shared_experts.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
128
- "model.decoder.layers.0.shared_experts.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
129
- "model.decoder.layers.0.shared_experts.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
130
- "model.decoder.layers.0.shared_experts.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
131
- "model.decoder.layers.0.shared_experts.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
132
- "model.decoder.layers.0.shared_experts.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
133
- "model.decoder.layers.0.shared_experts.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
134
- "model.decoder.layers.0.shared_experts.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
135
- "model.decoder.layers.0.shared_experts.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
136
- "model.decoder.layers.0.shared_experts.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
137
- "model.decoder.layers.0.shared_experts.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
138
- "model.decoder.layers.0.shared_experts.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
139
- "model.decoder.layers.0.shared_experts.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
140
- "model.decoder.layers.0.shared_experts.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
141
- "model.decoder.layers.0.shared_experts.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
142
- "model.decoder.layers.0.shared_experts.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
143
- "model.decoder.layers.0.shared_experts.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
144
- "model.decoder.layers.0.shared_experts.final_layernorm.weight": "model-00001-of-00002.safetensors",
145
- "model.decoder.layers.0.shared_experts.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
146
- "model.decoder.layers.0.shared_experts.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
147
- "model.decoder.layers.0.shared_experts.output_layer.weight": "model-00001-of-00002.safetensors",
148
- "model.decoder.layers.1.router.weight": "model-00001-of-00002.safetensors",
149
- "model.decoder.layers.1.backcast_layernorm.weight": "model-00001-of-00002.safetensors",
150
- "model.decoder.layers.1.experts.local_experts.0.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
151
- "model.decoder.layers.1.experts.local_experts.0.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
152
- "model.decoder.layers.1.experts.local_experts.0.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
153
- "model.decoder.layers.1.experts.local_experts.0.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
154
- "model.decoder.layers.1.experts.local_experts.0.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
155
- "model.decoder.layers.1.experts.local_experts.0.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
156
- "model.decoder.layers.1.experts.local_experts.0.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
157
- "model.decoder.layers.1.experts.local_experts.0.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
158
- "model.decoder.layers.1.experts.local_experts.0.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
159
- "model.decoder.layers.1.experts.local_experts.0.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
160
- "model.decoder.layers.1.experts.local_experts.0.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
161
- "model.decoder.layers.1.experts.local_experts.0.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
162
- "model.decoder.layers.1.experts.local_experts.0.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
163
- "model.decoder.layers.1.experts.local_experts.0.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
164
- "model.decoder.layers.1.experts.local_experts.0.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
165
- "model.decoder.layers.1.experts.local_experts.0.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
166
- "model.decoder.layers.1.experts.local_experts.0.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
167
- "model.decoder.layers.1.experts.local_experts.0.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
168
- "model.decoder.layers.1.experts.local_experts.0.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
169
- "model.decoder.layers.1.experts.local_experts.0.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
170
- "model.decoder.layers.1.experts.local_experts.0.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
171
- "model.decoder.layers.1.experts.local_experts.0.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
172
- "model.decoder.layers.1.experts.local_experts.0.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
173
- "model.decoder.layers.1.experts.local_experts.0.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
174
- "model.decoder.layers.1.experts.local_experts.0.final_layernorm.weight": "model-00001-of-00002.safetensors",
175
- "model.decoder.layers.1.experts.local_experts.0.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
176
- "model.decoder.layers.1.experts.local_experts.0.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
177
- "model.decoder.layers.1.experts.local_experts.0.output_layer.weight": "model-00001-of-00002.safetensors",
178
- "model.decoder.layers.1.experts.local_experts.1.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
179
- "model.decoder.layers.1.experts.local_experts.1.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
180
- "model.decoder.layers.1.experts.local_experts.1.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
181
- "model.decoder.layers.1.experts.local_experts.1.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
182
- "model.decoder.layers.1.experts.local_experts.1.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
183
- "model.decoder.layers.1.experts.local_experts.1.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
184
- "model.decoder.layers.1.experts.local_experts.1.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
185
- "model.decoder.layers.1.experts.local_experts.1.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
186
- "model.decoder.layers.1.experts.local_experts.1.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
187
- "model.decoder.layers.1.experts.local_experts.1.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
188
- "model.decoder.layers.1.experts.local_experts.1.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
189
- "model.decoder.layers.1.experts.local_experts.1.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
190
- "model.decoder.layers.1.experts.local_experts.1.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
191
- "model.decoder.layers.1.experts.local_experts.1.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
192
- "model.decoder.layers.1.experts.local_experts.1.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
193
- "model.decoder.layers.1.experts.local_experts.1.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
194
- "model.decoder.layers.1.experts.local_experts.1.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
195
- "model.decoder.layers.1.experts.local_experts.1.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
196
- "model.decoder.layers.1.experts.local_experts.1.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
197
- "model.decoder.layers.1.experts.local_experts.1.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
198
- "model.decoder.layers.1.experts.local_experts.1.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
199
- "model.decoder.layers.1.experts.local_experts.1.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
200
- "model.decoder.layers.1.experts.local_experts.1.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
201
- "model.decoder.layers.1.experts.local_experts.1.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
202
- "model.decoder.layers.1.experts.local_experts.1.final_layernorm.weight": "model-00001-of-00002.safetensors",
203
- "model.decoder.layers.1.experts.local_experts.1.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
204
- "model.decoder.layers.1.experts.local_experts.1.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
205
- "model.decoder.layers.1.experts.local_experts.1.output_layer.weight": "model-00001-of-00002.safetensors",
206
- "model.decoder.layers.1.experts.local_experts.2.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
207
- "model.decoder.layers.1.experts.local_experts.2.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
208
- "model.decoder.layers.1.experts.local_experts.2.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
209
- "model.decoder.layers.1.experts.local_experts.2.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
210
- "model.decoder.layers.1.experts.local_experts.2.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
211
- "model.decoder.layers.1.experts.local_experts.2.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
212
- "model.decoder.layers.1.experts.local_experts.2.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
213
- "model.decoder.layers.1.experts.local_experts.2.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
214
- "model.decoder.layers.1.experts.local_experts.2.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
215
- "model.decoder.layers.1.experts.local_experts.2.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
216
- "model.decoder.layers.1.experts.local_experts.2.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
217
- "model.decoder.layers.1.experts.local_experts.2.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
218
- "model.decoder.layers.1.experts.local_experts.2.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
219
- "model.decoder.layers.1.experts.local_experts.2.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
220
- "model.decoder.layers.1.experts.local_experts.2.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
221
- "model.decoder.layers.1.experts.local_experts.2.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
222
- "model.decoder.layers.1.experts.local_experts.2.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
223
- "model.decoder.layers.1.experts.local_experts.2.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
224
- "model.decoder.layers.1.experts.local_experts.2.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
225
- "model.decoder.layers.1.experts.local_experts.2.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
226
- "model.decoder.layers.1.experts.local_experts.2.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
227
- "model.decoder.layers.1.experts.local_experts.2.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
228
- "model.decoder.layers.1.experts.local_experts.2.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
229
- "model.decoder.layers.1.experts.local_experts.2.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
230
- "model.decoder.layers.1.experts.local_experts.2.final_layernorm.weight": "model-00001-of-00002.safetensors",
231
- "model.decoder.layers.1.experts.local_experts.2.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
232
- "model.decoder.layers.1.experts.local_experts.2.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
233
- "model.decoder.layers.1.experts.local_experts.2.output_layer.weight": "model-00001-of-00002.safetensors",
234
- "model.decoder.layers.1.experts.local_experts.3.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
235
- "model.decoder.layers.1.experts.local_experts.3.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
236
- "model.decoder.layers.1.experts.local_experts.3.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
237
- "model.decoder.layers.1.experts.local_experts.3.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
238
- "model.decoder.layers.1.experts.local_experts.3.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
239
- "model.decoder.layers.1.experts.local_experts.3.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
240
- "model.decoder.layers.1.experts.local_experts.3.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
241
- "model.decoder.layers.1.experts.local_experts.3.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
242
- "model.decoder.layers.1.experts.local_experts.3.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
243
- "model.decoder.layers.1.experts.local_experts.3.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
244
- "model.decoder.layers.1.experts.local_experts.3.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
245
- "model.decoder.layers.1.experts.local_experts.3.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
246
- "model.decoder.layers.1.experts.local_experts.3.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
247
- "model.decoder.layers.1.experts.local_experts.3.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
248
- "model.decoder.layers.1.experts.local_experts.3.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
249
- "model.decoder.layers.1.experts.local_experts.3.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
250
- "model.decoder.layers.1.experts.local_experts.3.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
251
- "model.decoder.layers.1.experts.local_experts.3.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
252
- "model.decoder.layers.1.experts.local_experts.3.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
253
- "model.decoder.layers.1.experts.local_experts.3.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
254
- "model.decoder.layers.1.experts.local_experts.3.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
255
- "model.decoder.layers.1.experts.local_experts.3.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
256
- "model.decoder.layers.1.experts.local_experts.3.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
257
- "model.decoder.layers.1.experts.local_experts.3.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
258
- "model.decoder.layers.1.experts.local_experts.3.final_layernorm.weight": "model-00001-of-00002.safetensors",
259
- "model.decoder.layers.1.experts.local_experts.3.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
260
- "model.decoder.layers.1.experts.local_experts.3.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
261
- "model.decoder.layers.1.experts.local_experts.3.output_layer.weight": "model-00002-of-00002.safetensors",
262
- "model.decoder.layers.1.shared_experts.layers.0.input_layernorm.weight": "model-00002-of-00002.safetensors",
263
- "model.decoder.layers.1.shared_experts.layers.0.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
264
- "model.decoder.layers.1.shared_experts.layers.0.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
265
- "model.decoder.layers.1.shared_experts.layers.0.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
266
- "model.decoder.layers.1.shared_experts.layers.0.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
267
- "model.decoder.layers.1.shared_experts.layers.0.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
268
- "model.decoder.layers.1.shared_experts.layers.1.input_layernorm.weight": "model-00002-of-00002.safetensors",
269
- "model.decoder.layers.1.shared_experts.layers.1.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
270
- "model.decoder.layers.1.shared_experts.layers.1.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
271
- "model.decoder.layers.1.shared_experts.layers.1.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
272
- "model.decoder.layers.1.shared_experts.layers.1.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
273
- "model.decoder.layers.1.shared_experts.layers.1.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
274
- "model.decoder.layers.1.shared_experts.layers.2.input_layernorm.weight": "model-00002-of-00002.safetensors",
275
- "model.decoder.layers.1.shared_experts.layers.2.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
276
- "model.decoder.layers.1.shared_experts.layers.2.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
277
- "model.decoder.layers.1.shared_experts.layers.2.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
278
- "model.decoder.layers.1.shared_experts.layers.2.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
279
- "model.decoder.layers.1.shared_experts.layers.2.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
280
- "model.decoder.layers.1.shared_experts.layers.3.input_layernorm.weight": "model-00002-of-00002.safetensors",
281
- "model.decoder.layers.1.shared_experts.layers.3.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
282
- "model.decoder.layers.1.shared_experts.layers.3.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
283
- "model.decoder.layers.1.shared_experts.layers.3.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
284
- "model.decoder.layers.1.shared_experts.layers.3.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
285
- "model.decoder.layers.1.shared_experts.layers.3.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
286
- "model.decoder.layers.1.shared_experts.final_layernorm.weight": "model-00002-of-00002.safetensors",
287
- "model.decoder.layers.1.shared_experts.patch_embedding.linear_fc1.weight": "model-00002-of-00002.safetensors",
288
- "model.decoder.layers.1.shared_experts.patch_embedding.linear_fc2.weight": "model-00002-of-00002.safetensors",
289
- "model.decoder.layers.1.shared_experts.output_layer.weight": "model-00002-of-00002.safetensors",
290
- "model.output_layer.weight": "model-00002-of-00002.safetensors"
291
- }
 
292
  }
 
1
  {
2
+ "metadata": {
3
+ "total_size": 4983115648
4
+ },
5
+ "weight_map": {
6
+ "model.decoder.layers.0.router.weight": "model-00001-of-00002.safetensors",
7
+ "model.decoder.layers.0.backcast_layernorm.weight": "model-00001-of-00002.safetensors",
8
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
9
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
10
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
11
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
12
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
13
+ "model.decoder.layers.0.experts.local_experts.0.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
14
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
15
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
16
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
17
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
18
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
19
+ "model.decoder.layers.0.experts.local_experts.0.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
20
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
21
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
22
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
23
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
24
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
25
+ "model.decoder.layers.0.experts.local_experts.0.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
26
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
27
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
28
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
29
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
30
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
31
+ "model.decoder.layers.0.experts.local_experts.0.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
32
+ "model.decoder.layers.0.experts.local_experts.0.final_layernorm.weight": "model-00001-of-00002.safetensors",
33
+ "model.decoder.layers.0.experts.local_experts.0.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
34
+ "model.decoder.layers.0.experts.local_experts.0.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
35
+ "model.decoder.layers.0.experts.local_experts.0.output_layer.weight": "model-00001-of-00002.safetensors",
36
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
37
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
38
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
39
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
40
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
41
+ "model.decoder.layers.0.experts.local_experts.1.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
42
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
43
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
44
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
45
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
46
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
47
+ "model.decoder.layers.0.experts.local_experts.1.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
48
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
49
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
50
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
51
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
52
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
53
+ "model.decoder.layers.0.experts.local_experts.1.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
54
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
55
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
56
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
57
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
58
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
59
+ "model.decoder.layers.0.experts.local_experts.1.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
60
+ "model.decoder.layers.0.experts.local_experts.1.final_layernorm.weight": "model-00001-of-00002.safetensors",
61
+ "model.decoder.layers.0.experts.local_experts.1.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
62
+ "model.decoder.layers.0.experts.local_experts.1.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
63
+ "model.decoder.layers.0.experts.local_experts.1.output_layer.weight": "model-00001-of-00002.safetensors",
64
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
65
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
66
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
67
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
68
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
69
+ "model.decoder.layers.0.experts.local_experts.2.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
70
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
71
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
72
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
73
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
74
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
75
+ "model.decoder.layers.0.experts.local_experts.2.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
76
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
77
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
78
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
79
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
80
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
81
+ "model.decoder.layers.0.experts.local_experts.2.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
82
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
83
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
84
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
85
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
86
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
87
+ "model.decoder.layers.0.experts.local_experts.2.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
88
+ "model.decoder.layers.0.experts.local_experts.2.final_layernorm.weight": "model-00001-of-00002.safetensors",
89
+ "model.decoder.layers.0.experts.local_experts.2.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
90
+ "model.decoder.layers.0.experts.local_experts.2.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
91
+ "model.decoder.layers.0.experts.local_experts.2.output_layer.weight": "model-00001-of-00002.safetensors",
92
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
93
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
94
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
95
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
96
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
97
+ "model.decoder.layers.0.experts.local_experts.3.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
98
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
99
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
100
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
101
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
102
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
103
+ "model.decoder.layers.0.experts.local_experts.3.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
104
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
105
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
106
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
107
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
108
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
109
+ "model.decoder.layers.0.experts.local_experts.3.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
110
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
111
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
112
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
113
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
114
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
115
+ "model.decoder.layers.0.experts.local_experts.3.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
116
+ "model.decoder.layers.0.experts.local_experts.3.final_layernorm.weight": "model-00001-of-00002.safetensors",
117
+ "model.decoder.layers.0.experts.local_experts.3.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
118
+ "model.decoder.layers.0.experts.local_experts.3.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
119
+ "model.decoder.layers.0.experts.local_experts.3.output_layer.weight": "model-00001-of-00002.safetensors",
120
+ "model.decoder.layers.0.shared_experts.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
121
+ "model.decoder.layers.0.shared_experts.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
122
+ "model.decoder.layers.0.shared_experts.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
123
+ "model.decoder.layers.0.shared_experts.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
124
+ "model.decoder.layers.0.shared_experts.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
125
+ "model.decoder.layers.0.shared_experts.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
126
+ "model.decoder.layers.0.shared_experts.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
127
+ "model.decoder.layers.0.shared_experts.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
128
+ "model.decoder.layers.0.shared_experts.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
129
+ "model.decoder.layers.0.shared_experts.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
130
+ "model.decoder.layers.0.shared_experts.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
131
+ "model.decoder.layers.0.shared_experts.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
132
+ "model.decoder.layers.0.shared_experts.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
133
+ "model.decoder.layers.0.shared_experts.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
134
+ "model.decoder.layers.0.shared_experts.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
135
+ "model.decoder.layers.0.shared_experts.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
136
+ "model.decoder.layers.0.shared_experts.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
137
+ "model.decoder.layers.0.shared_experts.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
138
+ "model.decoder.layers.0.shared_experts.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
139
+ "model.decoder.layers.0.shared_experts.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
140
+ "model.decoder.layers.0.shared_experts.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
141
+ "model.decoder.layers.0.shared_experts.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
142
+ "model.decoder.layers.0.shared_experts.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
143
+ "model.decoder.layers.0.shared_experts.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
144
+ "model.decoder.layers.0.shared_experts.final_layernorm.weight": "model-00001-of-00002.safetensors",
145
+ "model.decoder.layers.0.shared_experts.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
146
+ "model.decoder.layers.0.shared_experts.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
147
+ "model.decoder.layers.0.shared_experts.output_layer.weight": "model-00001-of-00002.safetensors",
148
+ "model.decoder.layers.1.router.weight": "model-00001-of-00002.safetensors",
149
+ "model.decoder.layers.1.backcast_layernorm.weight": "model-00001-of-00002.safetensors",
150
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
151
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
152
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
153
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
154
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
155
+ "model.decoder.layers.1.experts.local_experts.0.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
156
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
157
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
158
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
159
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
160
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
161
+ "model.decoder.layers.1.experts.local_experts.0.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
162
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
163
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
164
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
165
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
166
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
167
+ "model.decoder.layers.1.experts.local_experts.0.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
168
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
169
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
170
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
171
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
172
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
173
+ "model.decoder.layers.1.experts.local_experts.0.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
174
+ "model.decoder.layers.1.experts.local_experts.0.final_layernorm.weight": "model-00001-of-00002.safetensors",
175
+ "model.decoder.layers.1.experts.local_experts.0.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
176
+ "model.decoder.layers.1.experts.local_experts.0.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
177
+ "model.decoder.layers.1.experts.local_experts.0.output_layer.weight": "model-00001-of-00002.safetensors",
178
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
179
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
180
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
181
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
182
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
183
+ "model.decoder.layers.1.experts.local_experts.1.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
184
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
185
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
186
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
187
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
188
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
189
+ "model.decoder.layers.1.experts.local_experts.1.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
190
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
191
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
192
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
193
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
194
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
195
+ "model.decoder.layers.1.experts.local_experts.1.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
196
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
197
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
198
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
199
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
200
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
201
+ "model.decoder.layers.1.experts.local_experts.1.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
202
+ "model.decoder.layers.1.experts.local_experts.1.final_layernorm.weight": "model-00001-of-00002.safetensors",
203
+ "model.decoder.layers.1.experts.local_experts.1.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
204
+ "model.decoder.layers.1.experts.local_experts.1.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
205
+ "model.decoder.layers.1.experts.local_experts.1.output_layer.weight": "model-00001-of-00002.safetensors",
206
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
207
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
208
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
209
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
210
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
211
+ "model.decoder.layers.1.experts.local_experts.2.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
212
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
213
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
214
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
215
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
216
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
217
+ "model.decoder.layers.1.experts.local_experts.2.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
218
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
219
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
220
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
221
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
222
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
223
+ "model.decoder.layers.1.experts.local_experts.2.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
224
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
225
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
226
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
227
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
228
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
229
+ "model.decoder.layers.1.experts.local_experts.2.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
230
+ "model.decoder.layers.1.experts.local_experts.2.final_layernorm.weight": "model-00001-of-00002.safetensors",
231
+ "model.decoder.layers.1.experts.local_experts.2.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
232
+ "model.decoder.layers.1.experts.local_experts.2.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
233
+ "model.decoder.layers.1.experts.local_experts.2.output_layer.weight": "model-00001-of-00002.safetensors",
234
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
235
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
236
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
237
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
238
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
239
+ "model.decoder.layers.1.experts.local_experts.3.layers.0.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
240
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
241
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
242
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
243
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
244
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
245
+ "model.decoder.layers.1.experts.local_experts.3.layers.1.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
246
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
247
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
248
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
249
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
250
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
251
+ "model.decoder.layers.1.experts.local_experts.3.layers.2.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
252
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
253
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.self_attention.linear_proj.weight": "model-00001-of-00002.safetensors",
254
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.self_attention.linear_qkv.weight": "model-00001-of-00002.safetensors",
255
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.pre_mlp_layernorm.weight": "model-00001-of-00002.safetensors",
256
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.mlp.linear_fc1.weight": "model-00001-of-00002.safetensors",
257
+ "model.decoder.layers.1.experts.local_experts.3.layers.3.mlp.linear_fc2.weight": "model-00001-of-00002.safetensors",
258
+ "model.decoder.layers.1.experts.local_experts.3.final_layernorm.weight": "model-00001-of-00002.safetensors",
259
+ "model.decoder.layers.1.experts.local_experts.3.patch_embedding.linear_fc1.weight": "model-00001-of-00002.safetensors",
260
+ "model.decoder.layers.1.experts.local_experts.3.patch_embedding.linear_fc2.weight": "model-00001-of-00002.safetensors",
261
+ "model.decoder.layers.1.experts.local_experts.3.output_layer.weight": "model-00002-of-00002.safetensors",
262
+ "model.decoder.layers.1.shared_experts.layers.0.input_layernorm.weight": "model-00002-of-00002.safetensors",
263
+ "model.decoder.layers.1.shared_experts.layers.0.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
264
+ "model.decoder.layers.1.shared_experts.layers.0.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
265
+ "model.decoder.layers.1.shared_experts.layers.0.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
266
+ "model.decoder.layers.1.shared_experts.layers.0.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
267
+ "model.decoder.layers.1.shared_experts.layers.0.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
268
+ "model.decoder.layers.1.shared_experts.layers.1.input_layernorm.weight": "model-00002-of-00002.safetensors",
269
+ "model.decoder.layers.1.shared_experts.layers.1.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
270
+ "model.decoder.layers.1.shared_experts.layers.1.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
271
+ "model.decoder.layers.1.shared_experts.layers.1.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
272
+ "model.decoder.layers.1.shared_experts.layers.1.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
273
+ "model.decoder.layers.1.shared_experts.layers.1.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
274
+ "model.decoder.layers.1.shared_experts.layers.2.input_layernorm.weight": "model-00002-of-00002.safetensors",
275
+ "model.decoder.layers.1.shared_experts.layers.2.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
276
+ "model.decoder.layers.1.shared_experts.layers.2.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
277
+ "model.decoder.layers.1.shared_experts.layers.2.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
278
+ "model.decoder.layers.1.shared_experts.layers.2.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
279
+ "model.decoder.layers.1.shared_experts.layers.2.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
280
+ "model.decoder.layers.1.shared_experts.layers.3.input_layernorm.weight": "model-00002-of-00002.safetensors",
281
+ "model.decoder.layers.1.shared_experts.layers.3.self_attention.linear_proj.weight": "model-00002-of-00002.safetensors",
282
+ "model.decoder.layers.1.shared_experts.layers.3.self_attention.linear_qkv.weight": "model-00002-of-00002.safetensors",
283
+ "model.decoder.layers.1.shared_experts.layers.3.pre_mlp_layernorm.weight": "model-00002-of-00002.safetensors",
284
+ "model.decoder.layers.1.shared_experts.layers.3.mlp.linear_fc1.weight": "model-00002-of-00002.safetensors",
285
+ "model.decoder.layers.1.shared_experts.layers.3.mlp.linear_fc2.weight": "model-00002-of-00002.safetensors",
286
+ "model.decoder.layers.1.shared_experts.final_layernorm.weight": "model-00002-of-00002.safetensors",
287
+ "model.decoder.layers.1.shared_experts.patch_embedding.linear_fc1.weight": "model-00002-of-00002.safetensors",
288
+ "model.decoder.layers.1.shared_experts.patch_embedding.linear_fc2.weight": "model-00002-of-00002.safetensors",
289
+ "model.decoder.layers.1.shared_experts.output_layer.weight": "model-00002-of-00002.safetensors",
290
+ "model.decoder.input_layernorm.weight": "model-00002-of-00002.safetensors",
291
+ "model.output_layer.weight": "model-00002-of-00002.safetensors"
292
+ }
293
  }
modeling_FalconTST.py CHANGED
@@ -10,7 +10,6 @@ from einops import rearrange, repeat
10
  from functools import reduce
11
  from abc import ABC, abstractmethod
12
  from .configuration_FalconTST import FalconTSTConfig
13
- from .ts_generation_mixin import FalconTSTGenerationMixin
14
  from transformers import PreTrainedModel, Cache, DynamicCache
15
  from transformers.activations import ACT2FN
16
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
@@ -74,63 +73,6 @@ def _apply_rotary_pos_emb_bshd(
74
  return torch.cat((t, t_pass), dim=-1)
75
 
76
 
77
- def topk_softmax_with_capacity(
78
- logits: torch.Tensor,
79
- topk: int,
80
- use_pre_softmax: bool = False,
81
- score_function: str = "softmax",
82
- expert_bias: Optional[torch.Tensor] = None,
83
- ):
84
- """Apply capacity and padding to the top-k selection.
85
- Args:
86
- logits (torch.Tensor): Logits tensor.
87
- topk (int): The number of experts to select for each token.
88
- use_pre_softmax (bool): Whether to apply softmax or sigmoid before top-k selection.
89
- score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
90
- expert_bias (torch.Tensor): The bias added to logits for expert routing.
91
- Returns:
92
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
93
- - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
94
- the routing probabilities for each token to each expert.
95
- - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts]
96
- indicating which experts were selected for each token. True values represent
97
- the selected experts.
98
- - tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing
99
- the number of local tokens assigned to each expert before dropping and padding.
100
- """
101
- assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
102
-
103
- def compute_topk(scores, topk,):
104
- return torch.topk(scores, k=topk, dim=1)
105
-
106
- if score_function == "softmax":
107
- if use_pre_softmax:
108
- scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
109
- probs, top_indices = compute_topk(scores, topk, )
110
- else:
111
- scores, top_indices = compute_topk(logits, topk, )
112
- probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
113
- elif score_function == "sigmoid":
114
- scores = torch.sigmoid(logits.float()).type_as(logits)
115
- if expert_bias is not None:
116
- scores_for_routing = scores + expert_bias
117
- _, top_indices = compute_topk(scores_for_routing, topk, )
118
- scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
119
- else:
120
- scores, top_indices = compute_topk(scores, topk,)
121
- probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
122
- else:
123
- raise ValueError(f"Invalid score_function: {score_function}")
124
-
125
- # TODO Try using element-wise operations instead of scatter?
126
- topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
127
- topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
128
- # TODO: Reset topk_map to realize load-balancing?
129
- tokens_per_expert = topk_map.sum(dim=0)
130
-
131
- return topk_masked_gates, topk_map, tokens_per_expert
132
-
133
-
134
  class RotaryEmbedding(nn.Module):
135
  """Rotary Embedding.
136
 
@@ -156,7 +98,10 @@ class RotaryEmbedding(nn.Module):
156
 
157
  dim = kv_channels
158
  self.rotary_interleaved = rotary_interleaved
159
- device = 'cpu' if use_cpu_initialization else torch.cuda.current_device()
 
 
 
160
  self.inv_freq = 1.0 / (
161
  rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
162
  )
@@ -225,11 +170,6 @@ class IdentityOp(nn.Module):
225
  return x
226
 
227
 
228
- class IdentityFuncOp(nn.Module):
229
- def forward(self, x):
230
- return x
231
-
232
-
233
  class RMSNorm(nn.Module):
234
  def __init__(self, hidden_size, eps=1e-5):
235
  super().__init__()
@@ -264,24 +204,21 @@ class TEDotProductAttention(nn.Module):
264
  self.softmax_scale = softmax_scale
265
  self.drop = nn.Dropout(attention_dropout)
266
 
267
- def forward(self, q,k,v,attention_mask,causal=None, ):
268
  """Implements the multihead softmax attention.
269
  Arguments
270
  ---------
271
- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
272
- causal: if passed, will override self.causal
273
- key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
274
- False means to mask out. (B, S)
275
  """
276
- causal = self.causal if causal is None else causal
277
-
278
  q = q.transpose(0,1).contiguous()
279
  k = k.transpose(0,1).contiguous()
280
  v = v.transpose(0,1).contiguous()
281
 
282
  batch_size, seq_len = q.shape[0], q.shape[1]
283
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
284
- # scores
285
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
286
  scores = scores.masked_fill(attention_mask == 0, float('-1e9'))
287
  # Softmax
@@ -289,42 +226,37 @@ class TEDotProductAttention(nn.Module):
289
  # Dropout
290
  attention_drop = self.drop(attention)
291
  output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
292
- output = output.reshape(batch_size, seq_len, -1).transpose(0,1).contiguous()
293
- return output
294
 
 
 
295
 
296
 
297
  class SelfAttention(nn.Module):
298
  def __init__(self,config,):
299
  super().__init__()
300
  self.config = config
301
- q_layernorm=config.q_layernorm
302
- k_layernorm=config.k_layernorm
303
  self.hidden_size = config.hidden_size
304
  self.core_attention = TEDotProductAttention()
305
  self.linear_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.add_bias_linear,)
306
  self.linear_qkv = nn.Linear(self.hidden_size, 3*self.hidden_size, bias=config.add_bias_linear,)
307
- if q_layernorm:
308
- self.q_layernorm = RMSNorm(self.hidden_size)
309
- else:
310
- self.q_layernorm = IdentityOp()
311
- if k_layernorm:
312
- self.k_layernorm = RMSNorm(self.hidden_size)
313
- else:
314
- self.k_layernorm = IdentityOp()
315
 
316
- def forward(self, x, attention_mask,rotary_pos_emb):
 
 
 
 
 
317
  qkv = self.linear_qkv(x)
318
- qkv = qkv.view(qkv.size(0), qkv.size(1), self.config.num_attention_heads,-1)
319
  q, k, v = qkv.chunk(3, dim=-1)
 
320
  # Apply rotary encoding to q and k
321
  rotary_pos_emb = (rotary_pos_emb,) * 2
322
  q_pos_emb, k_pos_emb = rotary_pos_emb
323
  q = _apply_rotary_pos_emb_bshd(q, q_pos_emb)
324
  k = _apply_rotary_pos_emb_bshd(k, k_pos_emb)
325
 
326
- q = self.q_layernorm(q)
327
- k = self.k_layernorm(k)
328
  # attention
329
  attn_output = self.core_attention(q, k, v, attention_mask)
330
  output = self.linear_proj(attn_output)
@@ -333,7 +265,7 @@ class SelfAttention(nn.Module):
333
 
334
 
335
  class MLP(nn.Module):
336
- def __init__(self,config,in_features):
337
  super().__init__()
338
  self.config= config
339
  self.linear_fc1 = nn.Linear(in_features, self.config.moe_ffn_hidden_size*2, bias=self.config.add_bias_linear,)
@@ -367,9 +299,14 @@ class TransformerLayer(nn.Module):
367
  self.input_layernorm = IdentityOp()
368
  self.self_attention = SelfAttention(config)
369
  self.pre_mlp_layernorm = RMSNorm(self.config.hidden_size)
370
- self.mlp = MLP(config,self.config.hidden_size)
371
 
372
- def forward(self, x, attention_mask,rotary_pos_emb):
 
 
 
 
 
373
  residual = x
374
  x = self.input_layernorm(x)
375
  x = self.self_attention(x, attention_mask, rotary_pos_emb)
@@ -425,7 +362,7 @@ class FalconTSTExpert(nn.Module):
425
 
426
  # Patchify the input
427
  input_data = input_data.unfold(dimension=-1, size=self.patch_size, step=self.patch_size).contiguous() # input [batch_size, patch_num, patch_size]
428
- hidden_states= self.patch_embedding(input_data) # hidden_states [batch_size, patch_num, hidden_size]
429
  hidden_states = hidden_states.transpose(0, 1).contiguous() # hidden_states [patch_num, batch_size, hidden_size], To adapt to the Megatron
430
 
431
  # Patchify the mask: only the entire time points in a patch are masked then this patch is masked
@@ -436,16 +373,13 @@ class FalconTSTExpert(nn.Module):
436
  attention_mask = attention_mask.unsqueeze(2).repeat(1,1,patch_num) * attention_mask.unsqueeze(1).repeat(1,patch_num,1) # [batch_size, patch_num, patch_num]
437
  attention_mask = attention_mask.unsqueeze(1).contiguous() # [batch_size, 1, patch_num, patch_num]
438
 
439
-
440
  return hidden_states, attention_mask, input_mask
441
 
442
-
443
- def _forward_output(self, hidden_states, output_scale=None, input_mask=None, inference_context=None):
444
  """
445
  Perform a forward pass through the output layer.
446
 
447
  Args:
448
- expert_input (Tensor): Expert input of shape [batch_size, seq_len]
449
  hidden_states (Tensor): Transformed hidden states of shape [patch_num, batch_size, hidden_size]
450
  output_scale (Tensor, optional): Expert probabilities for the output layer [batch_size]
451
  input_mask (Tensor, optional): Expert input mask of shape [batch_size, seq_len], 0:mask, 1:unmask
@@ -466,11 +400,17 @@ class FalconTSTExpert(nn.Module):
466
 
467
  return expert_output
468
 
469
- def forward(self, expert_input, rotary_pos_emb,expert_probs=None):
470
  hidden_states, attention_mask, input_mask = self._forward_patch_embedding(expert_input)
 
 
 
 
471
  for layer in self.layers:
472
- hidden_states = layer(hidden_states,attention_mask,rotary_pos_emb[:hidden_states.shape[0]])
 
473
  hidden_states = self.final_layernorm(hidden_states)
 
474
  expert_output = self._forward_output(hidden_states, expert_probs, input_mask)
475
  return expert_output
476
 
@@ -512,174 +452,47 @@ class SequentialFalconTST(nn.Module):
512
  return expert_output
513
 
514
 
515
- class RouterGatingLinearFunction(torch.autograd.Function):
516
- """
517
- Autograd function for router gating linear.
518
- """
519
-
520
- @staticmethod
521
- def forward(ctx, inp: torch.Tensor, weight: torch.Tensor, router_dtype: torch.dtype):
522
- """
523
- Forward pass of the RouterGatingLinearFunction function.
524
- """
525
- ctx.router_dtype = router_dtype
526
- ctx.input_dtype = inp.dtype
527
- ctx.weight_dtype = weight.dtype
528
- inp_shape = inp.shape
529
- inp = inp.view(-1, inp_shape[-1])
530
-
531
- output = torch.mm(inp.to(router_dtype), weight.to(router_dtype).t())
532
-
533
- output = output.view(*inp_shape[:-1], -1)
534
- return output
535
-
536
-
537
- def router_gating_linear(inp: torch.Tensor, weight: torch.Tensor, router_dtype: torch.dtype):
538
- """
539
- Customized linear layer for router gating.
540
- This linear layer accepts bfloat16 input and weight, and can return output with router_dtype.
541
- It can reduce the memory usage by avoiding saving the intermediate high precision tensors.
542
- """
543
- return RouterGatingLinearFunction.apply(inp, weight, router_dtype)
544
-
545
-
546
- class Router(ABC,nn.Module):
547
- """Base Router class"""
548
-
549
- def __init__(
550
- self, config: FalconTSTConfig,
551
- ) -> None:
552
- """
553
- Initialize the Router module.
554
-
555
- Args:
556
- config (TransformerConfig): Configuration object for the Transformer model.
557
- model_comm_pgs (ModelCommProcessGroups, optional): Process groups for MoE operations.
558
- """
559
  super().__init__()
560
  self.config = config
 
561
 
562
- # Initialize the gate weights.
563
-
564
- if self.config.patch_size_list is not None:
565
- assert self.config.moe_router_input_size is not None
566
- self.weight = torch.nn.Parameter(
567
- torch.empty((self.config.num_moe_experts, self.config.moe_router_input_size), dtype=torch.float32)
568
- )
569
- else:
570
- self.weight = torch.nn.Parameter(
571
- torch.empty((self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32)
572
- )
573
  self.reset_parameters()
574
-
575
- def reset_parameters(self):
576
- """Reset the router parameters."""
577
- torch.nn.init.normal_(self.weight,mean=0,std=self.config.init_method_std)
578
- self.weight.data = self.weight.data.to(dtype=self.config.torch_dtype)
579
-
580
 
581
- def gating(self, input: torch.Tensor):
582
- """Forward pass of the router gate.
583
-
584
- Args:
585
- input (torch.Tensor): Input tensor.
586
-
587
- Returns:
588
- torch.Tensor: Logits tensor.
589
- """
590
- if self.weight.device != input.device:
591
- self.weight = self.weight.to(input.device)
592
- router_dtype = input.dtype
593
- logits = router_gating_linear(input, self.weight, router_dtype)
594
- return logits
595
 
596
- @abstractmethod
597
  def routing(self, logits: torch.Tensor):
598
- """Routing function.
599
-
600
- Args:
601
- logits (torch.Tensor): Logits tensor.
602
-
603
- Returns:
604
- Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
605
- probabilities and mapping.
606
- """
607
- raise NotImplementedError("Routing function not implemented.")
608
-
609
- @abstractmethod
610
- def forward(self, input: torch.Tensor):
611
- """
612
- Forward pass of the router.
613
-
614
- Args:
615
- input (torch.Tensor): Input tensor.
616
- """
617
- raise NotImplementedError("Forward function not implemented.")
618
-
619
-
620
- class TopKRouter(Router):
621
- """Route each token to the top-k experts."""
622
 
623
- def __init__(
624
- self, config: FalconTSTConfig,
625
- ) -> None:
626
- """Initialize the zero token dropping router.
627
-
628
- Args:
629
- config (TransformerConfig): The configuration for the transformer model.
630
- model_comm_pgs (ModelCommProcessGroups, optional): Process groups for MoE operations.
631
- """
632
- super().__init__(config=config)
633
- self.topk = self.config.moe_router_topk
634
- self.score_function = self.config.moe_router_score_function
635
-
636
- self.enable_expert_bias = self.config.moe_router_enable_expert_bias
637
- if self.enable_expert_bias:
638
- self.register_buffer(
639
- 'local_tokens_per_expert',
640
- torch.zeros(self.config.num_moe_experts, dtype=torch.float32),
641
- persistent=False,
642
- )
643
- self.register_buffer(
644
- 'expert_bias', torch.zeros(self.config.num_moe_experts, dtype=torch.float32)
645
- )
646
  else:
647
- self.local_tokens_per_expert = None
648
- self.expert_bias = None
649
-
650
-
651
- def routing(self, logits: torch.Tensor):
652
- """Top-k routing function
653
-
654
- Args:
655
- logits (torch.Tensor): Logits tensor after gating.
656
-
657
- Returns:
658
- probs (torch.Tensor): The probabilities of token to experts assignment.
659
- routing_map (torch.Tensor): The mapping of token to experts assignment,
660
- with shape [num_tokens, num_experts].
661
- """
662
- logits = logits.view(-1, self.config.num_moe_experts)
663
-
664
- scores, routing_map, tokens_per_expert = topk_softmax_with_capacity(
665
- logits,
666
- self.topk,
667
- use_pre_softmax=self.config.moe_router_pre_softmax,
668
- score_function=self.score_function,
669
- expert_bias=self.expert_bias,
670
- )
671
- return scores, routing_map
672
 
 
 
673
  def forward(self, input: torch.Tensor):
674
- """
675
- Forward pass of the router.
676
-
677
- Args:
678
- input (torch.Tensor): Input tensor.
679
- """
680
- logits = self.gating(input)
681
 
682
- scores, routing_map = self.routing(logits)
683
 
684
  return scores, routing_map
685
 
@@ -702,8 +515,8 @@ class FalconTSTMoELayer(nn.Module):
702
  self.expert_output_size = config.seq_length
703
 
704
  if self.is_last_layer and self.config.heterogeneous_moe_layer:
705
- # If heterogeneous_moe_layer is True, the backcast will be None
706
- self.backcast_layernorm = None
707
  else:
708
  self.backcast_layernorm = RMSNorm(self.seq_length)
709
 
@@ -784,42 +597,9 @@ class FalconTSTMoELayer(nn.Module):
784
  # permuted_probs (global_probs): [num_permuted_samples_after_dispatch_postprocess(sorted)]
785
 
786
  experts_output = self.experts(input, routing_map, rotary_pos_emb, probs)
787
-
788
 
789
  return experts_output, shared_experts_output
790
 
791
- def postprocess(
792
- self,
793
- backcast: torch.Tensor, # [batch_size, seq_len]
794
- forecast: torch.Tensor, # [batch_size, pred_len]
795
- output_backcast: torch.Tensor, # [batch_size, seq_len]
796
- output_forecast: torch.Tensor, # [batch_size, pred_len]
797
- ):
798
- """
799
- Args:
800
- backcast (torch.Tensor): The previous layer's backcast time series (samples). [batch_size, seq_len]
801
- forecast (torch.Tensor): The previous layer's forecast time series (samples). [batch_size, pred_len]
802
- output_backcast (torch.Tensor): The current layer's output backcast time series (samples). [batch_size, seq_len]
803
- output_forecast (torch.Tensor): The current layer's output forecast time series (samples). [batch_size, pred_len]
804
- means (torch.Tensor): The means of the non-masked backcast time series (samples). [batch_size, 1]
805
- stdev (torch.Tensor): The standard deviation of the non-masked backcast time series (samples). [batch_size, 1]
806
- backcast_mask (torch.Tensor): The previous layer's backcast mask of time series (samples) . [batch_size, seq_len]
807
- """
808
- if output_backcast is not None:
809
- # 25/8/14 @modified by xiaming replace the revin with layernorm after the moe layer
810
- # And if we multiply the output_backcast with the input mask, the performance will be hurted
811
- output_backcast = self.backcast_layernorm(output_backcast) # LayerNorm
812
- if self.config.residual_backcast:
813
- output_backcast = backcast - output_backcast
814
-
815
- output_backcast[~self.input_mask] = self.config.mask_pad_value # Important! Recover the mask time point back to mask_pad_value(default:255.)
816
-
817
- if self.config.do_expert_forecast and forecast is not None: # The first layer's forecast is None
818
- output_forecast = forecast + output_forecast
819
-
820
- return output_backcast, output_forecast
821
-
822
-
823
  def combine(
824
  self,
825
  experts_output: torch.Tensor,
@@ -828,8 +608,7 @@ class FalconTSTMoELayer(nn.Module):
828
  """Combines expert outputs via communication and adds shared expert output.
829
 
830
  This method uses the time series(sample) dispatcher to combine the outputs from different
831
- experts (e.g., via an All-to-All communication). It then adds the output
832
- from the shared expert if it exists.
833
  """
834
  assert experts_output.shape == shared_experts_output.shape,\
835
  f'experts_output shape {experts_output.shape} doesn\'t equal to shared_experts_output shape:{shared_experts_output.shape}'
@@ -854,7 +633,36 @@ class FalconTSTMoELayer(nn.Module):
854
  return output_backcast, output_forecast
855
 
856
 
857
- def forward(self, backcast,forecast,rotary_pos_emb):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858
  inputs, probs, residual, routing_map = self.router_and_preprocess(backcast)
859
  experts_output, shared_experts_output = self.experts_compute(inputs, probs, residual, rotary_pos_emb, routing_map)
860
  output_backcast, output_forecast = self.combine(experts_output, shared_experts_output)
@@ -862,20 +670,31 @@ class FalconTSTMoELayer(nn.Module):
862
  return output_backcast, output_forecast
863
 
864
 
865
-
866
  class FalconTSTBlock(nn.Module):
867
- def __init__(self,config):
868
  super().__init__()
869
  self.config = config
 
 
 
 
 
 
870
  self.layers = nn.ModuleList([
871
- FalconTSTMoELayer(config,layer_num +1)
872
- for layer_num in range(self.config.num_hidden_layers)
873
- ])
874
- def forward(self, x,rotary_pos_emb):
 
875
  backcast = x
876
  forecast = None
 
 
 
 
 
877
  for layer in self.layers:
878
- backcast, forecast = layer(backcast,forecast,rotary_pos_emb)
879
  return backcast,forecast
880
 
881
 
@@ -900,24 +719,28 @@ class FalconTSTPreTrainedModel(PreTrainedModel):
900
  if module.padding_idx is not None:
901
  module.weight.data[module.padding_idx].zero_()
902
 
 
903
  class FalconTSTModel(FalconTSTPreTrainedModel):
904
  def __init__(self, config: FalconTSTConfig):
905
  super().__init__(config)
906
  self.config = config
907
- self.seq_length = config.seq_length
908
  self.rotary_pos_emb = RotaryEmbedding(
909
- kv_channels=self.config.kv_channels,
910
- rotary_base=config.rotary_base,
911
- use_cpu_initialization=self.config.use_cpu_initialization,
912
- rotary_interleaved=self.config.rotary_interleaved
913
  )
914
  self.decoder = FalconTSTBlock(
915
- config=config
916
- )
 
917
  if self.config.do_expert_forecast and self.config.heterogeneous_moe_layer:
918
  self.output_layer = IdentityOp()
919
  else:
920
- self.output_layer = nn.Linear(in_features=self.seq_length, out_features=self.config.pred_length, bias=self.config.add_bias_linear,)
 
 
921
 
922
 
923
  def revin(
@@ -946,13 +769,8 @@ class FalconTSTModel(FalconTSTPreTrainedModel):
946
  return input, means, stdev
947
 
948
  def forward(self, input, revin):
949
- # Apply rotary position embeddings
950
- # seq_len = patches.size(1)
951
- # pos_emb = self.rotary_pos_emb(seq_len, patches.device)
952
- # patches = patches + pos_emb
953
-
954
  batch_size, input_len = input.shape
955
- # @created by xiaming @modified by baichun
956
  # realize varied input length
957
  if input_len > self.seq_length:
958
  input = input[:, -self.seq_length:]
@@ -972,7 +790,7 @@ class FalconTSTModel(FalconTSTPreTrainedModel):
972
  rotary_pos_emb = self.rotary_pos_emb(input_len, device=input.device)
973
 
974
  # Step3. Do one-step inference to get mixed forecasts from multiple forecast heads
975
- # mixed_pred: [batch_size, sum(multi_forecast_head)]
976
  mixed_pred = self._inference_step(
977
  input=input,
978
  input_mask=input_mask,
@@ -1005,12 +823,12 @@ class FalconTSTModel(FalconTSTPreTrainedModel):
1005
  rotary_pos_emb,
1006
  ):
1007
  if self.config.do_base_forecast:
1008
- base_forecast, _ = self.base_output_layer(input)
1009
  else:
1010
  base_forecast = None
1011
 
1012
  decoder_backcast, decoder_forecast = self.decoder(
1013
- input, # [batch_size, seq_len]
1014
  rotary_pos_emb, # [input_len, 1, 1, kv_channels(hidden_size // num_heads)]
1015
  )
1016
 
@@ -1019,12 +837,12 @@ class FalconTSTModel(FalconTSTPreTrainedModel):
1019
  if self.config.heterogeneous_moe_layer:
1020
  decoder_forecast = self.output_layer(decoder_forecast) # IdentityOp
1021
  else:
1022
- final_forecast= self.output_layer(decoder_backcast * input_mask)
1023
  decoder_forecast = decoder_forecast + final_forecast
1024
  else:
1025
  # The decoder_backcast contains the mask_pad_val(default:255.)
1026
  decoder_forecast, _ = self.output_layer(decoder_backcast * input_mask)
1027
-
1028
  if self.config.do_base_forecast:
1029
  assert base_forecast is not None, f'base_forecast is None'
1030
  FalconTST_forecast = base_forecast + decoder_forecast
@@ -1080,129 +898,62 @@ class FalconTSTModel(FalconTSTPreTrainedModel):
1080
 
1081
  final_output = final_output[:, :self.config.inference_length]
1082
 
1083
- elif auto_regressive_strategy == 'from_short_to_long':
1084
- # From short to long
1085
- # in validate_args, it has been sorted, and check the valid config
1086
- multi_forecast_head_list = sorted(self.config.multi_forecast_head_list)
1087
- multi_forecast_head_dict = {}
1088
- for idx, head_pred_len in enumerate(self.config.multi_forecast_head_list):
1089
- if idx == len(multi_forecast_head_list) - 1:
1090
- ar_step = math.ceil(self.config.inference_length / head_pred_len)
1091
- else:
1092
- ar_step = min(
1093
- self.config.autoregressive_step_list[idx],
1094
- self.config.multi_forecast_head_list[idx + 1] // self.config.multi_forecast_head_list[idx]
1095
- )
1096
- # ar_step = multi_forecast_head_list[idx + 1] // multi_forecast_head_list[idx]
1097
-
1098
- multi_forecast_head_dict[head_pred_len] = ar_step
1099
-
1100
- # the core idea of strategy [from_short_to_long]
1101
- mixed_pred = FalconTST_forecast
1102
- output_list = []
1103
- cur_pred = None
1104
- cur_pred_len = 0
1105
-
1106
- # from the first(shortest) as begining
1107
- for idx, head_pred_len in enumerate(self.config.multi_forecast_head_list):
1108
- # assert cur_pred_len <= head_pred_len, \
1109
- # "Accumulated prediction length exceeds the prediction length of current forecast head"
1110
-
1111
- ar_step = multi_forecast_head_dict[head_pred_len]
1112
- if ar_step == 0:
1113
- # Ignore the current forecast head
1114
- continue
1115
-
1116
- # Add current head's first auto-regressive step of prediction
1117
- head_pred = mixed_pred[:, :head_pred_len] # [single]
1118
- output_list.append(head_pred[:, cur_pred_len:])
1119
- cur_pred = torch.cat(output_list, dim=1)
1120
- cur_pred_len = cur_pred.shape[1]
1121
- if cur_pred_len >= self.config.inference_length:
1122
- break
1123
-
1124
- # Do auto-regressive of the rest of the steps
1125
- for _ in range(1, ar_step + 1):
1126
- # one-step model prediction
1127
- cur_input = torch.cat([input, cur_pred], dim=1)[:, -self.seq_length:].contiguous()
1128
- cur_input_mask = torch.cat(
1129
- [input_mask,
1130
- torch.ones(cur_pred.shape, dtype=input_mask.dtype, device=input_mask.device)],
1131
- dim=1)[:, -self.seq_length:].contiguous() # 0:mask, 1:unmask
1132
-
1133
- FalconTST_forecast = self._inference_step(
1134
- input=cur_input,
1135
- input_mask=cur_input_mask,
1136
- rotary_pos_emb=rotary_pos_emb,
1137
- )
1138
-
1139
- head_pred = FalconTST_forecast[:, :head_pred_len]
1140
- output_list.append(head_pred)
1141
- cur_pred = torch.cat(output_list, dim=1)
1142
- cur_pred_len = cur_pred.shape[1]
1143
- if cur_pred_len >= self.config.inference_length:
1144
- break
1145
-
1146
- if cur_pred_len >= self.config.inference_length:
1147
- break
1148
-
1149
- final_output = cur_pred[:, :self.config.inference_length] # [batch_size, inference_len]
1150
 
1151
  assert final_output.shape[1] == self.config.inference_length
1152
  return final_output
1153
 
1154
- class FalconTSTForPrediction(FalconTSTPreTrainedModel, FalconTSTGenerationMixin):
 
1155
  def __init__(self, config: FalconTSTConfig):
1156
  super().__init__(config)
1157
  self.config = config
1158
  self.model = FalconTSTModel(self.config)
1159
  self.post_init()
1160
 
1161
- def forward(
 
1162
  self,
1163
- input_ids: torch.FloatTensor,
1164
- attention_mask: Optional[torch.Tensor] = None,
1165
- labels: Optional[torch.FloatTensor] = None,
1166
- return_dict: Optional[bool] = False,
1167
- max_output_length: Optional[int] = None,
1168
- revin: Optional[bool] = False,
1169
- ):
1170
- self.model.config.inference_length = max_output_length
1171
- outputs = self.model(
1172
- input=input_ids,
1173
- revin=revin
1174
- )
1175
-
1176
- loss = None
1177
- logits = outputs
1178
-
1179
- if labels is not None:
1180
- loss_fn = nn.MSELoss()
1181
- loss = loss_fn(logits, labels)
1182
-
1183
- if not return_dict:
1184
- output = (logits,)
1185
- return ((loss,) + output) if loss is not None else output
1186
 
1187
- return logits
 
 
 
1188
 
1189
- def prepare_inputs_for_generation(
1190
- self,
1191
- input_ids,
1192
- past_key_values=None,
1193
- attention_mask=None,
1194
- inputs_embeds=None,
1195
- revin=False,
1196
- **kwargs
1197
- ):
1198
- """
1199
- Prepare model inputs for autoregressive generation.
1200
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1201
 
1202
- model_inputs = {"input_ids": input_ids}
1203
-
1204
- model_inputs.update({
1205
- "revin": revin,
1206
- })
 
 
1207
 
1208
- return model_inputs
 
10
  from functools import reduce
11
  from abc import ABC, abstractmethod
12
  from .configuration_FalconTST import FalconTSTConfig
 
13
  from transformers import PreTrainedModel, Cache, DynamicCache
14
  from transformers.activations import ACT2FN
15
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 
73
  return torch.cat((t, t_pass), dim=-1)
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  class RotaryEmbedding(nn.Module):
77
  """Rotary Embedding.
78
 
 
98
 
99
  dim = kv_channels
100
  self.rotary_interleaved = rotary_interleaved
101
+ if use_cpu_initialization or not torch.cuda.is_available():
102
+ device = 'cpu'
103
+ else:
104
+ device = torch.cuda.current_device()
105
  self.inv_freq = 1.0 / (
106
  rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
107
  )
 
170
  return x
171
 
172
 
 
 
 
 
 
173
  class RMSNorm(nn.Module):
174
  def __init__(self, hidden_size, eps=1e-5):
175
  super().__init__()
 
204
  self.softmax_scale = softmax_scale
205
  self.drop = nn.Dropout(attention_dropout)
206
 
207
+ def forward(self, q, k, v, attention_mask):
208
  """Implements the multihead softmax attention.
209
  Arguments
210
  ---------
211
+ q,k,v: The tensor containing the query, key, and value. [seq_len, batch_size, hidden_size]
212
+ attention_mask: boolean mask to apply to the attention weights. True means to keep,
213
+ False means to mask out. [batch_size, 1, seq_len, seq_len]
 
214
  """
 
 
215
  q = q.transpose(0,1).contiguous()
216
  k = k.transpose(0,1).contiguous()
217
  v = v.transpose(0,1).contiguous()
218
 
219
  batch_size, seq_len = q.shape[0], q.shape[1]
220
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
221
+ # scores
222
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
223
  scores = scores.masked_fill(attention_mask == 0, float('-1e9'))
224
  # Softmax
 
226
  # Dropout
227
  attention_drop = self.drop(attention)
228
  output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
229
+ output = output.reshape(batch_size, seq_len, -1)
 
230
 
231
+ output = output.transpose(0,1).contiguous()
232
+ return output
233
 
234
 
235
  class SelfAttention(nn.Module):
236
  def __init__(self,config,):
237
  super().__init__()
238
  self.config = config
 
 
239
  self.hidden_size = config.hidden_size
240
  self.core_attention = TEDotProductAttention()
241
  self.linear_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.add_bias_linear,)
242
  self.linear_qkv = nn.Linear(self.hidden_size, 3*self.hidden_size, bias=config.add_bias_linear,)
 
 
 
 
 
 
 
 
243
 
244
+ def forward(self, x, attention_mask, rotary_pos_emb):
245
+ '''
246
+ x: [seq_len, batch_size, hidden_size]
247
+ attention_mask: [batch_size, 1, seq_len, seq_len]
248
+ rotary_pos_emb: [seq_len, 1, 1, kv_channels(hidden_size // num_heads)]
249
+ '''
250
  qkv = self.linear_qkv(x)
251
+ qkv = qkv.view(qkv.size(0), qkv.size(1), self.config.num_attention_heads, -1)
252
  q, k, v = qkv.chunk(3, dim=-1)
253
+
254
  # Apply rotary encoding to q and k
255
  rotary_pos_emb = (rotary_pos_emb,) * 2
256
  q_pos_emb, k_pos_emb = rotary_pos_emb
257
  q = _apply_rotary_pos_emb_bshd(q, q_pos_emb)
258
  k = _apply_rotary_pos_emb_bshd(k, k_pos_emb)
259
 
 
 
260
  # attention
261
  attn_output = self.core_attention(q, k, v, attention_mask)
262
  output = self.linear_proj(attn_output)
 
265
 
266
 
267
  class MLP(nn.Module):
268
+ def __init__(self,config, in_features):
269
  super().__init__()
270
  self.config= config
271
  self.linear_fc1 = nn.Linear(in_features, self.config.moe_ffn_hidden_size*2, bias=self.config.add_bias_linear,)
 
299
  self.input_layernorm = IdentityOp()
300
  self.self_attention = SelfAttention(config)
301
  self.pre_mlp_layernorm = RMSNorm(self.config.hidden_size)
302
+ self.mlp = MLP(config, self.config.hidden_size)
303
 
304
+ def forward(self, x, attention_mask, rotary_pos_emb):
305
+ '''
306
+ x: [seq_len, batch_size, hidden_size]
307
+ attention_mask: [batch_size, 1, seq_len, seq_len]
308
+ rotary_pos_emb: [seq_len, 1, 1, kv_channels(hidden_size // num_heads)]
309
+ '''
310
  residual = x
311
  x = self.input_layernorm(x)
312
  x = self.self_attention(x, attention_mask, rotary_pos_emb)
 
362
 
363
  # Patchify the input
364
  input_data = input_data.unfold(dimension=-1, size=self.patch_size, step=self.patch_size).contiguous() # input [batch_size, patch_num, patch_size]
365
+ hidden_states= self.patch_embedding(input_data) # hidden_states [batch_size, patch_num, hidden_size]
366
  hidden_states = hidden_states.transpose(0, 1).contiguous() # hidden_states [patch_num, batch_size, hidden_size], To adapt to the Megatron
367
 
368
  # Patchify the mask: only the entire time points in a patch are masked then this patch is masked
 
373
  attention_mask = attention_mask.unsqueeze(2).repeat(1,1,patch_num) * attention_mask.unsqueeze(1).repeat(1,patch_num,1) # [batch_size, patch_num, patch_num]
374
  attention_mask = attention_mask.unsqueeze(1).contiguous() # [batch_size, 1, patch_num, patch_num]
375
 
 
376
  return hidden_states, attention_mask, input_mask
377
 
378
+ def _forward_output(self, hidden_states, output_scale=None, input_mask=None):
 
379
  """
380
  Perform a forward pass through the output layer.
381
 
382
  Args:
 
383
  hidden_states (Tensor): Transformed hidden states of shape [patch_num, batch_size, hidden_size]
384
  output_scale (Tensor, optional): Expert probabilities for the output layer [batch_size]
385
  input_mask (Tensor, optional): Expert input mask of shape [batch_size, seq_len], 0:mask, 1:unmask
 
400
 
401
  return expert_output
402
 
403
+ def forward(self, expert_input, rotary_pos_emb, expert_probs=None):
404
  hidden_states, attention_mask, input_mask = self._forward_patch_embedding(expert_input)
405
+ # hidden_states: [patch_num, batch_size, hidden_size]
406
+ # attention_mask: [batch_size, 1, patch_num, patch_num]
407
+ # input_mask: [batch_size, seq_len]
408
+
409
  for layer in self.layers:
410
+ hidden_states = layer(hidden_states, attention_mask, rotary_pos_emb[:hidden_states.shape[0]])
411
+
412
  hidden_states = self.final_layernorm(hidden_states)
413
+
414
  expert_output = self._forward_output(hidden_states, expert_probs, input_mask)
415
  return expert_output
416
 
 
452
  return expert_output
453
 
454
 
455
+ class TopKRouter(nn.Module):
456
+ def __init__(self, config: FalconTSTConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  super().__init__()
458
  self.config = config
459
+ self.topk = config.moe_router_topk
460
 
461
+ self.weight = nn.Parameter(
462
+ torch.empty((config.num_moe_experts, config.moe_router_input_size), dtype=torch.float32)
463
+ )
 
 
 
 
 
 
 
 
464
  self.reset_parameters()
 
 
 
 
 
 
465
 
466
+ def reset_parameters(self):
467
+ nn.init.normal_(self.weight, mean=0, std=self.config.init_method_std)
 
 
 
 
 
 
 
 
 
 
 
 
468
 
 
469
  def routing(self, logits: torch.Tensor):
470
+ score_function = self.config.moe_router_score_function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
+ if score_function == "softmax":
473
+ if self.config.moe_router_pre_softmax:
474
+ scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
475
+ probs, top_indices = torch.topk(scores, self.topk, dim=1)
476
+ else:
477
+ scores, top_indices = torch.topk(logits, self.topk, dim=1)
478
+ probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  else:
480
+ raise NotImplementedError
481
+
482
+ routing_probs = torch.zeros_like(logits).scatter_(1, top_indices, probs)
483
+ routing_map = torch.zeros_like(logits, dtype=torch.bool).scatter_(1, top_indices, True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
+ return routing_probs, routing_map
486
+
487
  def forward(self, input: torch.Tensor):
488
+ if self.weight.device != input.device:
489
+ self.weight.data = self.weight.data.to(input.device)
490
+
491
+ gating_logits = F.linear(input, self.weight)
492
+ num_tokens = gating_logits.shape[:-1].numel()
493
+ gating_logits = gating_logits.view(num_tokens, self.config.num_moe_experts)
 
494
 
495
+ scores, routing_map = self.routing(gating_logits)
496
 
497
  return scores, routing_map
498
 
 
515
  self.expert_output_size = config.seq_length
516
 
517
  if self.is_last_layer and self.config.heterogeneous_moe_layer:
518
+ # If heterogeneous_moe_layer is True, the backcast will be None
519
+ self.backcast_layernorm = None
520
  else:
521
  self.backcast_layernorm = RMSNorm(self.seq_length)
522
 
 
597
  # permuted_probs (global_probs): [num_permuted_samples_after_dispatch_postprocess(sorted)]
598
 
599
  experts_output = self.experts(input, routing_map, rotary_pos_emb, probs)
 
600
 
601
  return experts_output, shared_experts_output
602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  def combine(
604
  self,
605
  experts_output: torch.Tensor,
 
608
  """Combines expert outputs via communication and adds shared expert output.
609
 
610
  This method uses the time series(sample) dispatcher to combine the outputs from different
611
+ experts. It then adds the output from the shared expert if it exists.
 
612
  """
613
  assert experts_output.shape == shared_experts_output.shape,\
614
  f'experts_output shape {experts_output.shape} doesn\'t equal to shared_experts_output shape:{shared_experts_output.shape}'
 
633
  return output_backcast, output_forecast
634
 
635
 
636
+ def postprocess(
637
+ self,
638
+ backcast: torch.Tensor, # [batch_size, seq_len]
639
+ forecast: torch.Tensor, # [batch_size, pred_len]
640
+ output_backcast: torch.Tensor, # [batch_size, seq_len]
641
+ output_forecast: torch.Tensor, # [batch_size, pred_len]
642
+ ):
643
+ """
644
+ Args:
645
+ backcast (torch.Tensor): The previous layer's backcast time series (samples). [batch_size, seq_len]
646
+ forecast (torch.Tensor): The previous layer's forecast time series (samples). [batch_size, pred_len]
647
+ output_backcast (torch.Tensor): The current layer's output backcast time series (samples). [batch_size, seq_len]
648
+ output_forecast (torch.Tensor): The current layer's output forecast time series (samples). [batch_size, pred_len]
649
+ """
650
+ if output_backcast is not None:
651
+ # 25/8/14 @modified by xiaming replace the revin with layernorm after the moe layer
652
+ # And if we multiply the output_backcast with the input mask, the performance will be hurted
653
+ output_backcast = self.backcast_layernorm(output_backcast) # LayerNorm
654
+ if self.config.residual_backcast:
655
+ output_backcast = backcast - output_backcast
656
+
657
+ output_backcast[~self.input_mask] = self.config.mask_pad_value # Important! Recover the mask time point back to mask_pad_value(default:255.)
658
+
659
+ if self.config.do_expert_forecast and forecast is not None: # The first layer's forecast is None
660
+ output_forecast = forecast + output_forecast
661
+
662
+ return output_backcast, output_forecast
663
+
664
+
665
+ def forward(self, backcast, forecast, rotary_pos_emb):
666
  inputs, probs, residual, routing_map = self.router_and_preprocess(backcast)
667
  experts_output, shared_experts_output = self.experts_compute(inputs, probs, residual, rotary_pos_emb, routing_map)
668
  output_backcast, output_forecast = self.combine(experts_output, shared_experts_output)
 
670
  return output_backcast, output_forecast
671
 
672
 
 
673
  class FalconTSTBlock(nn.Module):
674
+ def __init__(self, config, input_layernorm = True):
675
  super().__init__()
676
  self.config = config
677
+
678
+ if input_layernorm:
679
+ self.input_layernorm = RMSNorm(self.config.seq_length)
680
+ else:
681
+ self.input_layernorm = IdentityOp()
682
+
683
  self.layers = nn.ModuleList([
684
+ FalconTSTMoELayer(config, layer_num + 1)
685
+ for layer_num in range(self.config.num_hidden_layers)
686
+ ])
687
+
688
+ def forward(self, x, rotary_pos_emb):
689
  backcast = x
690
  forecast = None
691
+
692
+ input_mask = (backcast != self.config.mask_pad_value)
693
+ backcast = self.input_layernorm(backcast * input_mask)
694
+ backcast[~input_mask] = self.config.mask_pad_value
695
+
696
  for layer in self.layers:
697
+ backcast, forecast = layer(backcast, forecast, rotary_pos_emb)
698
  return backcast,forecast
699
 
700
 
 
719
  if module.padding_idx is not None:
720
  module.weight.data[module.padding_idx].zero_()
721
 
722
+
723
  class FalconTSTModel(FalconTSTPreTrainedModel):
724
  def __init__(self, config: FalconTSTConfig):
725
  super().__init__(config)
726
  self.config = config
727
+ self.seq_length = self.config.seq_length
728
  self.rotary_pos_emb = RotaryEmbedding(
729
+ kv_channels=self.config.kv_channels,
730
+ rotary_base=self.config.rotary_base,
731
+ use_cpu_initialization=self.config.use_cpu_initialization,
732
+ rotary_interleaved=self.config.rotary_interleaved
733
  )
734
  self.decoder = FalconTSTBlock(
735
+ config=config,
736
+ input_layernorm=self.config.block_input_layernorm
737
+ )
738
  if self.config.do_expert_forecast and self.config.heterogeneous_moe_layer:
739
  self.output_layer = IdentityOp()
740
  else:
741
+ self.output_layer = nn.Linear(in_features=self.seq_length,
742
+ out_features=self.config.pred_length,
743
+ bias=self.config.add_bias_linear,)
744
 
745
 
746
  def revin(
 
769
  return input, means, stdev
770
 
771
  def forward(self, input, revin):
772
+
 
 
 
 
773
  batch_size, input_len = input.shape
 
774
  # realize varied input length
775
  if input_len > self.seq_length:
776
  input = input[:, -self.seq_length:]
 
790
  rotary_pos_emb = self.rotary_pos_emb(input_len, device=input.device)
791
 
792
  # Step3. Do one-step inference to get mixed forecasts from multiple forecast heads
793
+ # mixed_pred: [batch_size, max(multi_forecast_head)]
794
  mixed_pred = self._inference_step(
795
  input=input,
796
  input_mask=input_mask,
 
823
  rotary_pos_emb,
824
  ):
825
  if self.config.do_base_forecast:
826
+ base_forecast, _ = self.base_output_layer(input * input_mask)
827
  else:
828
  base_forecast = None
829
 
830
  decoder_backcast, decoder_forecast = self.decoder(
831
+ input, # [batch_size, seq_len]
832
  rotary_pos_emb, # [input_len, 1, 1, kv_channels(hidden_size // num_heads)]
833
  )
834
 
 
837
  if self.config.heterogeneous_moe_layer:
838
  decoder_forecast = self.output_layer(decoder_forecast) # IdentityOp
839
  else:
840
+ final_forecast= self.output_layer(decoder_backcast * input_mask)
841
  decoder_forecast = decoder_forecast + final_forecast
842
  else:
843
  # The decoder_backcast contains the mask_pad_val(default:255.)
844
  decoder_forecast, _ = self.output_layer(decoder_backcast * input_mask)
845
+
846
  if self.config.do_base_forecast:
847
  assert base_forecast is not None, f'base_forecast is None'
848
  FalconTST_forecast = base_forecast + decoder_forecast
 
898
 
899
  final_output = final_output[:, :self.config.inference_length]
900
 
901
+ else:
902
+ raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903
 
904
  assert final_output.shape[1] == self.config.inference_length
905
  return final_output
906
 
907
+
908
+ class FalconTSTForPrediction(FalconTSTPreTrainedModel):
909
  def __init__(self, config: FalconTSTConfig):
910
  super().__init__(config)
911
  self.config = config
912
  self.model = FalconTSTModel(self.config)
913
  self.post_init()
914
 
915
+ @torch.no_grad()
916
+ def predict(
917
  self,
918
+ time_series: torch.Tensor,
919
+ forecast_horizon: int,
920
+ revin: bool = True,
921
+ ) -> torch.Tensor:
922
+ """
923
+ Generates time series forecasts autoregressively.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
924
 
925
+ Args:
926
+ time_series (torch.Tensor): Input time series data.
927
+ Shape: [batch_size, seq_len] or [batch_size, seq_len, channels].
928
+ forecast_horizon (int): The number of future time steps to predict.
929
 
930
+ Returns:
931
+ torch.Tensor: The forecasted time series. Shape: [batch_size, forecast_horizon, channels].
 
 
 
 
 
 
 
 
 
932
  """
933
+ self.eval()
934
+
935
+ assert time_series.ndim == 2 or time_series.ndim == 3, "Input shape must be [batch, seq_len, channel] or [batch, seq_len]"
936
+ is_multichannel = time_series.ndim == 3
937
+ if is_multichannel:
938
+ batch_size, seq_len, num_channels = time_series.shape
939
+ # [B, L, C] -> [B * C, L]
940
+ input_flat = time_series.permute(0, 2, 1).reshape(batch_size * num_channels, seq_len)
941
+ else:
942
+ batch_size, seq_len = time_series.shape
943
+ num_channels = 1
944
+ input_flat = time_series
945
+
946
+ self.config.inference_length = forecast_horizon
947
+ forecast_flat = self.model(
948
+ input=input_flat,
949
+ revin=revin
950
+ ) # Shape: [B * C, H]
951
 
952
+ if is_multichannel:
953
+ forecast = forecast_flat.reshape(batch_size, num_channels, forecast_horizon)
954
+ forecast = forecast.permute(0, 2, 1).contiguous()
955
+ else:
956
+ forecast = forecast_flat
957
+
958
+ return forecast
959
 
 
ts_generation_mixin.py DELETED
@@ -1,89 +0,0 @@
1
- import warnings
2
- from typing import Any, Dict, List, Optional, Union, Callable
3
- import torch
4
- from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
5
- from transformers.generation import validate_stopping_criteria, EosTokenCriteria
6
- from transformers.generation.utils import (
7
- GenerateNonBeamOutput,
8
- GenerateEncoderDecoderOutput,
9
- GenerateDecoderOnlyOutput,
10
- GenerationConfig,
11
- GenerateOutput,
12
- )
13
- from transformers.utils import ModelOutput
14
-
15
-
16
- class FalconTSTGenerationMixin(GenerationMixin):
17
- @torch.no_grad()
18
- def generate(
19
- self,
20
- inputs: Optional[torch.Tensor] = None,
21
- generation_config: Optional[GenerationConfig] = None,
22
- logits_processor: Optional[LogitsProcessorList] = None,
23
- stopping_criteria: Optional[StoppingCriteriaList] = None,
24
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
25
- synced_gpus: Optional[bool] = None,
26
- assistant_model: Optional["PreTrainedModel"] = None,
27
- streamer: Optional["BaseStreamer"] = None,
28
- negative_prompt_ids: Optional[torch.Tensor] = None,
29
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
30
- revin: Optional[bool] = True,
31
- num_samples: Optional[int] = 1,
32
- **kwargs,
33
- ) -> Union[GenerateOutput, torch.LongTensor]:
34
- """
35
- FalconTST generate function。
36
- """
37
- batch_size = inputs.shape[0]
38
- length = inputs.shape[1]
39
- channel = 1
40
- if len(inputs.shape) == 3:
41
- channel = inputs.shape[2]
42
- inputs = inputs.permute(0, 2, 1).reshape(batch_size * channel, length)
43
- elif len(inputs.shape) > 3:
44
- raise ValueError("Input shape must be [batch, seq_len, channel] or [batch, seq_len]")
45
-
46
- outputs = super().generate(
47
- inputs=inputs,
48
- generation_config=generation_config,
49
- logits_processor=logits_processor,
50
- stopping_criteria=stopping_criteria,
51
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
52
- synced_gpus=synced_gpus,
53
- assistant_model=assistant_model,
54
- streamer=streamer,
55
- negative_prompt_ids=negative_prompt_ids,
56
- negative_prompt_attention_mask=negative_prompt_attention_mask,
57
- revin=revin,
58
- **kwargs,
59
- )
60
- pred_len = outputs.shape[1]
61
- outputs = outputs.reshape(batch_size, channel, pred_len)
62
- outputs = outputs.transpose(1, 2).contiguous()
63
- return outputs
64
-
65
- def _greedy_search(
66
- self,
67
- input_ids: torch.Tensor,
68
- logits_processor: Optional[LogitsProcessorList] = None,
69
- stopping_criteria: Optional[StoppingCriteriaList] = None,
70
- max_length: Optional[int] = None,
71
- pad_token_id: Optional[int] = None,
72
- eos_token_id: Optional[Union[int, List[int]]] = None,
73
- output_attentions: Optional[bool] = None,
74
- output_hidden_states: Optional[bool] = None,
75
- output_scores: Optional[bool] = None,
76
- output_logits: Optional[bool] = None,
77
- return_dict_in_generate: Optional[bool] = None,
78
- synced_gpus: bool = False,
79
- streamer: Optional["BaseStreamer"] = None,
80
- **model_kwargs,
81
- ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
82
- input_ids = input_ids.to(self.device)
83
- batch_size, cur_len = input_ids.shape
84
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
85
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
86
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
87
- # stopping_criteria.max_length = input_len + pred_len
88
- outputs = self(**model_inputs, return_dict=True, max_output_length=stopping_criteria.max_length-cur_len)
89
- return outputs