Minor changes proposal to allow ONNX export

#54
Files changed (1) hide show
  1. modeling_mixformer_sequential.py +38 -71
modeling_mixformer_sequential.py CHANGED
@@ -117,10 +117,6 @@ def _apply_rotary_emb(
117
  rotary_seqlen, rotary_dim = cos.shape
118
  rotary_dim *= 2
119
 
120
- assert rotary_dim <= head_dim
121
- assert seqlen <= rotary_seqlen
122
- assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
123
-
124
  x_rot = x[:, :, :, :rotary_dim]
125
  x_pass = x[:, :, :, rotary_dim:]
126
 
@@ -141,13 +137,9 @@ def _apply_rotary_emb_kv(
141
  sin_k: Optional[torch.FloatTensor] = None,
142
  ) -> torch.FloatTensor:
143
  _, seqlen, two, _, head_dim = kv.shape
144
- assert two == 2
145
 
146
  rotary_seqlen, rotary_dim = cos.shape
147
  rotary_dim *= 2
148
- assert rotary_dim <= head_dim
149
- assert seqlen <= rotary_seqlen
150
- assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
151
 
152
  k_rot = kv[:, :, 0, :, :rotary_dim]
153
  k_pass = kv[:, :, 0, :, rotary_dim:]
@@ -175,13 +167,9 @@ def _apply_rotary_emb_qkv(
175
  sin_k: Optional[torch.FloatTensor] = None,
176
  ) -> torch.FloatTensor:
177
  _, seqlen, three, _, head_dim = qkv.shape
178
- assert three == 3
179
 
180
  rotary_seqlen, rotary_dim = cos.shape
181
  rotary_dim *= 2
182
- assert rotary_dim <= head_dim
183
- assert seqlen <= rotary_seqlen
184
- assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
185
 
186
  q_rot = qkv[:, :, 0, :, :rotary_dim]
187
  q_pass = qkv[:, :, 0, :, rotary_dim:]
@@ -223,6 +211,7 @@ class RotaryEmbedding(nn.Module):
223
  scale_base: Optional[float] = None,
224
  pos_idx_in_fp32: bool = True,
225
  device: Optional[str] = None,
 
226
  **kwargs,
227
  ) -> None:
228
  super().__init__()
@@ -248,11 +237,8 @@ class RotaryEmbedding(nn.Module):
248
  )
249
  self.register_buffer("scale", scale, persistent=False)
250
 
251
- self._seq_len_cached = 0
252
- self._cos_cached = None
253
- self._sin_cached = None
254
- self._cos_k_cached = None
255
- self._sin_k_cached = None
256
 
257
  def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
258
  return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
@@ -262,43 +248,36 @@ class RotaryEmbedding(nn.Module):
262
  ) -> None:
263
  # Reset the tables if sequence length has been chaned, if we are on a
264
  # new device or if we are switching from inference mode to training
265
- if (
266
- seqlen > self._seq_len_cached
267
- or self._cos_cached is None
268
- or self._cos_cached.device != device
269
- or self._cos_cached.dtype != dtype
270
- or (self.training and self._cos_cached.is_inference())
271
- ):
272
- self._seq_len_cached = seqlen
273
-
274
- # fp32 is preferred since the output of `torch.arange` can be quite large
275
- # and bf16 would lose a lot of precision
276
- if self.pos_idx_in_fp32:
277
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
278
- if self.inv_freq.dtype != torch.float32:
279
- inv_freq = self._compute_inv_freq(device=device)
280
- else:
281
- inv_freq = self.inv_freq
282
  else:
283
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
284
  inv_freq = self.inv_freq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
287
- freqs = torch.outer(t, inv_freq)
288
- if self.scale is None:
289
- self._cos_cached = torch.cos(freqs).to(dtype)
290
- self._sin_cached = torch.sin(freqs).to(dtype)
291
- else:
292
- power = (
293
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
294
- ) / self.scale_base
295
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
296
-
297
- # Force the scale multiplication to happen in fp32
298
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
299
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
300
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
301
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
302
 
303
  def forward(
304
  self,
@@ -309,10 +288,11 @@ class RotaryEmbedding(nn.Module):
309
  ) -> Tuple[torch.Tensor, torch.Tensor]:
310
  seqlen = qkv.shape[1]
311
 
312
- if max_seqlen is not None:
313
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
314
- else:
315
- self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
 
316
 
317
  if kv is None:
318
  return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
@@ -336,7 +316,6 @@ class MLP(nn.Module):
336
  super().__init__()
337
 
338
  act_fn = config.activation_function if act_fn is None else act_fn
339
- assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
340
 
341
  n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
342
  n_inner = n_inner if n_inner is not None else 4 * config.n_embd
@@ -436,7 +415,6 @@ class CrossAttention(nn.Module):
436
  ) -> torch.FloatTensor:
437
  batch_size, seqlen_q = q.shape[0], q.shape[1]
438
  seqlen_k = kv.shape[1]
439
- assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
440
 
441
  if kv.shape[3] != q.shape[2]:
442
  kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
@@ -474,14 +452,6 @@ def _find_mha_dims(
474
  n_head_kv: Optional[int] = None,
475
  head_dim: Optional[int] = None,
476
  ) -> Tuple[int, int]:
477
- assert all(
478
- hasattr(config, attr) for attr in ["n_embd", "n_head"]
479
- ), "`config` must have `n_embd` and `n_head` attributes."
480
-
481
- if head_dim is None:
482
- assert (
483
- config.n_embd % config.n_head == 0
484
- ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
485
 
486
  if n_head is None and head_dim is None:
487
  head_dim = config.n_embd // config.n_head
@@ -491,7 +461,6 @@ def _find_mha_dims(
491
 
492
  if n_head_kv is None:
493
  n_head_kv = getattr(config, "n_head_kv", None) or n_head
494
- assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
495
 
496
  return n_head, n_head_kv, head_dim
497
 
@@ -515,13 +484,10 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
515
 
516
  batch_start = inference_params.batch_size_offset
517
  batch_end = batch_start + kv.shape[0]
518
- assert batch_end <= kv_cache.shape[0]
519
 
520
  sequence_start = inference_params.seqlen_offset
521
  sequence_end = sequence_start + kv.shape[1]
522
- assert sequence_end <= kv_cache.shape[1]
523
 
524
- assert kv_cache is not None
525
  kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
526
  kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
527
 
@@ -560,7 +526,7 @@ class MHA(nn.Module):
560
  rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
561
  if rotary_cls is None:
562
  rotary_cls = RotaryEmbedding
563
- self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
564
 
565
  # MLP
566
  self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
@@ -632,7 +598,8 @@ class MHA(nn.Module):
632
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
633
  **kwargs,
634
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
635
- if attention_mask is not None and torch.any(~attention_mask.bool()):
 
636
  attention_mask = attention_mask.bool()
637
  else:
638
  attention_mask = None
 
117
  rotary_seqlen, rotary_dim = cos.shape
118
  rotary_dim *= 2
119
 
 
 
 
 
120
  x_rot = x[:, :, :, :rotary_dim]
121
  x_pass = x[:, :, :, rotary_dim:]
122
 
 
137
  sin_k: Optional[torch.FloatTensor] = None,
138
  ) -> torch.FloatTensor:
139
  _, seqlen, two, _, head_dim = kv.shape
 
140
 
141
  rotary_seqlen, rotary_dim = cos.shape
142
  rotary_dim *= 2
 
 
 
143
 
144
  k_rot = kv[:, :, 0, :, :rotary_dim]
145
  k_pass = kv[:, :, 0, :, rotary_dim:]
 
167
  sin_k: Optional[torch.FloatTensor] = None,
168
  ) -> torch.FloatTensor:
169
  _, seqlen, three, _, head_dim = qkv.shape
 
170
 
171
  rotary_seqlen, rotary_dim = cos.shape
172
  rotary_dim *= 2
 
 
 
173
 
174
  q_rot = qkv[:, :, 0, :, :rotary_dim]
175
  q_pass = qkv[:, :, 0, :, rotary_dim:]
 
211
  scale_base: Optional[float] = None,
212
  pos_idx_in_fp32: bool = True,
213
  device: Optional[str] = None,
214
+ max_position_embeddings=2048,
215
  **kwargs,
216
  ) -> None:
217
  super().__init__()
 
237
  )
238
  self.register_buffer("scale", scale, persistent=False)
239
 
240
+ # NOTE: initialize cached attributes
241
+ self._update_cos_sin_cache(seqlen=max_position_embeddings, device=device, dtype=torch.float32)
 
 
 
242
 
243
  def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
244
  return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
 
248
  ) -> None:
249
  # Reset the tables if sequence length has been chaned, if we are on a
250
  # new device or if we are switching from inference mode to training
251
+ self._seq_len_cached = seqlen
252
+
253
+ # fp32 is preferred since the output of `torch.arange` can be quite large
254
+ # and bf16 would lose a lot of precision
255
+ if self.pos_idx_in_fp32:
256
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
257
+ if self.inv_freq.dtype != torch.float32:
258
+ inv_freq = self._compute_inv_freq(device=device)
 
 
 
 
 
 
 
 
 
259
  else:
 
260
  inv_freq = self.inv_freq
261
+ else:
262
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
263
+ inv_freq = self.inv_freq
264
+
265
+ # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
266
+ freqs = torch.outer(t, inv_freq)
267
+ if self.scale is None:
268
+ self._cos_cached = torch.cos(freqs).to(dtype)
269
+ self._sin_cached = torch.sin(freqs).to(dtype)
270
+ else:
271
+ power = (
272
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
273
+ ) / self.scale_base
274
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
275
 
276
+ # Force the scale multiplication to happen in fp32
277
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
278
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
279
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
280
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  def forward(
283
  self,
 
288
  ) -> Tuple[torch.Tensor, torch.Tensor]:
289
  seqlen = qkv.shape[1]
290
 
291
+ if seqlen > self._seq_len_cached:
292
+ if max_seqlen is not None:
293
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
294
+ else:
295
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
296
 
297
  if kv is None:
298
  return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
 
316
  super().__init__()
317
 
318
  act_fn = config.activation_function if act_fn is None else act_fn
 
319
 
320
  n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
321
  n_inner = n_inner if n_inner is not None else 4 * config.n_embd
 
415
  ) -> torch.FloatTensor:
416
  batch_size, seqlen_q = q.shape[0], q.shape[1]
417
  seqlen_k = kv.shape[1]
 
418
 
419
  if kv.shape[3] != q.shape[2]:
420
  kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
 
452
  n_head_kv: Optional[int] = None,
453
  head_dim: Optional[int] = None,
454
  ) -> Tuple[int, int]:
 
 
 
 
 
 
 
 
455
 
456
  if n_head is None and head_dim is None:
457
  head_dim = config.n_embd // config.n_head
 
461
 
462
  if n_head_kv is None:
463
  n_head_kv = getattr(config, "n_head_kv", None) or n_head
 
464
 
465
  return n_head, n_head_kv, head_dim
466
 
 
484
 
485
  batch_start = inference_params.batch_size_offset
486
  batch_end = batch_start + kv.shape[0]
 
487
 
488
  sequence_start = inference_params.seqlen_offset
489
  sequence_end = sequence_start + kv.shape[1]
 
490
 
 
491
  kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
492
  kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
493
 
 
526
  rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
527
  if rotary_cls is None:
528
  rotary_cls = RotaryEmbedding
529
+ self.rotary_emb = rotary_cls(self.rotary_emb_dim, max_position_embeddings=config.n_positions, **rotary_kwargs)
530
 
531
  # MLP
532
  self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
 
598
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
599
  **kwargs,
600
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
601
+ # TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
602
+ if attention_mask is not None:
603
  attention_mask = attention_mask.bool()
604
  else:
605
  attention_mask = None