Minor changes proposal to allow ONNX export
#54
by
titaiwang03
- opened
- modeling_mixformer_sequential.py +38 -71
modeling_mixformer_sequential.py
CHANGED
|
@@ -117,10 +117,6 @@ def _apply_rotary_emb(
|
|
| 117 |
rotary_seqlen, rotary_dim = cos.shape
|
| 118 |
rotary_dim *= 2
|
| 119 |
|
| 120 |
-
assert rotary_dim <= head_dim
|
| 121 |
-
assert seqlen <= rotary_seqlen
|
| 122 |
-
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
| 123 |
-
|
| 124 |
x_rot = x[:, :, :, :rotary_dim]
|
| 125 |
x_pass = x[:, :, :, rotary_dim:]
|
| 126 |
|
|
@@ -141,13 +137,9 @@ def _apply_rotary_emb_kv(
|
|
| 141 |
sin_k: Optional[torch.FloatTensor] = None,
|
| 142 |
) -> torch.FloatTensor:
|
| 143 |
_, seqlen, two, _, head_dim = kv.shape
|
| 144 |
-
assert two == 2
|
| 145 |
|
| 146 |
rotary_seqlen, rotary_dim = cos.shape
|
| 147 |
rotary_dim *= 2
|
| 148 |
-
assert rotary_dim <= head_dim
|
| 149 |
-
assert seqlen <= rotary_seqlen
|
| 150 |
-
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
| 151 |
|
| 152 |
k_rot = kv[:, :, 0, :, :rotary_dim]
|
| 153 |
k_pass = kv[:, :, 0, :, rotary_dim:]
|
|
@@ -175,13 +167,9 @@ def _apply_rotary_emb_qkv(
|
|
| 175 |
sin_k: Optional[torch.FloatTensor] = None,
|
| 176 |
) -> torch.FloatTensor:
|
| 177 |
_, seqlen, three, _, head_dim = qkv.shape
|
| 178 |
-
assert three == 3
|
| 179 |
|
| 180 |
rotary_seqlen, rotary_dim = cos.shape
|
| 181 |
rotary_dim *= 2
|
| 182 |
-
assert rotary_dim <= head_dim
|
| 183 |
-
assert seqlen <= rotary_seqlen
|
| 184 |
-
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
| 185 |
|
| 186 |
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
| 187 |
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
|
@@ -223,6 +211,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 223 |
scale_base: Optional[float] = None,
|
| 224 |
pos_idx_in_fp32: bool = True,
|
| 225 |
device: Optional[str] = None,
|
|
|
|
| 226 |
**kwargs,
|
| 227 |
) -> None:
|
| 228 |
super().__init__()
|
|
@@ -248,11 +237,8 @@ class RotaryEmbedding(nn.Module):
|
|
| 248 |
)
|
| 249 |
self.register_buffer("scale", scale, persistent=False)
|
| 250 |
|
| 251 |
-
|
| 252 |
-
self.
|
| 253 |
-
self._sin_cached = None
|
| 254 |
-
self._cos_k_cached = None
|
| 255 |
-
self._sin_k_cached = None
|
| 256 |
|
| 257 |
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
| 258 |
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
|
@@ -262,43 +248,36 @@ class RotaryEmbedding(nn.Module):
|
|
| 262 |
) -> None:
|
| 263 |
# Reset the tables if sequence length has been chaned, if we are on a
|
| 264 |
# new device or if we are switching from inference mode to training
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
# fp32 is preferred since the output of `torch.arange` can be quite large
|
| 275 |
-
# and bf16 would lose a lot of precision
|
| 276 |
-
if self.pos_idx_in_fp32:
|
| 277 |
-
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 278 |
-
if self.inv_freq.dtype != torch.float32:
|
| 279 |
-
inv_freq = self._compute_inv_freq(device=device)
|
| 280 |
-
else:
|
| 281 |
-
inv_freq = self.inv_freq
|
| 282 |
else:
|
| 283 |
-
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| 284 |
inv_freq = self.inv_freq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
-
#
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
else:
|
| 292 |
-
power = (
|
| 293 |
-
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
| 294 |
-
) / self.scale_base
|
| 295 |
-
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
| 296 |
-
|
| 297 |
-
# Force the scale multiplication to happen in fp32
|
| 298 |
-
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 299 |
-
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| 300 |
-
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
| 301 |
-
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
| 302 |
|
| 303 |
def forward(
|
| 304 |
self,
|
|
@@ -309,10 +288,11 @@ class RotaryEmbedding(nn.Module):
|
|
| 309 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 310 |
seqlen = qkv.shape[1]
|
| 311 |
|
| 312 |
-
if
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
|
|
|
| 316 |
|
| 317 |
if kv is None:
|
| 318 |
return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
|
@@ -336,7 +316,6 @@ class MLP(nn.Module):
|
|
| 336 |
super().__init__()
|
| 337 |
|
| 338 |
act_fn = config.activation_function if act_fn is None else act_fn
|
| 339 |
-
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
|
| 340 |
|
| 341 |
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
| 342 |
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
|
@@ -436,7 +415,6 @@ class CrossAttention(nn.Module):
|
|
| 436 |
) -> torch.FloatTensor:
|
| 437 |
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 438 |
seqlen_k = kv.shape[1]
|
| 439 |
-
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
| 440 |
|
| 441 |
if kv.shape[3] != q.shape[2]:
|
| 442 |
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
|
@@ -474,14 +452,6 @@ def _find_mha_dims(
|
|
| 474 |
n_head_kv: Optional[int] = None,
|
| 475 |
head_dim: Optional[int] = None,
|
| 476 |
) -> Tuple[int, int]:
|
| 477 |
-
assert all(
|
| 478 |
-
hasattr(config, attr) for attr in ["n_embd", "n_head"]
|
| 479 |
-
), "`config` must have `n_embd` and `n_head` attributes."
|
| 480 |
-
|
| 481 |
-
if head_dim is None:
|
| 482 |
-
assert (
|
| 483 |
-
config.n_embd % config.n_head == 0
|
| 484 |
-
), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
|
| 485 |
|
| 486 |
if n_head is None and head_dim is None:
|
| 487 |
head_dim = config.n_embd // config.n_head
|
|
@@ -491,7 +461,6 @@ def _find_mha_dims(
|
|
| 491 |
|
| 492 |
if n_head_kv is None:
|
| 493 |
n_head_kv = getattr(config, "n_head_kv", None) or n_head
|
| 494 |
-
assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
|
| 495 |
|
| 496 |
return n_head, n_head_kv, head_dim
|
| 497 |
|
|
@@ -515,13 +484,10 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
|
| 515 |
|
| 516 |
batch_start = inference_params.batch_size_offset
|
| 517 |
batch_end = batch_start + kv.shape[0]
|
| 518 |
-
assert batch_end <= kv_cache.shape[0]
|
| 519 |
|
| 520 |
sequence_start = inference_params.seqlen_offset
|
| 521 |
sequence_end = sequence_start + kv.shape[1]
|
| 522 |
-
assert sequence_end <= kv_cache.shape[1]
|
| 523 |
|
| 524 |
-
assert kv_cache is not None
|
| 525 |
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
| 526 |
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
| 527 |
|
|
@@ -560,7 +526,7 @@ class MHA(nn.Module):
|
|
| 560 |
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
| 561 |
if rotary_cls is None:
|
| 562 |
rotary_cls = RotaryEmbedding
|
| 563 |
-
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
| 564 |
|
| 565 |
# MLP
|
| 566 |
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
|
@@ -632,7 +598,8 @@ class MHA(nn.Module):
|
|
| 632 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 633 |
**kwargs,
|
| 634 |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
| 635 |
-
|
|
|
|
| 636 |
attention_mask = attention_mask.bool()
|
| 637 |
else:
|
| 638 |
attention_mask = None
|
|
|
|
| 117 |
rotary_seqlen, rotary_dim = cos.shape
|
| 118 |
rotary_dim *= 2
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
x_rot = x[:, :, :, :rotary_dim]
|
| 121 |
x_pass = x[:, :, :, rotary_dim:]
|
| 122 |
|
|
|
|
| 137 |
sin_k: Optional[torch.FloatTensor] = None,
|
| 138 |
) -> torch.FloatTensor:
|
| 139 |
_, seqlen, two, _, head_dim = kv.shape
|
|
|
|
| 140 |
|
| 141 |
rotary_seqlen, rotary_dim = cos.shape
|
| 142 |
rotary_dim *= 2
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
k_rot = kv[:, :, 0, :, :rotary_dim]
|
| 145 |
k_pass = kv[:, :, 0, :, rotary_dim:]
|
|
|
|
| 167 |
sin_k: Optional[torch.FloatTensor] = None,
|
| 168 |
) -> torch.FloatTensor:
|
| 169 |
_, seqlen, three, _, head_dim = qkv.shape
|
|
|
|
| 170 |
|
| 171 |
rotary_seqlen, rotary_dim = cos.shape
|
| 172 |
rotary_dim *= 2
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
| 175 |
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
|
|
|
| 211 |
scale_base: Optional[float] = None,
|
| 212 |
pos_idx_in_fp32: bool = True,
|
| 213 |
device: Optional[str] = None,
|
| 214 |
+
max_position_embeddings=2048,
|
| 215 |
**kwargs,
|
| 216 |
) -> None:
|
| 217 |
super().__init__()
|
|
|
|
| 237 |
)
|
| 238 |
self.register_buffer("scale", scale, persistent=False)
|
| 239 |
|
| 240 |
+
# NOTE: initialize cached attributes
|
| 241 |
+
self._update_cos_sin_cache(seqlen=max_position_embeddings, device=device, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
| 244 |
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
|
|
|
| 248 |
) -> None:
|
| 249 |
# Reset the tables if sequence length has been chaned, if we are on a
|
| 250 |
# new device or if we are switching from inference mode to training
|
| 251 |
+
self._seq_len_cached = seqlen
|
| 252 |
+
|
| 253 |
+
# fp32 is preferred since the output of `torch.arange` can be quite large
|
| 254 |
+
# and bf16 would lose a lot of precision
|
| 255 |
+
if self.pos_idx_in_fp32:
|
| 256 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 257 |
+
if self.inv_freq.dtype != torch.float32:
|
| 258 |
+
inv_freq = self._compute_inv_freq(device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
else:
|
|
|
|
| 260 |
inv_freq = self.inv_freq
|
| 261 |
+
else:
|
| 262 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| 263 |
+
inv_freq = self.inv_freq
|
| 264 |
+
|
| 265 |
+
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
| 266 |
+
freqs = torch.outer(t, inv_freq)
|
| 267 |
+
if self.scale is None:
|
| 268 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
| 269 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
| 270 |
+
else:
|
| 271 |
+
power = (
|
| 272 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
| 273 |
+
) / self.scale_base
|
| 274 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
| 275 |
|
| 276 |
+
# Force the scale multiplication to happen in fp32
|
| 277 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 278 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| 279 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
| 280 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
def forward(
|
| 283 |
self,
|
|
|
|
| 288 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 289 |
seqlen = qkv.shape[1]
|
| 290 |
|
| 291 |
+
if seqlen > self._seq_len_cached:
|
| 292 |
+
if max_seqlen is not None:
|
| 293 |
+
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
| 294 |
+
else:
|
| 295 |
+
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
| 296 |
|
| 297 |
if kv is None:
|
| 298 |
return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
|
|
|
| 316 |
super().__init__()
|
| 317 |
|
| 318 |
act_fn = config.activation_function if act_fn is None else act_fn
|
|
|
|
| 319 |
|
| 320 |
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
| 321 |
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
|
|
|
| 415 |
) -> torch.FloatTensor:
|
| 416 |
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 417 |
seqlen_k = kv.shape[1]
|
|
|
|
| 418 |
|
| 419 |
if kv.shape[3] != q.shape[2]:
|
| 420 |
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
|
|
|
| 452 |
n_head_kv: Optional[int] = None,
|
| 453 |
head_dim: Optional[int] = None,
|
| 454 |
) -> Tuple[int, int]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
if n_head is None and head_dim is None:
|
| 457 |
head_dim = config.n_embd // config.n_head
|
|
|
|
| 461 |
|
| 462 |
if n_head_kv is None:
|
| 463 |
n_head_kv = getattr(config, "n_head_kv", None) or n_head
|
|
|
|
| 464 |
|
| 465 |
return n_head, n_head_kv, head_dim
|
| 466 |
|
|
|
|
| 484 |
|
| 485 |
batch_start = inference_params.batch_size_offset
|
| 486 |
batch_end = batch_start + kv.shape[0]
|
|
|
|
| 487 |
|
| 488 |
sequence_start = inference_params.seqlen_offset
|
| 489 |
sequence_end = sequence_start + kv.shape[1]
|
|
|
|
| 490 |
|
|
|
|
| 491 |
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
| 492 |
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
| 493 |
|
|
|
|
| 526 |
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
| 527 |
if rotary_cls is None:
|
| 528 |
rotary_cls = RotaryEmbedding
|
| 529 |
+
self.rotary_emb = rotary_cls(self.rotary_emb_dim, max_position_embeddings=config.n_positions, **rotary_kwargs)
|
| 530 |
|
| 531 |
# MLP
|
| 532 |
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
|
|
|
| 598 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 599 |
**kwargs,
|
| 600 |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
| 601 |
+
# TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
|
| 602 |
+
if attention_mask is not None:
|
| 603 |
attention_mask = attention_mask.bool()
|
| 604 |
else:
|
| 605 |
attention_mask = None
|