Fix _prepare_4d_causal_attention_mask_for_sdpa
Browse files- modeling_plamo.py +98 -2
modeling_plamo.py
CHANGED
|
@@ -6,13 +6,109 @@ from torch import nn
|
|
| 6 |
from torch.nn import functional as F
|
| 7 |
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedModel
|
| 8 |
from transformers.modeling_attn_mask_utils import (
|
|
|
|
| 9 |
_prepare_4d_causal_attention_mask,
|
| 10 |
-
_prepare_4d_causal_attention_mask_for_sdpa,
|
| 11 |
)
|
| 12 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 13 |
from transformers.tokenization_utils_base import BatchEncoding
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def _swiglu(h: torch.Tensor) -> torch.Tensor:
|
| 17 |
h0, h1 = h.chunk(2, dim=-1)
|
| 18 |
return torch.nn.functional.silu(h0) * h1
|
|
@@ -817,7 +913,7 @@ class ModifiedAttention(Attention):
|
|
| 817 |
|
| 818 |
|
| 819 |
PLAMO_ATTENTION_CLASSES = {
|
| 820 |
-
"sdpa":
|
| 821 |
}
|
| 822 |
|
| 823 |
|
|
|
|
| 6 |
from torch.nn import functional as F
|
| 7 |
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedModel
|
| 8 |
from transformers.modeling_attn_mask_utils import (
|
| 9 |
+
AttentionMaskConverter,
|
| 10 |
_prepare_4d_causal_attention_mask,
|
|
|
|
| 11 |
)
|
| 12 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 13 |
from transformers.tokenization_utils_base import BatchEncoding
|
| 14 |
|
| 15 |
|
| 16 |
+
# From: https://github.com/McGill-NLP/llm2vec/blob/main/llm2vec/models/attn_mask_utils.py
|
| 17 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(
|
| 18 |
+
attention_mask: Optional[torch.Tensor],
|
| 19 |
+
input_shape: Union[torch.Size, Tuple, List],
|
| 20 |
+
inputs_embeds: torch.Tensor,
|
| 21 |
+
past_key_values_length: int,
|
| 22 |
+
sliding_window: Optional[int] = None,
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
|
| 26 |
+
|
| 27 |
+
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
|
| 28 |
+
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
|
| 29 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
| 30 |
+
"""
|
| 31 |
+
attn_mask_converter = AttentionMaskConverter(
|
| 32 |
+
is_causal=False, sliding_window=sliding_window
|
| 33 |
+
) # is_causal=True in original implementation
|
| 34 |
+
|
| 35 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
| 36 |
+
batch_size, query_length = input_shape
|
| 37 |
+
|
| 38 |
+
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
| 39 |
+
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
| 40 |
+
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
| 41 |
+
is_tracing = (
|
| 42 |
+
torch.jit.is_tracing()
|
| 43 |
+
or isinstance(inputs_embeds, torch.fx.Proxy)
|
| 44 |
+
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if attention_mask is not None:
|
| 48 |
+
# 4d mask is passed through
|
| 49 |
+
if len(attention_mask.shape) == 4:
|
| 50 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
| 51 |
+
if tuple(attention_mask.shape) != expected_shape:
|
| 52 |
+
raise ValueError(
|
| 53 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
| 57 |
+
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
|
| 58 |
+
attention_mask = inverted_mask.masked_fill(
|
| 59 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
| 60 |
+
)
|
| 61 |
+
return attention_mask
|
| 62 |
+
|
| 63 |
+
elif not is_tracing and torch.all(attention_mask == 1):
|
| 64 |
+
if query_length == 1:
|
| 65 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
| 66 |
+
attention_mask = None
|
| 67 |
+
elif key_value_length == query_length:
|
| 68 |
+
attention_mask = None
|
| 69 |
+
else:
|
| 70 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
|
| 71 |
+
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
| 72 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
| 73 |
+
pass
|
| 74 |
+
elif query_length > 1 and key_value_length != query_length:
|
| 75 |
+
# See the comment above (https://github.com/pytorch/pytorch/issues/108108).
|
| 76 |
+
# Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
|
| 77 |
+
attention_mask = True
|
| 78 |
+
elif is_tracing:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if attention_mask is None:
|
| 84 |
+
expanded_4d_mask = None
|
| 85 |
+
elif attention_mask is True:
|
| 86 |
+
expanded_4d_mask = attn_mask_converter.to_causal_4d(
|
| 87 |
+
input_shape[0],
|
| 88 |
+
input_shape[-1],
|
| 89 |
+
key_value_length,
|
| 90 |
+
dtype=inputs_embeds.dtype,
|
| 91 |
+
device=inputs_embeds.device,
|
| 92 |
+
)
|
| 93 |
+
else:
|
| 94 |
+
expanded_4d_mask = attn_mask_converter.to_4d(
|
| 95 |
+
attention_mask,
|
| 96 |
+
input_shape[-1],
|
| 97 |
+
dtype=inputs_embeds.dtype,
|
| 98 |
+
key_value_length=key_value_length,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
|
| 102 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 103 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 104 |
+
if not is_tracing and expanded_4d_mask.device.type == "cuda":
|
| 105 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
| 106 |
+
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return expanded_4d_mask
|
| 110 |
+
|
| 111 |
+
|
| 112 |
def _swiglu(h: torch.Tensor) -> torch.Tensor:
|
| 113 |
h0, h1 = h.chunk(2, dim=-1)
|
| 114 |
return torch.nn.functional.silu(h0) * h1
|
|
|
|
| 913 |
|
| 914 |
|
| 915 |
PLAMO_ATTENTION_CLASSES = {
|
| 916 |
+
"sdpa": ModifiedAttention,
|
| 917 |
}
|
| 918 |
|
| 919 |
|