kazemnejad commited on
Commit
6ef1fda
·
verified ·
1 Parent(s): 02087c6

Upload CustomDecoderOnlyT5

Browse files
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CustomDecoderOnlyT5"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_custom_t5.CustomT5Config",
7
+ "AutoModelForCausalLM": "modeling_custom_t5.CustomDecoderOnlyT5"
8
+ },
9
+ "classifier_dropout": 0.0,
10
+ "d_ff": 16384,
11
+ "d_kv": 128,
12
+ "d_model": 1024,
13
+ "decoder_start_token_id": 0,
14
+ "dense_act_fn": "relu",
15
+ "dropout_rate": 0.1,
16
+ "eos_token_id": 1,
17
+ "feed_forward_proj": "relu",
18
+ "initializer_factor": 1.0,
19
+ "is_decoder": true,
20
+ "is_encoder_decoder": false,
21
+ "is_gated_act": false,
22
+ "layer_norm_epsilon": 1e-06,
23
+ "model_type": "custom_decoder_only_t5",
24
+ "n_positions": 1024,
25
+ "num_decoder_layers": 24,
26
+ "num_heads": 32,
27
+ "num_layers": 24,
28
+ "output_past": true,
29
+ "pad_token_id": 0,
30
+ "position_encoding_type": "none",
31
+ "relative_attention_max_distance": 128,
32
+ "relative_attention_num_buckets": 32,
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.31.0",
35
+ "use_cache": true,
36
+ "vocab_size": 49152
37
+ }
configuration_custom_t5.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5Config
2
+
3
+ POSITION_ENCODING_REL_T5_BIAS = "t5_relative_bias"
4
+ POSITION_ENCODING_REL_TRANSFORMER_XL = "transformer_xl_relative_encoding"
5
+ POSITION_ENCODING_ROTARY = "rotary"
6
+ POSITION_ENCODING_ROTARY_RERUN = "rotary_rerun"
7
+ POSITION_ENCODING_ROTARY_NEW = "new_rotary"
8
+ POSITION_ENCODING_ABS_LEARNED = "abs_learned"
9
+ POSITION_ENCODING_ABS_SINUSOID = "abs_sinusoid"
10
+ POSITION_ENCODING_ALiBi = "alibi"
11
+ POSITION_ENCODING_ALiBi_LEARNED = "alibi_learned"
12
+ POSITION_ENCODING_NONE = "none"
13
+ POSITION_ENCODING_NONE_WINDOW = "none_window"
14
+
15
+
16
+ class CustomT5Config(T5Config):
17
+ model_type = "custom_decoder_only_t5"
18
+
19
+ def __init__(
20
+ self,
21
+ position_encoding_type=POSITION_ENCODING_REL_T5_BIAS,
22
+ **kwargs,
23
+ ):
24
+ if position_encoding_type not in [
25
+ POSITION_ENCODING_ALiBi,
26
+ POSITION_ENCODING_ALiBi_LEARNED,
27
+ POSITION_ENCODING_ABS_LEARNED,
28
+ POSITION_ENCODING_ABS_SINUSOID,
29
+ POSITION_ENCODING_REL_T5_BIAS,
30
+ POSITION_ENCODING_REL_TRANSFORMER_XL,
31
+ POSITION_ENCODING_ROTARY,
32
+ POSITION_ENCODING_ROTARY_NEW,
33
+ POSITION_ENCODING_NONE,
34
+ POSITION_ENCODING_NONE_WINDOW,
35
+ ]:
36
+ raise ValueError(
37
+ f"Invalid position_encoding_type: {position_encoding_type}"
38
+ )
39
+ self.position_encoding_type = position_encoding_type
40
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.31.0"
7
+ }
modeling_custom_t5.py ADDED
@@ -0,0 +1,1414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import CrossEntropyLoss
11
+ from torch.utils.checkpoint import checkpoint
12
+ from transformers import T5Config
13
+ from transformers.modeling_outputs import (
14
+ BaseModelOutputWithPastAndCrossAttentions,
15
+ )
16
+ from transformers.utils import ModelOutput
17
+ from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
18
+
19
+ from .configuration_custom_t5 import (
20
+ POSITION_ENCODING_REL_T5_BIAS,
21
+ POSITION_ENCODING_REL_TRANSFORMER_XL,
22
+ POSITION_ENCODING_ROTARY,
23
+ POSITION_ENCODING_ROTARY_NEW,
24
+ POSITION_ENCODING_ABS_LEARNED,
25
+ POSITION_ENCODING_ABS_SINUSOID,
26
+ POSITION_ENCODING_ALiBi,
27
+ POSITION_ENCODING_ALiBi_LEARNED,
28
+ POSITION_ENCODING_NONE,
29
+ POSITION_ENCODING_NONE_WINDOW,
30
+ )
31
+ from .modeling_t5 import (
32
+ T5Stack,
33
+ T5PreTrainedModel,
34
+ T5Block,
35
+ T5LayerNorm,
36
+ T5LayerFF,
37
+ T5LayerSelfAttention,
38
+ T5Attention,
39
+ T5LayerCrossAttention,
40
+ )
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ @dataclass
46
+ class CausalLMOutputWithPastAndLoss(ModelOutput):
47
+ """
48
+ Base class for causal language model (or autoregressive) outputs.
49
+
50
+ Args:
51
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
52
+ Language modeling loss (for next-token prediction).
53
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
54
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
55
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
56
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
57
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
58
+
59
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
60
+ `past_key_values` input) to speed up sequential decoding.
61
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
62
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
63
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
64
+
65
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
66
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
67
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
68
+ sequence_length)`.
69
+
70
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
71
+ heads.
72
+ """
73
+
74
+ loss: Optional[torch.FloatTensor] = None
75
+ logits: torch.FloatTensor = None
76
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
77
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
78
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
79
+ non_reduced_loss: Optional[torch.FloatTensor] = None
80
+
81
+
82
+ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
83
+ dim = x.shape[-1]
84
+ if seq_len is None:
85
+ seq_len = x.shape[seq_dim]
86
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
87
+ sinusoid_inp = (
88
+ torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq)
89
+ .to(x.device)
90
+ .float()
91
+ )
92
+ return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
93
+
94
+
95
+ def rotate_every_two(x):
96
+ """
97
+ Example: [a, b, c, d] -> [-b, a, -d, c]
98
+ """
99
+ x1 = x[:, :, :, ::2]
100
+ x2 = x[:, :, :, 1::2]
101
+ x = torch.stack((-x2, x1), axis=-1)
102
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
103
+
104
+
105
+ def apply_rotary_pos_emb(x, sincos, offset=0):
106
+ sin, cos = map(
107
+ lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(
108
+ 2, 3
109
+ ),
110
+ sincos,
111
+ )
112
+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
113
+ return (x * cos) + (rotate_every_two(x) * sin)
114
+
115
+
116
+ def apply_rotary_pos_emb_new(x, sincos, offset=0):
117
+ sin, cos = map(
118
+ lambda t: t[:, :, None, :].repeat_interleave(2, 3),
119
+ sincos,
120
+ )
121
+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
122
+ return (x * cos) + (rotate_every_two(x) * sin)
123
+
124
+
125
+ class PositionalEmbedding(nn.Module):
126
+ def __init__(self, demb):
127
+ super().__init__()
128
+
129
+ self.demb = demb
130
+
131
+ inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
132
+ self.register_buffer("inv_freq", inv_freq)
133
+
134
+ def forward(self, pos_seq, bsz=None):
135
+ sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
136
+ pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
137
+
138
+ if bsz is not None:
139
+ return pos_emb[None, :, :].expand(bsz, -1, -1)
140
+ else:
141
+ return pos_emb[None, :, :]
142
+
143
+
144
+ class FixedAbsolutePositionalEmbedding(nn.Module):
145
+ def __init__(self, dim):
146
+ super().__init__()
147
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
148
+ t = torch.arange(16384).type_as(inv_freq)
149
+ sinusoid_inp = torch.einsum("i , j -> i j", t, inv_freq)
150
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
151
+ self.embed = nn.Embedding.from_pretrained(emb, freeze=True)
152
+
153
+ def forward(self, position_ids: torch.Tensor):
154
+ return self.embed(position_ids.long())
155
+
156
+
157
+ class FixedRotaryPositionalEmbedding(nn.Module):
158
+ def __init__(
159
+ self, rotary_dim: int, rotary_base: int = 10000, max_position: int = 16384
160
+ ):
161
+ super().__init__()
162
+ # This is an inverse frequency tensor
163
+ # Each dimension has a higher denominator than the previous one
164
+ # So, the frequency will be lower for higher dimensions
165
+ inv_freq = 1.0 / (
166
+ rotary_base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)
167
+ ) # [rotary_dim/2]
168
+
169
+ # Now, we create frequencies for each position
170
+ t = torch.arange(max_position, device=inv_freq.device, dtype=inv_freq.dtype)
171
+ freqs = torch.einsum("i,j->ij", t, inv_freq) # [max_position, rotary_dim/2]
172
+
173
+ sins = torch.sin(freqs)
174
+ coss = torch.cos(freqs)
175
+
176
+ emb = torch.cat([sins, coss], dim=-1) # [max_position, rotary_dim]
177
+ self.embed = nn.Embedding.from_pretrained(emb, freeze=True)
178
+
179
+ def forward(self, position_ids: torch.Tensor):
180
+ return self.embed(position_ids.long())
181
+
182
+
183
+ class CustomT5Attention(T5Attention):
184
+ def __init__(self, config: T5Config, has_relative_attention_bias=False):
185
+ super(T5Attention, self).__init__()
186
+ self.is_decoder = config.is_decoder
187
+ self.has_relative_attention_bias = has_relative_attention_bias
188
+
189
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
190
+ self.d_model = config.d_model
191
+ self.key_value_proj_dim = config.d_kv
192
+ self.d_head = config.d_kv
193
+ self.n_heads = config.num_heads
194
+ self.dropout = config.dropout_rate
195
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
196
+
197
+ # Mesh TensorFlow initialization to avoid scaling before softmax
198
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
199
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
200
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
201
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
202
+
203
+ self.position_encoding_type = getattr(
204
+ config, "position_encoding_type", POSITION_ENCODING_REL_T5_BIAS
205
+ )
206
+
207
+ if self.has_relative_attention_bias:
208
+ self.relative_attention_bias = nn.Embedding(
209
+ self.relative_attention_num_buckets, self.n_heads
210
+ )
211
+
212
+ if self.position_encoding_type == POSITION_ENCODING_REL_TRANSFORMER_XL:
213
+ self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_heads, self.d_head))
214
+ nn.init.normal_(
215
+ self.r_r_bias, mean=0.0, std=config.initializer_factor * 0.2
216
+ )
217
+ self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_heads, self.d_head))
218
+ nn.init.normal_(
219
+ self.r_w_bias, mean=0.0, std=config.initializer_factor * 0.2
220
+ )
221
+ self.r = nn.Linear(self.d_model, self.n_heads * self.d_head, bias=False)
222
+ self.r.weight.data.normal_(
223
+ mean=0.0, std=config.initializer_factor * (self.d_model**-0.5)
224
+ )
225
+ self.pos_emb = PositionalEmbedding(self.d_model)
226
+ self.clamp_length = 1000
227
+
228
+ if self.position_encoding_type == POSITION_ENCODING_ROTARY:
229
+ self.rotary_dim = None
230
+ if getattr(config, "rotary_dim", None) is not None:
231
+ self.rotary_dim = config.rotary_dim
232
+ self.rotary_dim = int(0.25 * self.d_head)
233
+
234
+ if self.position_encoding_type == POSITION_ENCODING_ROTARY_NEW:
235
+ # We hardcode the rotary dim to 25 percent of the head dim
236
+ self.rotary_dim = self.d_head // 4
237
+
238
+ self.pruned_heads = set()
239
+ self.gradient_checkpointing = False
240
+
241
+ def _rel_shift(self, x):
242
+ zero_pad_shape = x.size()[:2] + (x.size(2), 1)
243
+ zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
244
+ x_padded = torch.cat([zero_pad, x], dim=3)
245
+ x_padded_shape = x.size()[:2] + (x.size(3) + 1, x.size(2))
246
+ x_padded = x_padded.view(*x_padded_shape)
247
+ x = x_padded[:, :, 1:, :].view_as(x)
248
+ return x
249
+
250
+ def forward(
251
+ self,
252
+ hidden_states,
253
+ mask=None,
254
+ position_bias=None,
255
+ key_value_states=None,
256
+ past_key_value=None,
257
+ layer_head_mask=None,
258
+ query_length=None,
259
+ use_cache=False,
260
+ output_attentions=False,
261
+ ):
262
+ """
263
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
264
+ """
265
+ # Input is (batch_size, seq_length, dim)
266
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
267
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
268
+ batch_size, seq_length = hidden_states.shape[:2]
269
+
270
+ real_seq_length = seq_length
271
+
272
+ if past_key_value is not None:
273
+ assert (
274
+ len(past_key_value) == 2
275
+ ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
276
+ real_seq_length += (
277
+ past_key_value[0].shape[2] if query_length is None else query_length
278
+ )
279
+
280
+ key_length = (
281
+ real_seq_length if key_value_states is None else key_value_states.shape[1]
282
+ )
283
+
284
+ def shape(states):
285
+ """projection"""
286
+ return states.view(
287
+ batch_size, -1, self.n_heads, self.key_value_proj_dim
288
+ ).transpose(1, 2)
289
+
290
+ def unshape(states):
291
+ """reshape"""
292
+ return (
293
+ states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
294
+ )
295
+
296
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
297
+ """projects hidden states correctly to key/query states"""
298
+ if key_value_states is None:
299
+ # self-attn
300
+ # (batch_size, n_heads, seq_length, dim_per_head)
301
+ hidden_states = shape(proj_layer(hidden_states))
302
+ elif past_key_value is None:
303
+ # cross-attn
304
+ # (batch_size, n_heads, seq_length, dim_per_head)
305
+ hidden_states = shape(proj_layer(key_value_states))
306
+
307
+ if past_key_value is not None:
308
+ if key_value_states is None:
309
+ # self-attn
310
+ # (batch_size, n_heads, key_length, dim_per_head)
311
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
312
+ else:
313
+ # cross-attn
314
+ hidden_states = past_key_value
315
+ return hidden_states
316
+
317
+ # get query states
318
+ query_states = shape(
319
+ self.q(hidden_states)
320
+ ) # (batch_size, n_heads, seq_length, dim_per_head)
321
+
322
+ if self.position_encoding_type in [
323
+ POSITION_ENCODING_ROTARY,
324
+ POSITION_ENCODING_ROTARY_NEW,
325
+ ]:
326
+ key_states = shape(self.k(hidden_states))
327
+ else:
328
+ # get key/value states
329
+ key_states = project(
330
+ hidden_states,
331
+ self.k,
332
+ key_value_states,
333
+ past_key_value[0] if past_key_value is not None else None,
334
+ )
335
+
336
+ value_states = project(
337
+ hidden_states,
338
+ self.v,
339
+ key_value_states,
340
+ past_key_value[1] if past_key_value is not None else None,
341
+ )
342
+
343
+ attention_output_dict = {}
344
+
345
+ if self.position_encoding_type == POSITION_ENCODING_REL_T5_BIAS:
346
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
347
+ attention_output_dict["scores_before"] = scores
348
+ if position_bias is None:
349
+ if not self.has_relative_attention_bias:
350
+ position_bias = torch.zeros(
351
+ (1, self.n_heads, real_seq_length, key_length),
352
+ device=scores.device,
353
+ dtype=scores.dtype,
354
+ )
355
+ if self.gradient_checkpointing and self.training:
356
+ position_bias.requires_grad = True
357
+ else:
358
+ position_bias = self.compute_bias(real_seq_length, key_length)
359
+
360
+ # if key and values are already calculated
361
+ # we want only the last query position bias
362
+ if past_key_value is not None:
363
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
364
+
365
+ if mask is not None:
366
+ position_bias = (
367
+ position_bias + mask
368
+ ) # (batch_size, n_heads, seq_length, key_length)
369
+
370
+ scores += position_bias
371
+ elif self.position_encoding_type == POSITION_ENCODING_REL_TRANSFORMER_XL:
372
+ if position_bias is None:
373
+ pos_seq = torch.arange(
374
+ real_seq_length - 1,
375
+ -1,
376
+ -1.0,
377
+ device=hidden_states.device,
378
+ dtype=hidden_states.dtype,
379
+ )
380
+ if self.clamp_length > 0:
381
+ pos_seq = pos_seq.clamp_(max=self.clamp_length)
382
+ position_bias = self.pos_emb(pos_seq)
383
+ position_bias = nn.functional.dropout(
384
+ position_bias, p=self.dropout, training=self.training
385
+ )
386
+
387
+ position_embeds = position_bias # position embeds: [1, seq_len, d_model]
388
+
389
+ r_head_k = self.r(position_embeds) # [1, seq_len, n_head*d_head]
390
+ r_head_k = r_head_k.view(
391
+ position_embeds.shape[1], self.n_heads, self.d_head
392
+ ) # [seq_len, n_head, d_head]
393
+
394
+ rw_head_q = query_states + self.r_w_bias[None, :, None, :]
395
+ AC = torch.einsum("bnqd,bnkd->bnqk", (rw_head_q, key_states))
396
+
397
+ rr_head_q = query_states + self.r_r_bias[None, :, None, :]
398
+ BD = torch.einsum("bnid,jnd->bnij", (rr_head_q, r_head_k))
399
+ BD = self._rel_shift(BD)
400
+
401
+ scores = AC + BD
402
+
403
+ if mask is not None:
404
+ scores += mask
405
+ elif self.position_encoding_type == POSITION_ENCODING_ROTARY:
406
+ r_seq_len = hidden_states.shape[1]
407
+ r_offset = 0
408
+
409
+ if past_key_value is not None:
410
+ r_offset = past_key_value[0].shape[2]
411
+ r_seq_len += r_offset
412
+
413
+ query_states = query_states.permute(0, 2, 1, 3)
414
+ key_states = key_states.permute(0, 2, 1, 3)
415
+
416
+ if self.rotary_dim is not None:
417
+ k_rot = key_states[:, :, :, : self.rotary_dim]
418
+ k_pass = key_states[:, :, :, self.rotary_dim :]
419
+
420
+ q_rot = query_states[:, :, :, : self.rotary_dim]
421
+ q_pass = query_states[:, :, :, self.rotary_dim :]
422
+
423
+ sincos = fixed_pos_embedding(k_rot, 1, seq_len=r_seq_len)
424
+ k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=r_offset)
425
+ q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=r_offset)
426
+
427
+ if output_attentions:
428
+ scores_pass = torch.matmul(
429
+ q_pass.permute(0, 2, 1, 3),
430
+ k_pass.permute(0, 2, 1, 3).transpose(3, 2),
431
+ )
432
+ attention_output_dict["scores_pass"] = scores_pass
433
+
434
+ scores_rot = torch.matmul(
435
+ q_rot.permute(0, 2, 1, 3),
436
+ k_rot.permute(0, 2, 1, 3).transpose(3, 2),
437
+ )
438
+ attention_output_dict["scores_rot"] = scores_rot
439
+
440
+ key_states = torch.cat([k_rot, k_pass], dim=-1)
441
+ query_states = torch.cat([q_rot, q_pass], dim=-1)
442
+ else:
443
+ sincos = fixed_pos_embedding(key_states, 1, seq_len=r_seq_len)
444
+ key_states = apply_rotary_pos_emb(key_states, sincos, offset=r_offset)
445
+ query_states = apply_rotary_pos_emb(
446
+ query_states, sincos, offset=r_offset
447
+ )
448
+
449
+ query_states = query_states.permute(0, 2, 1, 3)
450
+ key_states = key_states.permute(0, 2, 1, 3)
451
+
452
+ if past_key_value is not None:
453
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
454
+
455
+ scores = torch.matmul(
456
+ query_states, key_states.transpose(3, 2)
457
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
458
+ if mask is not None:
459
+ scores += mask # (batch_size, n_heads, seq_length, key_length)
460
+
461
+ elif self.position_encoding_type == POSITION_ENCODING_ROTARY_NEW:
462
+ r_seq_len = hidden_states.shape[1]
463
+ r_offset = 0
464
+
465
+ if past_key_value is not None:
466
+ r_offset = past_key_value[0].shape[2]
467
+ r_seq_len += r_offset
468
+
469
+ query_states = query_states.permute(0, 2, 1, 3)
470
+ key_states = key_states.permute(0, 2, 1, 3)
471
+
472
+ if self.rotary_dim is not None:
473
+ k_rot = key_states[:, :, :, : self.rotary_dim]
474
+ k_pass = key_states[:, :, :, self.rotary_dim :]
475
+
476
+ q_rot = query_states[:, :, :, : self.rotary_dim]
477
+ q_pass = query_states[:, :, :, self.rotary_dim :]
478
+
479
+ sincos = position_bias
480
+ # sincos is just vector created by torch.cat([sin, cos], dim=-1)
481
+ # so we can just split it in half
482
+ sin = sincos[:, :, : self.rotary_dim // 2]
483
+ cos = sincos[:, :, self.rotary_dim // 2 :]
484
+
485
+ # We don't need to pass offset here, because we already used
486
+ # position_ids to retrieve correct sin and cos vectors
487
+ k_rot = apply_rotary_pos_emb_new(k_rot, (sin, cos))
488
+ q_rot = apply_rotary_pos_emb_new(q_rot, (sin, cos))
489
+
490
+ key_states = torch.cat([k_rot, k_pass], dim=-1)
491
+ query_states = torch.cat([q_rot, q_pass], dim=-1)
492
+ else:
493
+ raise ValueError("rotary_dim is None")
494
+
495
+ query_states = query_states.permute(0, 2, 1, 3)
496
+ key_states = key_states.permute(0, 2, 1, 3)
497
+
498
+ if past_key_value is not None:
499
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
500
+
501
+ scores = torch.matmul(
502
+ query_states, key_states.transpose(3, 2)
503
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
504
+ if mask is not None:
505
+ scores += mask # (batch_size, n_heads, seq_length, key_length)
506
+ elif self.position_encoding_type == POSITION_ENCODING_ALiBi:
507
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
508
+ attention_output_dict["scores_before"] = scores
509
+
510
+ alibi = position_bias
511
+ alibi = alibi.view(batch_size, self.n_heads, 1, key_length)
512
+
513
+ # if key and values are already calculated
514
+ # we want only the last query position bias
515
+ if past_key_value is not None:
516
+ alibi = alibi[:, :, -hidden_states.size(1) :, :]
517
+
518
+ if mask is not None:
519
+ alibi = alibi + mask # (batch_size, n_heads, seq_length, key_length)
520
+
521
+ scores += alibi
522
+ else:
523
+ assert (
524
+ self.position_encoding_type == POSITION_ENCODING_NONE
525
+ ), f"Unknown position encoding type: {self.position_encoding_type}"
526
+ scores = torch.matmul(
527
+ query_states, key_states.transpose(3, 2)
528
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
529
+ if mask is not None:
530
+ scores += mask # (batch_size, n_heads, seq_length, key_length)
531
+
532
+ attention_output_dict["scores"] = scores
533
+
534
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
535
+ scores
536
+ ) # (batch_size, n_heads, seq_length, key_length)
537
+ attn_weights = nn.functional.dropout(
538
+ attn_weights, p=self.dropout, training=self.training
539
+ ) # (batch_size, n_heads, seq_length, key_length)
540
+
541
+ # Mask heads if we want to
542
+ if layer_head_mask is not None:
543
+ attn_weights = attn_weights * layer_head_mask
544
+
545
+ attention_output_dict["probs"] = attn_weights
546
+
547
+ attn_output = unshape(
548
+ torch.matmul(attn_weights, value_states)
549
+ ) # (batch_size, seq_length, dim)
550
+ attn_output = self.o(attn_output)
551
+
552
+ present_key_value_state = (
553
+ (key_states, value_states) if (self.is_decoder and use_cache) else None
554
+ )
555
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
556
+
557
+ if output_attentions:
558
+ outputs = outputs + (attention_output_dict,)
559
+ return outputs
560
+
561
+
562
+ class CustomT5LayerSelfAttention(T5LayerSelfAttention):
563
+ def __init__(self, config, has_relative_attention_bias=False):
564
+ super(T5LayerSelfAttention, self).__init__()
565
+ self.SelfAttention = CustomT5Attention(
566
+ config, has_relative_attention_bias=has_relative_attention_bias
567
+ )
568
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
569
+ self.dropout = nn.Dropout(config.dropout_rate)
570
+
571
+
572
+ class CustomT5Block(T5Block):
573
+ def __init__(self, config, has_relative_attention_bias=False):
574
+ super(T5Block, self).__init__()
575
+ self.is_decoder = config.is_decoder
576
+ assert self.is_decoder
577
+ self.layer = nn.ModuleList()
578
+ self.layer.append(
579
+ CustomT5LayerSelfAttention(
580
+ config, has_relative_attention_bias=has_relative_attention_bias
581
+ )
582
+ )
583
+ if self.is_decoder:
584
+ self.layer.append(T5LayerCrossAttention(config))
585
+
586
+ self.layer.append(T5LayerFF(config))
587
+
588
+
589
+ def _make_causal_mask(
590
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
591
+ ) -> torch.BoolTensor:
592
+ """
593
+ Make causal mask used for self-attention.
594
+ """
595
+ batch_size, target_length = input_ids_shape
596
+ mask = torch.empty(
597
+ (target_length, target_length + past_key_values_length),
598
+ dtype=torch.bool,
599
+ device=device,
600
+ )
601
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
602
+ seq_ids = torch.arange(target_length, device=device)
603
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
604
+
605
+ if past_key_values_length > 0:
606
+ mask[:, :past_key_values_length] = False
607
+
608
+ expanded_mask = mask[None, None, :, :].expand(
609
+ batch_size, 1, target_length, target_length + past_key_values_length
610
+ )
611
+ return expanded_mask
612
+
613
+
614
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
615
+ """
616
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
617
+ """
618
+ batch_size, src_length = mask.shape
619
+ tgt_length = tgt_length if tgt_length is not None else src_length
620
+
621
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
622
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
623
+
624
+
625
+ def build_alibi_tensor(
626
+ attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
627
+ ) -> torch.Tensor:
628
+ """
629
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
630
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
631
+ `softmax(l+a) = softmax(l)`. Based on
632
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
633
+ TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
634
+ Args:
635
+ Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
636
+ attention_mask (`torch.Tensor`):
637
+ Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
638
+ num_heads (`int`, *required*):
639
+ number of heads
640
+ dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
641
+ dtype of the output tensor
642
+ """
643
+ if len(attention_mask.shape) == 2:
644
+ batch_size, seq_length = attention_mask.shape
645
+ elif len(attention_mask.shape) == 3:
646
+ batch_size, _, seq_length = attention_mask.shape
647
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
648
+ base = torch.tensor(
649
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
650
+ device=attention_mask.device,
651
+ dtype=torch.float32,
652
+ )
653
+ powers = torch.arange(
654
+ 1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
655
+ )
656
+ slopes = torch.pow(base, powers)
657
+
658
+ if closest_power_of_2 != num_heads:
659
+ extra_base = torch.tensor(
660
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
661
+ device=attention_mask.device,
662
+ dtype=torch.float32,
663
+ )
664
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
665
+ extra_powers = torch.arange(
666
+ 1,
667
+ 1 + 2 * num_remaining_heads,
668
+ 2,
669
+ device=attention_mask.device,
670
+ dtype=torch.int32,
671
+ )
672
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
673
+
674
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
675
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
676
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
677
+ # => the query_length dimension will then be broadcasted correctly
678
+ # This is more or less identical to T5's relative position bias:
679
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
680
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
681
+ alibi = slopes[..., None] * arange_tensor
682
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
683
+
684
+
685
+ class CustomT5Stack(T5Stack):
686
+ def __init__(self, config, embed_tokens=None):
687
+ super(T5Stack, self).__init__(config)
688
+
689
+ self.embed_tokens = embed_tokens
690
+ self.is_decoder = config.is_decoder
691
+ self.position_encoding_type = getattr(
692
+ config, "position_encoding_type", POSITION_ENCODING_REL_T5_BIAS
693
+ )
694
+
695
+ logger.info(f"position_encoding_type: {self.position_encoding_type}")
696
+
697
+ self.block = nn.ModuleList(
698
+ [
699
+ CustomT5Block(config, has_relative_attention_bias=bool(i == 0))
700
+ for i in range(config.num_layers)
701
+ ]
702
+ )
703
+ self.final_layer_norm = T5LayerNorm(
704
+ config.d_model, eps=config.layer_norm_epsilon
705
+ )
706
+ self.dropout = nn.Dropout(config.dropout_rate)
707
+
708
+ if self.position_encoding_type == POSITION_ENCODING_ABS_LEARNED:
709
+ self.wpe = nn.Embedding(2048, config.d_model)
710
+ parent_dir = Path(os.path.dirname(os.path.abspath(__file__)))
711
+ learned_embed_file = parent_dir / "gpt_neo_125m_pos_embed.npy"
712
+ if learned_embed_file.exists():
713
+ logger.info(
714
+ "Loading position embedding from {}".format(learned_embed_file)
715
+ )
716
+ import numpy as np
717
+
718
+ weight = np.load(str(learned_embed_file))
719
+ self.wpe.weight.data.copy_(torch.from_numpy(weight))
720
+ self.wpe.weight.requires_grad = False
721
+ else:
722
+ self.wpe.weight.data.normal_(
723
+ mean=0.0, std=config.initializer_factor * 1.0
724
+ )
725
+
726
+ if self.position_encoding_type == POSITION_ENCODING_ABS_SINUSOID:
727
+ self.wpe = FixedAbsolutePositionalEmbedding(config.d_model)
728
+
729
+ if self.position_encoding_type == POSITION_ENCODING_ROTARY_NEW:
730
+ # Rotary dim is X percentage of d_head
731
+ # Right now, we just hardcode X here following:
732
+ # https://github.com/huggingface/transformers/blob/v4.26.0/src/transformers/models/gpt_neox/configuration_gpt_neox.py
733
+ rotary_dim = int(config.d_kv * 0.25)
734
+ self.fixed_rotary_embedding = FixedRotaryPositionalEmbedding(
735
+ rotary_dim, max_position=4096
736
+ )
737
+
738
+ if self.position_encoding_type in [
739
+ POSITION_ENCODING_ALiBi,
740
+ POSITION_ENCODING_ALiBi_LEARNED,
741
+ ]:
742
+ maxpos = 2048
743
+ attn_heads = config.num_heads
744
+ if self.position_encoding_type == POSITION_ENCODING_ALiBi_LEARNED:
745
+ self.learned_logslopes = nn.Parameter(
746
+ torch.log(torch.Tensor(self.get_slopes(attn_heads)))
747
+ )
748
+ else:
749
+ slopes = torch.Tensor(self.get_slopes(attn_heads))
750
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(
751
+ maxpos
752
+ ).unsqueeze(0).unsqueeze(0).expand(attn_heads, -1, -1)
753
+ alibi = alibi.view(attn_heads, 1, maxpos)
754
+ self.register_buffer("alibi", alibi)
755
+
756
+ # Initialize weights and apply final processing
757
+ self.post_init()
758
+ # Model parallel
759
+ self.model_parallel = False
760
+ self.device_map = None
761
+ self.gradient_checkpointing = False
762
+
763
+ self.window_size = 80 # only used for none_windowed
764
+
765
+ def _alibi_prepare_attn_mask(
766
+ self,
767
+ attention_mask: torch.Tensor,
768
+ input_shape: Tuple[int, int],
769
+ past_key_values_length: int,
770
+ ) -> torch.BoolTensor:
771
+ # create causal mask
772
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
773
+ combined_attention_mask = None
774
+ device = attention_mask.device
775
+ _, src_length = input_shape
776
+
777
+ if src_length > 1:
778
+ combined_attention_mask = _make_causal_mask(
779
+ input_shape,
780
+ device=device,
781
+ past_key_values_length=past_key_values_length,
782
+ )
783
+
784
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
785
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
786
+ combined_attention_mask = (
787
+ expanded_attn_mask
788
+ if combined_attention_mask is None
789
+ else expanded_attn_mask | combined_attention_mask
790
+ )
791
+
792
+ return combined_attention_mask
793
+
794
+ def get_slopes(self, n):
795
+ def get_slopes_power_of_2(n):
796
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
797
+ ratio = start
798
+ return [start * ratio**i for i in range(n)]
799
+
800
+ if math.log2(n).is_integer():
801
+ return get_slopes_power_of_2(
802
+ n
803
+ ) # In the paper, we only train models that have 2^a heads for some a. This function has
804
+ else: # some good properties that only occur when the input is a power of 2. To maintain that even
805
+ closest_power_of_2 = 2 ** math.floor(
806
+ math.log2(n)
807
+ ) # when the number of heads is not a power of 2, we use this workaround.
808
+ return (
809
+ get_slopes_power_of_2(closest_power_of_2)
810
+ + self.get_slopes(2 * closest_power_of_2)[0::2][
811
+ : n - closest_power_of_2
812
+ ]
813
+ )
814
+
815
+ def forward(
816
+ self,
817
+ input_ids=None,
818
+ attention_mask=None,
819
+ encoder_hidden_states=None,
820
+ encoder_attention_mask=None,
821
+ inputs_embeds=None,
822
+ head_mask=None,
823
+ cross_attn_head_mask=None,
824
+ past_key_values=None,
825
+ use_cache=None,
826
+ output_attentions=None,
827
+ output_hidden_states=None,
828
+ position_ids=None,
829
+ return_dict=None,
830
+ ):
831
+ # Model parallel
832
+ if self.model_parallel:
833
+ torch.cuda.set_device(self.first_device)
834
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
835
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
836
+ output_attentions = (
837
+ output_attentions
838
+ if output_attentions is not None
839
+ else self.config.output_attentions
840
+ )
841
+ output_hidden_states = (
842
+ output_hidden_states
843
+ if output_hidden_states is not None
844
+ else self.config.output_hidden_states
845
+ )
846
+ return_dict = (
847
+ return_dict if return_dict is not None else self.config.use_return_dict
848
+ )
849
+
850
+ if input_ids is not None and inputs_embeds is not None:
851
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
852
+ raise ValueError(
853
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
854
+ )
855
+ elif input_ids is not None:
856
+ input_shape = input_ids.size()
857
+ input_ids = input_ids.view(-1, input_shape[-1])
858
+ elif inputs_embeds is not None:
859
+ input_shape = inputs_embeds.size()[:-1]
860
+ else:
861
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
862
+ raise ValueError(
863
+ f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
864
+ )
865
+
866
+ if inputs_embeds is None:
867
+ assert (
868
+ self.embed_tokens is not None
869
+ ), "You have to initialize the model with valid token embeddings"
870
+ inputs_embeds = self.embed_tokens(input_ids)
871
+
872
+ if self.position_encoding_type in [
873
+ POSITION_ENCODING_ABS_LEARNED,
874
+ POSITION_ENCODING_ABS_SINUSOID,
875
+ ]:
876
+ if position_ids is not None:
877
+ position_ids = position_ids.view(-1, input_shape[-1])
878
+
879
+ if past_key_values is None:
880
+ past_length = 0
881
+ else:
882
+ past_length = past_key_values[0][0].size(-2)
883
+
884
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
885
+ if position_ids is None:
886
+ position_ids = torch.arange(
887
+ past_length,
888
+ input_shape[-1] + past_length,
889
+ dtype=torch.long,
890
+ device=device,
891
+ )
892
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
893
+
894
+ position_embeds = self.wpe(position_ids)
895
+ inputs_embeds += position_embeds
896
+
897
+ batch_size, seq_length = input_shape
898
+
899
+ # `position_bias` is a just tensor that is passed to all attention layers
900
+ position_bias = None
901
+
902
+ # required mask seq length can be calculated via length of past
903
+ mask_seq_length = (
904
+ past_key_values[0][0].shape[2] + seq_length
905
+ if past_key_values is not None
906
+ else seq_length
907
+ )
908
+
909
+ if use_cache is True:
910
+ assert (
911
+ self.is_decoder
912
+ ), f"`use_cache` can only be set to `True` if {self} is used as a decoder"
913
+
914
+ if attention_mask is None:
915
+ attention_mask = torch.ones(batch_size, mask_seq_length).to(
916
+ inputs_embeds.device
917
+ )
918
+ if (
919
+ self.is_decoder
920
+ and encoder_attention_mask is None
921
+ and encoder_hidden_states is not None
922
+ ):
923
+ encoder_seq_length = encoder_hidden_states.shape[1]
924
+ encoder_attention_mask = torch.ones(
925
+ batch_size,
926
+ encoder_seq_length,
927
+ device=inputs_embeds.device,
928
+ dtype=torch.long,
929
+ )
930
+
931
+ if self.position_encoding_type == POSITION_ENCODING_ROTARY_NEW:
932
+ if position_ids is not None:
933
+ position_ids = position_ids.view(-1, input_shape[-1])
934
+
935
+ if past_key_values is None:
936
+ past_length = 0
937
+ else:
938
+ past_length = past_key_values[0][0].size(-2)
939
+
940
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
941
+ if position_ids is None:
942
+ position_ids = torch.arange(
943
+ past_length,
944
+ input_shape[-1] + past_length,
945
+ dtype=torch.long,
946
+ device=device,
947
+ )
948
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
949
+
950
+ sinusoidal_pos = self.fixed_rotary_embedding(position_ids)
951
+ position_bias = sinusoidal_pos
952
+
953
+ # initialize past_key_values with `None` if past does not exist
954
+ if past_key_values is None:
955
+ past_key_values = [None] * len(self.block)
956
+
957
+ if self.position_encoding_type == POSITION_ENCODING_NONE_WINDOW:
958
+ indices = torch.arange(seq_length, device=inputs_embeds.device)
959
+ causal_mask = indices[:, None] >= indices
960
+ window_mask = (
961
+ (indices.unsqueeze(0) - indices.unsqueeze(0).T)
962
+ .abs()
963
+ .less(self.window_size)
964
+ )
965
+ causal_mask = causal_mask & window_mask
966
+ attention_mask = causal_mask.int()
967
+
968
+ # Repeat the mask for each sample in the batch
969
+ attention_mask = attention_mask[None, :, :].expand(
970
+ batch_size, seq_length, seq_length
971
+ )
972
+
973
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
974
+ # ourselves in which case we just need to make it broadcastable to all heads.
975
+ extended_attention_mask = self.get_extended_attention_mask(
976
+ attention_mask, input_shape, inputs_embeds.device
977
+ )
978
+
979
+ if self.position_encoding_type == POSITION_ENCODING_ALiBi:
980
+ num_heads = self.config.num_heads
981
+ if len(attention_mask.shape) == 3:
982
+ # We need to make a default attention mask
983
+ alibi_attention_mask = torch.ones(batch_size, mask_seq_length).to(
984
+ inputs_embeds.device
985
+ )
986
+ else:
987
+ alibi_attention_mask = attention_mask
988
+
989
+ alibi = build_alibi_tensor(
990
+ alibi_attention_mask, num_heads, dtype=inputs_embeds.dtype
991
+ )
992
+ position_bias = alibi
993
+ del alibi_attention_mask
994
+
995
+ if self.position_encoding_type in [POSITION_ENCODING_ALiBi_LEARNED]:
996
+ if not hasattr(self, "alibi"):
997
+ maxpos = 2048
998
+ attn_heads = self.config.num_heads
999
+ slopes = self.learned_logslopes.exp()
1000
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(
1001
+ maxpos, device=slopes.device
1002
+ ).unsqueeze(0).unsqueeze(0).expand(attn_heads, -1, -1)
1003
+ alibi = alibi.view(attn_heads, 1, maxpos)
1004
+ else:
1005
+ alibi = self.alibi
1006
+
1007
+ alibi = alibi.unsqueeze(0).repeat(batch_size, 1, 1, 1)
1008
+ alibi = alibi[:, :, :, : attention_mask.shape[-1]]
1009
+ alibi = alibi.repeat(1, 1, extended_attention_mask.shape[2], 1)
1010
+ extended_attention_mask = torch.where(
1011
+ extended_attention_mask == 0,
1012
+ alibi,
1013
+ extended_attention_mask.repeat(1, self.config.num_heads, 1, 1),
1014
+ )
1015
+
1016
+ # If a 2D or 3D attention mask is provided for the cross-attention
1017
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1018
+ if self.is_decoder and encoder_hidden_states is not None:
1019
+ (
1020
+ encoder_batch_size,
1021
+ encoder_sequence_length,
1022
+ _,
1023
+ ) = encoder_hidden_states.size()
1024
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1025
+ if encoder_attention_mask is None:
1026
+ encoder_attention_mask = torch.ones(
1027
+ encoder_hidden_shape, device=inputs_embeds.device
1028
+ )
1029
+ encoder_extended_attention_mask = self.invert_attention_mask(
1030
+ encoder_attention_mask
1031
+ )
1032
+ else:
1033
+ encoder_extended_attention_mask = None
1034
+
1035
+ # Prepare head mask if needed
1036
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
1037
+ cross_attn_head_mask = self.get_head_mask(
1038
+ cross_attn_head_mask, self.config.num_layers
1039
+ )
1040
+ present_key_value_states = () if use_cache else None
1041
+ all_hidden_states = () if output_hidden_states else None
1042
+ all_attentions = () if output_attentions else None
1043
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
1044
+ # position_bias = None
1045
+ encoder_decoder_position_bias = None
1046
+
1047
+ hidden_states = self.dropout(inputs_embeds)
1048
+
1049
+ for i, (layer_module, past_key_value) in enumerate(
1050
+ zip(self.block, past_key_values)
1051
+ ):
1052
+ layer_head_mask = head_mask[i]
1053
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
1054
+ # Model parallel
1055
+ if self.model_parallel:
1056
+ torch.cuda.set_device(hidden_states.device)
1057
+ # Ensure that attention_mask is always on the same device as hidden_states
1058
+ if attention_mask is not None:
1059
+ attention_mask = attention_mask.to(hidden_states.device)
1060
+ if position_bias is not None:
1061
+ position_bias = position_bias.to(hidden_states.device)
1062
+ if encoder_hidden_states is not None:
1063
+ encoder_hidden_states = encoder_hidden_states.to(
1064
+ hidden_states.device
1065
+ )
1066
+ if encoder_extended_attention_mask is not None:
1067
+ encoder_extended_attention_mask = (
1068
+ encoder_extended_attention_mask.to(hidden_states.device)
1069
+ )
1070
+ if encoder_decoder_position_bias is not None:
1071
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(
1072
+ hidden_states.device
1073
+ )
1074
+ if layer_head_mask is not None:
1075
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
1076
+ if cross_attn_layer_head_mask is not None:
1077
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
1078
+ hidden_states.device
1079
+ )
1080
+ if output_hidden_states:
1081
+ all_hidden_states = all_hidden_states + (hidden_states,)
1082
+
1083
+ if self.gradient_checkpointing and self.training:
1084
+ if use_cache:
1085
+ logger.warn(
1086
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1087
+ )
1088
+ use_cache = False
1089
+
1090
+ def create_custom_forward(module):
1091
+ def custom_forward(*inputs):
1092
+ return tuple(module(*inputs, use_cache, output_attentions))
1093
+
1094
+ return custom_forward
1095
+
1096
+ layer_outputs = checkpoint(
1097
+ create_custom_forward(layer_module),
1098
+ hidden_states,
1099
+ extended_attention_mask,
1100
+ position_bias,
1101
+ encoder_hidden_states,
1102
+ encoder_extended_attention_mask,
1103
+ encoder_decoder_position_bias,
1104
+ layer_head_mask,
1105
+ cross_attn_layer_head_mask,
1106
+ None, # past_key_value is always None with gradient checkpointing
1107
+ )
1108
+ else:
1109
+ layer_outputs = layer_module(
1110
+ hidden_states,
1111
+ attention_mask=extended_attention_mask,
1112
+ position_bias=position_bias,
1113
+ encoder_hidden_states=encoder_hidden_states,
1114
+ encoder_attention_mask=encoder_extended_attention_mask,
1115
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
1116
+ layer_head_mask=layer_head_mask,
1117
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1118
+ past_key_value=past_key_value,
1119
+ use_cache=use_cache,
1120
+ output_attentions=output_attentions,
1121
+ )
1122
+
1123
+ # layer_outputs is a tuple with:
1124
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1125
+ if use_cache is False:
1126
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1127
+
1128
+ hidden_states, present_key_value_state = layer_outputs[:2]
1129
+
1130
+ # We share the position biases between the layers - the first layer store them
1131
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1132
+ # (cross-attention position bias), (cross-attention weights)
1133
+ position_bias = layer_outputs[2]
1134
+ if self.is_decoder and encoder_hidden_states is not None:
1135
+ encoder_decoder_position_bias = layer_outputs[
1136
+ 4 if output_attentions else 3
1137
+ ]
1138
+ # append next layer key value states
1139
+ if use_cache:
1140
+ present_key_value_states = present_key_value_states + (
1141
+ present_key_value_state,
1142
+ )
1143
+
1144
+ if output_attentions:
1145
+ all_attentions = all_attentions + (layer_outputs[3],)
1146
+ if self.is_decoder:
1147
+ all_cross_attentions = all_cross_attentions + (None,)
1148
+
1149
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1150
+ if self.model_parallel:
1151
+ for k, v in self.device_map.items():
1152
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1153
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1154
+
1155
+ hidden_states = self.final_layer_norm(hidden_states)
1156
+ hidden_states = self.dropout(hidden_states)
1157
+
1158
+ # Add last layer
1159
+ if output_hidden_states:
1160
+ all_hidden_states = all_hidden_states + (hidden_states,)
1161
+
1162
+ if not return_dict:
1163
+ return tuple(
1164
+ v
1165
+ for v in [
1166
+ hidden_states,
1167
+ present_key_value_states,
1168
+ all_hidden_states,
1169
+ all_attentions,
1170
+ all_cross_attentions,
1171
+ ]
1172
+ if v is not None
1173
+ )
1174
+ return BaseModelOutputWithPastAndCrossAttentions(
1175
+ last_hidden_state=hidden_states,
1176
+ past_key_values=present_key_value_states,
1177
+ hidden_states=all_hidden_states,
1178
+ attentions=all_attentions,
1179
+ cross_attentions=all_cross_attentions,
1180
+ )
1181
+
1182
+
1183
+ class CustomDecoderOnlyT5(T5PreTrainedModel):
1184
+ _keys_to_ignore_on_load_missing = [
1185
+ r"decoder\.embed_tokens\.weight",
1186
+ r"encoder",
1187
+ r"lm_head\.weight",
1188
+ ]
1189
+ _keys_to_ignore_on_load_unexpected = [
1190
+ r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
1191
+ ]
1192
+
1193
+ def __init__(
1194
+ self,
1195
+ config=None,
1196
+ output_non_reduced_loss: bool = False,
1197
+ **kwargs,
1198
+ ):
1199
+ assert config is not None
1200
+ config.is_decoder = True
1201
+ config.is_encoder_decoder = False
1202
+
1203
+ assert (
1204
+ config.position_encoding_type is not None
1205
+ ), "Position encoding type must be set"
1206
+
1207
+ self.output_non_reduced_loss = output_non_reduced_loss
1208
+ self.main_input_name = "input_ids"
1209
+
1210
+ super().__init__(config)
1211
+
1212
+ self.model_dim = config.d_model
1213
+
1214
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1215
+ self.decoder = CustomT5Stack(config, self.shared)
1216
+
1217
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1218
+
1219
+ # Initialize weights and apply final processing
1220
+ self.post_init()
1221
+
1222
+ # Model parallel
1223
+ self.model_parallel = False
1224
+ self.device_map = None
1225
+ #
1226
+ cross_attention_params = [
1227
+ p
1228
+ for n, p in self.decoder.named_parameters()
1229
+ if n.startswith("block.") and ".layer.1." in n
1230
+ ]
1231
+ for param in cross_attention_params:
1232
+ param.requires_grad = False
1233
+
1234
+ # self.handle_tokenizer(tokenizer)
1235
+
1236
+ def get_decoder(self):
1237
+ return self.decoder
1238
+
1239
+ def parallelize(self, device_map=None):
1240
+ self.device_map = (
1241
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1242
+ if device_map is None
1243
+ else device_map
1244
+ )
1245
+ assert_device_map(self.device_map, len(self.encoder.block))
1246
+ self.encoder.parallelize(self.device_map)
1247
+ self.decoder.parallelize(self.device_map)
1248
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1249
+ self.model_parallel = True
1250
+
1251
+ def deparallelize(self):
1252
+ self.encoder.deparallelize()
1253
+ self.decoder.deparallelize()
1254
+ self.encoder = self.encoder.to("cpu")
1255
+ self.decoder = self.decoder.to("cpu")
1256
+ self.lm_head = self.lm_head.to("cpu")
1257
+ self.model_parallel = False
1258
+ self.device_map = None
1259
+ torch.cuda.empty_cache()
1260
+
1261
+ def get_input_embeddings(self):
1262
+ return self.shared
1263
+
1264
+ def set_input_embeddings(self, new_embeddings):
1265
+ self.shared = new_embeddings
1266
+ self.decoder.set_input_embeddings(new_embeddings)
1267
+
1268
+ def set_output_embeddings(self, new_embeddings):
1269
+ self.lm_head = new_embeddings
1270
+
1271
+ def get_output_embeddings(self):
1272
+ return self.lm_head
1273
+
1274
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1275
+ token_type_ids = kwargs.get("token_type_ids", None)
1276
+ # only last token for inputs_ids if past is defined in kwargs
1277
+ if past:
1278
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1279
+ if token_type_ids is not None:
1280
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
1281
+
1282
+ attention_mask = kwargs.get("attention_mask", None)
1283
+ position_ids = kwargs.get("position_ids", None)
1284
+
1285
+ if attention_mask is not None and position_ids is None:
1286
+ # create position_ids on the fly for batch generation
1287
+ position_ids = attention_mask.long().cumsum(-1) - 1
1288
+ position_ids.masked_fill_(attention_mask == 0, 1)
1289
+ if past:
1290
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1291
+ else:
1292
+ position_ids = None
1293
+
1294
+ return {
1295
+ "input_ids": input_ids,
1296
+ "past_key_values": past,
1297
+ "use_cache": kwargs.get("use_cache"),
1298
+ "attention_mask": attention_mask,
1299
+ "token_type_ids": token_type_ids,
1300
+ "position_ids": position_ids,
1301
+ }
1302
+
1303
+ def forward(
1304
+ self,
1305
+ input_ids=None,
1306
+ past_key_values=None,
1307
+ attention_mask=None,
1308
+ token_type_ids=None,
1309
+ position_ids=None,
1310
+ head_mask=None,
1311
+ inputs_embeds=None,
1312
+ labels=None,
1313
+ use_cache=None,
1314
+ output_attentions=None,
1315
+ output_hidden_states=None,
1316
+ return_dict=None,
1317
+ ):
1318
+ return_dict = (
1319
+ return_dict if return_dict is not None else self.config.use_return_dict
1320
+ )
1321
+
1322
+ if self.model_parallel:
1323
+ torch.cuda.set_device(self.decoder.first_device)
1324
+
1325
+ if self.model_parallel:
1326
+ torch.cuda.set_device(self.decoder.first_device)
1327
+ if input_ids is not None:
1328
+ input_ids = input_ids.to(self.decoder.first_device)
1329
+ if attention_mask is not None:
1330
+ attention_mask = attention_mask.to(self.decoder.first_device)
1331
+
1332
+ transformer_outputs = self.decoder(
1333
+ input_ids=input_ids,
1334
+ attention_mask=attention_mask,
1335
+ inputs_embeds=inputs_embeds,
1336
+ past_key_values=past_key_values,
1337
+ position_ids=position_ids,
1338
+ encoder_hidden_states=None,
1339
+ encoder_attention_mask=None,
1340
+ head_mask=head_mask,
1341
+ cross_attn_head_mask=None,
1342
+ use_cache=use_cache,
1343
+ output_attentions=output_attentions,
1344
+ output_hidden_states=output_hidden_states,
1345
+ return_dict=return_dict,
1346
+ )
1347
+ hidden_states = transformer_outputs[0]
1348
+
1349
+ if self.config.tie_word_embeddings:
1350
+ # Rescale output before projecting on vocab
1351
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1352
+ hidden_states = hidden_states * (self.model_dim**-0.5)
1353
+
1354
+ lm_logits = self.lm_head(hidden_states)
1355
+
1356
+ loss = None
1357
+ non_reduced_loss = None
1358
+ if labels is not None:
1359
+ # Compute loss in fp32 to match with mesh-tf version
1360
+ # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
1361
+ lm_logits = lm_logits.to(torch.float32)
1362
+
1363
+ # Shift so that tokens < n predict n
1364
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1365
+ shift_labels = labels[..., 1:].contiguous()
1366
+ # Flatten the tokens
1367
+ loss_fct = CrossEntropyLoss()
1368
+ loss = loss_fct(
1369
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1370
+ )
1371
+
1372
+ lm_logits = lm_logits.to(hidden_states.dtype)
1373
+ loss = loss.to(hidden_states.dtype)
1374
+
1375
+ if self.output_non_reduced_loss:
1376
+ loss_fct = CrossEntropyLoss(reduction="none")
1377
+ non_reduced_loss = loss_fct(
1378
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1379
+ )
1380
+
1381
+ # Reshape to [batch_size, seq_length - 1]
1382
+ non_reduced_loss = non_reduced_loss.view(
1383
+ shift_labels.shape[0], shift_labels.shape[1]
1384
+ )[:, -1].view(-1, 1)
1385
+
1386
+ if not return_dict:
1387
+ output = (lm_logits,) + transformer_outputs[1:]
1388
+ return ((loss,) + output) if loss is not None else output
1389
+
1390
+ return CausalLMOutputWithPastAndLoss(
1391
+ loss=loss,
1392
+ logits=lm_logits,
1393
+ past_key_values=transformer_outputs.past_key_values,
1394
+ hidden_states=transformer_outputs.hidden_states,
1395
+ attentions=transformer_outputs.attentions,
1396
+ non_reduced_loss=non_reduced_loss,
1397
+ )
1398
+
1399
+ @staticmethod
1400
+ def _reorder_cache(
1401
+ past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1402
+ ) -> Tuple[Tuple[torch.Tensor]]:
1403
+ """
1404
+ This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
1405
+ [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1406
+ beam_idx at every generation step.
1407
+ """
1408
+ return tuple(
1409
+ tuple(
1410
+ past_state.index_select(0, beam_idx.to(past_state.device))
1411
+ for past_state in layer_past
1412
+ )
1413
+ for layer_past in past
1414
+ )
modeling_t5.py ADDED
@@ -0,0 +1,1821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch T5 model."""
16
+
17
+
18
+ import copy
19
+ import math
20
+ import os
21
+
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss
25
+ from torch.utils.checkpoint import checkpoint
26
+ from transformers import T5Config
27
+ from transformers.activations import ACT2FN
28
+ from transformers.file_utils import (
29
+ DUMMY_INPUTS,
30
+ DUMMY_MASK,
31
+ add_start_docstrings,
32
+ add_start_docstrings_to_model_forward,
33
+ is_torch_fx_proxy,
34
+ replace_return_docstrings,
35
+ )
36
+ from transformers.modeling_outputs import (
37
+ BaseModelOutput,
38
+ BaseModelOutputWithPastAndCrossAttentions,
39
+ Seq2SeqLMOutput,
40
+ Seq2SeqModelOutput,
41
+ )
42
+ from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
43
+ from transformers.utils import logging
44
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _CONFIG_FOR_DOC = "T5Config"
49
+ _TOKENIZER_FOR_DOC = "T5Tokenizer"
50
+ _CHECKPOINT_FOR_DOC = "t5-small"
51
+
52
+ ####################################################
53
+ # This dict contains ids and associated url
54
+ # for the pretrained weights provided with the models
55
+ ####################################################
56
+ T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
57
+ "t5-small",
58
+ "t5-base",
59
+ "t5-large",
60
+ "t5-3b",
61
+ "t5-11b",
62
+ # See all T5 models at https://huggingface.co/models?filter=t5
63
+ ]
64
+
65
+
66
+ ####################################################
67
+ # This is a conversion method from TF 1.0 to PyTorch
68
+ # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
69
+ ####################################################
70
+ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
71
+ """Load tf checkpoints in a pytorch model."""
72
+ try:
73
+ import re
74
+
75
+ import numpy as np
76
+ import tensorflow as tf
77
+ except ImportError:
78
+ logger.error(
79
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
80
+ "https://www.tensorflow.org/install/ for installation instructions."
81
+ )
82
+ raise
83
+ tf_path = os.path.abspath(tf_checkpoint_path)
84
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
85
+ # Load weights from TF model
86
+ init_vars = tf.train.list_variables(tf_path)
87
+ names = []
88
+ tf_weights = {}
89
+ for name, shape in init_vars:
90
+ logger.info(f"Loading TF weight {name} with shape {shape}")
91
+ array = tf.train.load_variable(tf_path, name)
92
+ names.append(name)
93
+ tf_weights[name] = array
94
+
95
+ for txt_name in names:
96
+ name = txt_name.split("/")
97
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
98
+ # which are not required for using pretrained model
99
+ if any(
100
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
101
+ for n in name
102
+ ):
103
+ logger.info(f"Skipping {'/'.join(name)}")
104
+ tf_weights.pop(txt_name, None)
105
+ continue
106
+ if "_slot_" in name[-1]:
107
+ logger.info(f"Skipping {'/'.join(name)}")
108
+ tf_weights.pop(txt_name, None)
109
+ continue
110
+ pointer = model
111
+ array = tf_weights[txt_name]
112
+
113
+ for m_name in name:
114
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
115
+ scope_names = re.split(r"_(\d+)", m_name)
116
+ else:
117
+ scope_names = [m_name]
118
+ if scope_names[0] in ["kernel", "scale", "embedding"]:
119
+ pointer = getattr(pointer, "weight")
120
+ elif scope_names[0] == "self_attention":
121
+ pointer = getattr(pointer, "layer")
122
+ pointer = pointer[0]
123
+ elif scope_names[0] == "enc_dec_attention":
124
+ pointer = getattr(pointer, "layer")
125
+ pointer = pointer[1]
126
+ elif scope_names[0] == "dense_relu_dense":
127
+ pointer = getattr(pointer, "layer")
128
+ pointer = pointer[2]
129
+ elif scope_names[0] == "rms_norm":
130
+ if hasattr(pointer, "layer_norm"):
131
+ pointer = getattr(pointer, "layer_norm")
132
+ elif hasattr(pointer, "final_layer_norm"):
133
+ pointer = getattr(pointer, "final_layer_norm")
134
+ elif scope_names[0] == "scale":
135
+ pointer = getattr(pointer, "weight")
136
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
137
+ pointer = getattr(pointer, "bias")
138
+ elif scope_names[0] == "squad":
139
+ pointer = getattr(pointer, "classifier")
140
+ elif scope_names[0] == "decoder" and name[1] == "logits":
141
+ continue
142
+ elif scope_names[0] == "logits":
143
+ pointer = getattr(pointer, "lm_head")
144
+ elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
145
+ pointer = getattr(pointer, f"wi_{scope_names[1]}")
146
+ continue
147
+ else:
148
+ try:
149
+ pointer = getattr(pointer, scope_names[0])
150
+ except AttributeError:
151
+ logger.info(f"Skipping {'/'.join(name)}")
152
+ continue
153
+ if len(scope_names) >= 2:
154
+ num = int(scope_names[1])
155
+ pointer = pointer[num]
156
+ if scope_names[0] not in ["kernel", "scale", "embedding"]:
157
+ pointer = getattr(pointer, "weight")
158
+ if scope_names[0] != "embedding":
159
+ logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
160
+ array = np.transpose(array)
161
+ try:
162
+ assert (
163
+ pointer.shape == array.shape
164
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
165
+ except AssertionError as e:
166
+ e.args += (pointer.shape, array.shape)
167
+ raise
168
+ logger.info(f"Initialize PyTorch weight {name}")
169
+ pointer.data = torch.from_numpy(array.astype(np.float32))
170
+ tf_weights.pop(txt_name, None)
171
+
172
+ logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
173
+ return model
174
+
175
+
176
+ ####################################################
177
+ # PyTorch Models are constructed by sub-classing
178
+ # - torch.nn.Module for the layers and
179
+ # - PreTrainedModel for the models (it-self a sub-class of nn.Module)
180
+ ####################################################
181
+ PARALLELIZE_DOCSTRING = r"""
182
+ This is an experimental feature and is a subject to change at a moment's notice.
183
+
184
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
185
+ it will evenly distribute blocks across all devices.
186
+
187
+ Args:
188
+ device_map (`Dict[int, list]`, optional, defaults to None):
189
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
190
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
191
+ have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
192
+ following number of attention modules:
193
+
194
+ - t5-small: 6
195
+ - t5-base: 12
196
+ - t5-large: 24
197
+ - t5-3b: 24
198
+ - t5-11b: 24
199
+
200
+ Example:
201
+
202
+ ```python
203
+ # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
204
+ model = T5ForConditionalGeneration.from_pretrained("t5-3b")
205
+ device_map = {
206
+ 0: [0, 1, 2],
207
+ 1: [3, 4, 5, 6, 7, 8, 9],
208
+ 2: [10, 11, 12, 13, 14, 15, 16],
209
+ 3: [17, 18, 19, 20, 21, 22, 23],
210
+ }
211
+ model.parallelize(device_map)
212
+ ```
213
+ """
214
+ DEPARALLELIZE_DOCSTRING = r"""
215
+ Moves the model to cpu from a model parallel state.
216
+
217
+ Example:
218
+
219
+ ```python
220
+ # On a 4 GPU machine with t5-3b:
221
+ model = T5ForConditionalGeneration.from_pretrained("t5-3b")
222
+ device_map = {
223
+ 0: [0, 1, 2],
224
+ 1: [3, 4, 5, 6, 7, 8, 9],
225
+ 2: [10, 11, 12, 13, 14, 15, 16],
226
+ 3: [17, 18, 19, 20, 21, 22, 23],
227
+ }
228
+ model.parallelize(device_map) # Splits the model across several devices
229
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
230
+ ```
231
+ """
232
+
233
+
234
+ class T5LayerNorm(nn.Module):
235
+ def __init__(self, hidden_size, eps=1e-6):
236
+ """
237
+ Construct a layernorm module in the T5 style No bias and no subtraction of mean.
238
+ """
239
+ super().__init__()
240
+ self.weight = nn.Parameter(torch.ones(hidden_size))
241
+ self.variance_epsilon = eps
242
+
243
+ def forward(self, hidden_states):
244
+ # layer norm should always be calculated in float32
245
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
246
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
247
+
248
+ # convert into half-precision if necessary
249
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
250
+ hidden_states = hidden_states.to(self.weight.dtype)
251
+
252
+ return self.weight * hidden_states
253
+
254
+
255
+ class T5DenseReluDense(nn.Module):
256
+ def __init__(self, config):
257
+ super().__init__()
258
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
259
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
260
+ self.dropout = nn.Dropout(config.dropout_rate)
261
+
262
+ def forward(self, hidden_states):
263
+ hidden_states = self.wi(hidden_states)
264
+ hidden_states = nn.functional.relu(hidden_states)
265
+ hidden_states = self.dropout(hidden_states)
266
+ hidden_states = self.wo(hidden_states)
267
+ return hidden_states
268
+
269
+
270
+ class T5DenseGatedGeluDense(nn.Module):
271
+ def __init__(self, config):
272
+ super().__init__()
273
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
274
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
275
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
276
+ self.dropout = nn.Dropout(config.dropout_rate)
277
+ self.gelu_act = ACT2FN["gelu_new"]
278
+
279
+ def forward(self, hidden_states):
280
+ hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
281
+ hidden_linear = self.wi_1(hidden_states)
282
+ hidden_states = hidden_gelu * hidden_linear
283
+ hidden_states = self.dropout(hidden_states)
284
+ hidden_states = self.wo(hidden_states)
285
+ return hidden_states
286
+
287
+
288
+ class T5LayerFF(nn.Module):
289
+ def __init__(self, config):
290
+ super().__init__()
291
+ if config.feed_forward_proj == "relu":
292
+ self.DenseReluDense = T5DenseReluDense(config)
293
+ elif config.feed_forward_proj == "gated-gelu":
294
+ self.DenseReluDense = T5DenseGatedGeluDense(config)
295
+ else:
296
+ raise ValueError(
297
+ f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
298
+ )
299
+
300
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
301
+ self.dropout = nn.Dropout(config.dropout_rate)
302
+
303
+ def forward(self, hidden_states):
304
+ forwarded_states = self.layer_norm(hidden_states)
305
+ forwarded_states = self.DenseReluDense(forwarded_states)
306
+ hidden_states = hidden_states + self.dropout(forwarded_states)
307
+ return hidden_states
308
+
309
+
310
+ class T5Attention(nn.Module):
311
+ def __init__(self, config: T5Config, has_relative_attention_bias=False):
312
+ super().__init__()
313
+ self.is_decoder = config.is_decoder
314
+ self.has_relative_attention_bias = has_relative_attention_bias
315
+
316
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
317
+ self.d_model = config.d_model
318
+ self.key_value_proj_dim = config.d_kv
319
+ self.n_heads = config.num_heads
320
+ self.dropout = config.dropout_rate
321
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
322
+
323
+ # Mesh TensorFlow initialization to avoid scaling before softmax
324
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
325
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
326
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
327
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
328
+
329
+ if self.has_relative_attention_bias:
330
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
331
+ self.pruned_heads = set()
332
+ self.gradient_checkpointing = False
333
+
334
+ def prune_heads(self, heads):
335
+ if len(heads) == 0:
336
+ return
337
+ heads, index = find_pruneable_heads_and_indices(
338
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
339
+ )
340
+ # Prune linear layers
341
+ self.q = prune_linear_layer(self.q, index)
342
+ self.k = prune_linear_layer(self.k, index)
343
+ self.v = prune_linear_layer(self.v, index)
344
+ self.o = prune_linear_layer(self.o, index, dim=1)
345
+ # Update hyper params
346
+ self.n_heads = self.n_heads - len(heads)
347
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
348
+ self.pruned_heads = self.pruned_heads.union(heads)
349
+
350
+ @staticmethod
351
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
352
+ """
353
+ Adapted from Mesh Tensorflow:
354
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
355
+
356
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
357
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
358
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
359
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
360
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
361
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
362
+
363
+ Args:
364
+ relative_position: an int32 Tensor
365
+ bidirectional: a boolean - whether the attention is bidirectional
366
+ num_buckets: an integer
367
+ max_distance: an integer
368
+
369
+ Returns:
370
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
371
+ """
372
+ relative_buckets = 0
373
+ if bidirectional:
374
+ num_buckets //= 2
375
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
376
+ relative_position = torch.abs(relative_position)
377
+ else:
378
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
379
+ # now relative_position is in the range [0, inf)
380
+
381
+ # half of the buckets are for exact increments in positions
382
+ max_exact = num_buckets // 2
383
+ is_small = relative_position < max_exact
384
+
385
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
386
+ relative_postion_if_large = max_exact + (
387
+ torch.log(relative_position.float() / max_exact)
388
+ / math.log(max_distance / max_exact)
389
+ * (num_buckets - max_exact)
390
+ ).to(torch.long)
391
+ relative_postion_if_large = torch.min(
392
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
393
+ )
394
+
395
+ relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
396
+ return relative_buckets
397
+
398
+ def compute_bias(self, query_length, key_length):
399
+ """Compute binned relative position bias"""
400
+ context_position = torch.arange(
401
+ query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
402
+ )[:, None]
403
+ memory_position = torch.arange(
404
+ key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
405
+ )[None, :]
406
+ relative_position = memory_position - context_position # shape (query_length, key_length)
407
+ relative_position_bucket = self._relative_position_bucket(
408
+ relative_position, # shape (query_length, key_length)
409
+ bidirectional=(not self.is_decoder),
410
+ num_buckets=self.relative_attention_num_buckets,
411
+ )
412
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
413
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
414
+ return values
415
+
416
+ def forward(
417
+ self,
418
+ hidden_states,
419
+ mask=None,
420
+ key_value_states=None,
421
+ position_bias=None,
422
+ past_key_value=None,
423
+ layer_head_mask=None,
424
+ query_length=None,
425
+ use_cache=False,
426
+ output_attentions=False,
427
+ ):
428
+ """
429
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
430
+ """
431
+ # Input is (batch_size, seq_length, dim)
432
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
433
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
434
+ batch_size, seq_length = hidden_states.shape[:2]
435
+
436
+ real_seq_length = seq_length
437
+
438
+ if past_key_value is not None:
439
+ assert (
440
+ len(past_key_value) == 2
441
+ ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
442
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
443
+
444
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
445
+
446
+ def shape(states):
447
+ """projection"""
448
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
449
+
450
+ def unshape(states):
451
+ """reshape"""
452
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
453
+
454
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
455
+ """projects hidden states correctly to key/query states"""
456
+ if key_value_states is None:
457
+ # self-attn
458
+ # (batch_size, n_heads, seq_length, dim_per_head)
459
+ hidden_states = shape(proj_layer(hidden_states))
460
+ elif past_key_value is None:
461
+ # cross-attn
462
+ # (batch_size, n_heads, seq_length, dim_per_head)
463
+ hidden_states = shape(proj_layer(key_value_states))
464
+
465
+ if past_key_value is not None:
466
+ if key_value_states is None:
467
+ # self-attn
468
+ # (batch_size, n_heads, key_length, dim_per_head)
469
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
470
+ else:
471
+ # cross-attn
472
+ hidden_states = past_key_value
473
+ return hidden_states
474
+
475
+ # get query states
476
+ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
477
+
478
+ # get key/value states
479
+ key_states = project(
480
+ hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
481
+ )
482
+ value_states = project(
483
+ hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
484
+ )
485
+
486
+ # compute scores
487
+ scores = torch.matmul(
488
+ query_states, key_states.transpose(3, 2)
489
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
490
+
491
+ if position_bias is None:
492
+ if not self.has_relative_attention_bias:
493
+ position_bias = torch.zeros(
494
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
495
+ )
496
+ if self.gradient_checkpointing and self.training:
497
+ position_bias.requires_grad = True
498
+ else:
499
+ position_bias = self.compute_bias(real_seq_length, key_length)
500
+
501
+ # if key and values are already calculated
502
+ # we want only the last query position bias
503
+ if past_key_value is not None:
504
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
505
+
506
+ if mask is not None:
507
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
508
+
509
+ scores += position_bias
510
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
511
+ scores
512
+ ) # (batch_size, n_heads, seq_length, key_length)
513
+ attn_weights = nn.functional.dropout(
514
+ attn_weights, p=self.dropout, training=self.training
515
+ ) # (batch_size, n_heads, seq_length, key_length)
516
+
517
+ # Mask heads if we want to
518
+ if layer_head_mask is not None:
519
+ attn_weights = attn_weights * layer_head_mask
520
+
521
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
522
+ attn_output = self.o(attn_output)
523
+
524
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
525
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
526
+
527
+ if output_attentions:
528
+ outputs = outputs + (attn_weights,)
529
+ return outputs
530
+
531
+
532
+ class T5LayerSelfAttention(nn.Module):
533
+ def __init__(self, config, has_relative_attention_bias=False):
534
+ super().__init__()
535
+ self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
536
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
537
+ self.dropout = nn.Dropout(config.dropout_rate)
538
+
539
+ def forward(
540
+ self,
541
+ hidden_states,
542
+ attention_mask=None,
543
+ position_bias=None,
544
+ layer_head_mask=None,
545
+ past_key_value=None,
546
+ use_cache=False,
547
+ output_attentions=False,
548
+ ):
549
+ normed_hidden_states = self.layer_norm(hidden_states)
550
+ attention_output = self.SelfAttention(
551
+ normed_hidden_states,
552
+ mask=attention_mask,
553
+ position_bias=position_bias,
554
+ layer_head_mask=layer_head_mask,
555
+ past_key_value=past_key_value,
556
+ use_cache=use_cache,
557
+ output_attentions=output_attentions,
558
+ )
559
+ hidden_states = hidden_states + self.dropout(attention_output[0])
560
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
561
+ return outputs
562
+
563
+
564
+ class T5LayerCrossAttention(nn.Module):
565
+ def __init__(self, config):
566
+ super().__init__()
567
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
568
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
569
+ self.dropout = nn.Dropout(config.dropout_rate)
570
+
571
+ def forward(
572
+ self,
573
+ hidden_states,
574
+ key_value_states,
575
+ attention_mask=None,
576
+ position_bias=None,
577
+ layer_head_mask=None,
578
+ past_key_value=None,
579
+ use_cache=False,
580
+ query_length=None,
581
+ output_attentions=False,
582
+ ):
583
+ normed_hidden_states = self.layer_norm(hidden_states)
584
+ attention_output = self.EncDecAttention(
585
+ normed_hidden_states,
586
+ mask=attention_mask,
587
+ key_value_states=key_value_states,
588
+ position_bias=position_bias,
589
+ layer_head_mask=layer_head_mask,
590
+ past_key_value=past_key_value,
591
+ use_cache=use_cache,
592
+ query_length=query_length,
593
+ output_attentions=output_attentions,
594
+ )
595
+ layer_output = hidden_states + self.dropout(attention_output[0])
596
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
597
+ return outputs
598
+
599
+
600
+ class T5Block(nn.Module):
601
+ def __init__(self, config, has_relative_attention_bias=False):
602
+ super().__init__()
603
+ self.is_decoder = config.is_decoder
604
+ self.layer = nn.ModuleList()
605
+ self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
606
+ if self.is_decoder:
607
+ self.layer.append(T5LayerCrossAttention(config))
608
+
609
+ self.layer.append(T5LayerFF(config))
610
+
611
+ def forward(
612
+ self,
613
+ hidden_states,
614
+ attention_mask=None,
615
+ position_bias=None,
616
+ encoder_hidden_states=None,
617
+ encoder_attention_mask=None,
618
+ encoder_decoder_position_bias=None,
619
+ layer_head_mask=None,
620
+ cross_attn_layer_head_mask=None,
621
+ past_key_value=None,
622
+ use_cache=False,
623
+ output_attentions=False,
624
+ return_dict=True,
625
+ ):
626
+
627
+ if past_key_value is not None:
628
+ assert self.is_decoder, "Only decoder can use `past_key_values`"
629
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
630
+
631
+ if len(past_key_value) != expected_num_past_key_values:
632
+ raise ValueError(
633
+ f"There should be {expected_num_past_key_values} past states. "
634
+ f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
635
+ f"Got {len(past_key_value)} past key / value states"
636
+ )
637
+
638
+ self_attn_past_key_value = past_key_value[:2]
639
+ cross_attn_past_key_value = past_key_value[2:]
640
+ else:
641
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
642
+
643
+ self_attention_outputs = self.layer[0](
644
+ hidden_states,
645
+ attention_mask=attention_mask,
646
+ position_bias=position_bias,
647
+ layer_head_mask=layer_head_mask,
648
+ past_key_value=self_attn_past_key_value,
649
+ use_cache=use_cache,
650
+ output_attentions=output_attentions,
651
+ )
652
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
653
+ attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
654
+
655
+ # clamp inf values to enable fp16 training
656
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
657
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
658
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
659
+
660
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
661
+ if do_cross_attention:
662
+ # the actual query length is unknown for cross attention
663
+ # if using past key value states. Need to inject it here
664
+ if present_key_value_state is not None:
665
+ query_length = present_key_value_state[0].shape[2]
666
+ else:
667
+ query_length = None
668
+
669
+ cross_attention_outputs = self.layer[1](
670
+ hidden_states,
671
+ key_value_states=encoder_hidden_states,
672
+ attention_mask=encoder_attention_mask,
673
+ position_bias=encoder_decoder_position_bias,
674
+ layer_head_mask=cross_attn_layer_head_mask,
675
+ past_key_value=cross_attn_past_key_value,
676
+ query_length=query_length,
677
+ use_cache=use_cache,
678
+ output_attentions=output_attentions,
679
+ )
680
+ hidden_states = cross_attention_outputs[0]
681
+
682
+ # clamp inf values to enable fp16 training
683
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
684
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
685
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
686
+
687
+ # Combine self attn and cross attn key value states
688
+ if present_key_value_state is not None:
689
+ present_key_value_state = present_key_value_state + cross_attention_outputs[1]
690
+
691
+ # Keep cross-attention outputs and relative position weights
692
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
693
+
694
+ # Apply Feed Forward layer
695
+ hidden_states = self.layer[-1](hidden_states)
696
+
697
+ # clamp inf values to enable fp16 training
698
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
699
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
700
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
701
+
702
+ outputs = (hidden_states,)
703
+
704
+ if use_cache:
705
+ outputs = outputs + (present_key_value_state,) + attention_outputs
706
+ else:
707
+ outputs = outputs + attention_outputs
708
+
709
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
710
+
711
+
712
+ class T5PreTrainedModel(PreTrainedModel):
713
+ """
714
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
715
+ models.
716
+ """
717
+
718
+ config_class = T5Config
719
+ load_tf_weights = load_tf_weights_in_t5
720
+ base_model_prefix = "transformer"
721
+ is_parallelizable = True
722
+ supports_gradient_checkpointing = True
723
+
724
+ @property
725
+ def dummy_inputs(self):
726
+ input_ids = torch.tensor(DUMMY_INPUTS)
727
+ input_mask = torch.tensor(DUMMY_MASK)
728
+ dummy_inputs = {
729
+ "decoder_input_ids": input_ids,
730
+ "input_ids": input_ids,
731
+ "decoder_attention_mask": input_mask,
732
+ }
733
+ return dummy_inputs
734
+
735
+ def _init_weights(self, module):
736
+ """Initialize the weights"""
737
+ factor = self.config.initializer_factor # Used for testing weights initialization
738
+ if isinstance(module, T5LayerNorm):
739
+ module.weight.data.fill_(factor * 1.0)
740
+ elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
741
+ # Mesh TensorFlow embeddings initialization
742
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
743
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
744
+ elif isinstance(module, T5DenseReluDense):
745
+ # Mesh TensorFlow FF initialization
746
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
747
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
748
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
749
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
750
+ module.wi.bias.data.zero_()
751
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
752
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
753
+ module.wo.bias.data.zero_()
754
+ elif isinstance(module, T5DenseGatedGeluDense):
755
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
756
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
757
+ module.wi_0.bias.data.zero_()
758
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
759
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
760
+ module.wi_1.bias.data.zero_()
761
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
762
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
763
+ module.wo.bias.data.zero_()
764
+ elif isinstance(module, T5Attention):
765
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
766
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
767
+ d_model = self.config.d_model
768
+ key_value_proj_dim = self.config.d_kv
769
+ n_heads = self.config.num_heads
770
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
771
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
772
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
773
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
774
+ if module.has_relative_attention_bias:
775
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
776
+
777
+ def _set_gradient_checkpointing(self, module, value=False):
778
+ if isinstance(module, (T5Attention, T5Stack)):
779
+ module.gradient_checkpointing = value
780
+
781
+ def _shift_right(self, input_ids):
782
+ decoder_start_token_id = self.config.decoder_start_token_id
783
+ pad_token_id = self.config.pad_token_id
784
+
785
+ assert (
786
+ decoder_start_token_id is not None
787
+ ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
788
+
789
+ # shift inputs to the right
790
+ if is_torch_fx_proxy(input_ids):
791
+ # Item assignment is not supported natively for proxies.
792
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
793
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
794
+ else:
795
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
796
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
797
+ shifted_input_ids[..., 0] = decoder_start_token_id
798
+
799
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
800
+ # replace possible -100 values in labels by `pad_token_id`
801
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
802
+
803
+ assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
804
+
805
+ return shifted_input_ids
806
+
807
+
808
+ class T5Stack(T5PreTrainedModel):
809
+ def __init__(self, config, embed_tokens=None):
810
+ super().__init__(config)
811
+
812
+ self.embed_tokens = embed_tokens
813
+ self.is_decoder = config.is_decoder
814
+
815
+ self.block = nn.ModuleList(
816
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
817
+ )
818
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
819
+ self.dropout = nn.Dropout(config.dropout_rate)
820
+
821
+ # Initialize weights and apply final processing
822
+ self.post_init()
823
+ # Model parallel
824
+ self.model_parallel = False
825
+ self.device_map = None
826
+ self.gradient_checkpointing = False
827
+
828
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
829
+ def parallelize(self, device_map=None):
830
+ # Check validity of device_map
831
+ self.device_map = (
832
+ get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
833
+ )
834
+ assert_device_map(self.device_map, len(self.block))
835
+ self.model_parallel = True
836
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
837
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
838
+ # Load onto devices
839
+ for k, v in self.device_map.items():
840
+ for layer in v:
841
+ cuda_device = "cuda:" + str(k)
842
+ self.block[layer] = self.block[layer].to(cuda_device)
843
+
844
+ # Set embed_tokens to first layer
845
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
846
+ # Set final layer norm to last device
847
+ self.final_layer_norm = self.final_layer_norm.to(self.last_device)
848
+
849
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
850
+ def deparallelize(self):
851
+ self.model_parallel = False
852
+ self.device_map = None
853
+ self.first_device = "cpu"
854
+ self.last_device = "cpu"
855
+ for i in range(len(self.block)):
856
+ self.block[i] = self.block[i].to("cpu")
857
+ self.embed_tokens = self.embed_tokens.to("cpu")
858
+ self.final_layer_norm = self.final_layer_norm.to("cpu")
859
+ torch.cuda.empty_cache()
860
+
861
+ def get_input_embeddings(self):
862
+ return self.embed_tokens
863
+
864
+ def set_input_embeddings(self, new_embeddings):
865
+ self.embed_tokens = new_embeddings
866
+
867
+ def forward(
868
+ self,
869
+ input_ids=None,
870
+ attention_mask=None,
871
+ encoder_hidden_states=None,
872
+ encoder_attention_mask=None,
873
+ inputs_embeds=None,
874
+ head_mask=None,
875
+ cross_attn_head_mask=None,
876
+ past_key_values=None,
877
+ use_cache=None,
878
+ output_attentions=None,
879
+ output_hidden_states=None,
880
+ return_dict=None,
881
+ ):
882
+ # Model parallel
883
+ if self.model_parallel:
884
+ torch.cuda.set_device(self.first_device)
885
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
886
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
887
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
888
+ output_hidden_states = (
889
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
890
+ )
891
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
892
+
893
+ if input_ids is not None and inputs_embeds is not None:
894
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
895
+ raise ValueError(
896
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
897
+ )
898
+ elif input_ids is not None:
899
+ input_shape = input_ids.size()
900
+ input_ids = input_ids.view(-1, input_shape[-1])
901
+ elif inputs_embeds is not None:
902
+ input_shape = inputs_embeds.size()[:-1]
903
+ else:
904
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
905
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
906
+
907
+ if inputs_embeds is None:
908
+ assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
909
+ inputs_embeds = self.embed_tokens(input_ids)
910
+
911
+ batch_size, seq_length = input_shape
912
+
913
+ # required mask seq length can be calculated via length of past
914
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
915
+
916
+ if use_cache is True:
917
+ assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
918
+
919
+ if attention_mask is None:
920
+ attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
921
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
922
+ encoder_seq_length = encoder_hidden_states.shape[1]
923
+ encoder_attention_mask = torch.ones(
924
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
925
+ )
926
+
927
+ # initialize past_key_values with `None` if past does not exist
928
+ if past_key_values is None:
929
+ past_key_values = [None] * len(self.block)
930
+
931
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
932
+ # ourselves in which case we just need to make it broadcastable to all heads.
933
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
934
+
935
+ # If a 2D or 3D attention mask is provided for the cross-attention
936
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
937
+ if self.is_decoder and encoder_hidden_states is not None:
938
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
939
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
940
+ if encoder_attention_mask is None:
941
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
942
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
943
+ else:
944
+ encoder_extended_attention_mask = None
945
+
946
+ # Prepare head mask if needed
947
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
948
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
949
+ present_key_value_states = () if use_cache else None
950
+ all_hidden_states = () if output_hidden_states else None
951
+ all_attentions = () if output_attentions else None
952
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
953
+ position_bias = None
954
+ encoder_decoder_position_bias = None
955
+
956
+ hidden_states = self.dropout(inputs_embeds)
957
+
958
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
959
+ layer_head_mask = head_mask[i]
960
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
961
+ # Model parallel
962
+ if self.model_parallel:
963
+ torch.cuda.set_device(hidden_states.device)
964
+ # Ensure that attention_mask is always on the same device as hidden_states
965
+ if attention_mask is not None:
966
+ attention_mask = attention_mask.to(hidden_states.device)
967
+ if position_bias is not None:
968
+ position_bias = position_bias.to(hidden_states.device)
969
+ if encoder_hidden_states is not None:
970
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
971
+ if encoder_extended_attention_mask is not None:
972
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
973
+ if encoder_decoder_position_bias is not None:
974
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
975
+ if layer_head_mask is not None:
976
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
977
+ if cross_attn_layer_head_mask is not None:
978
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
979
+ if output_hidden_states:
980
+ all_hidden_states = all_hidden_states + (hidden_states,)
981
+
982
+ if self.gradient_checkpointing and self.training:
983
+ if use_cache:
984
+ logger.warn(
985
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
986
+ )
987
+ use_cache = False
988
+
989
+ def create_custom_forward(module):
990
+ def custom_forward(*inputs):
991
+ return tuple(module(*inputs, use_cache, output_attentions))
992
+
993
+ return custom_forward
994
+
995
+ layer_outputs = checkpoint(
996
+ create_custom_forward(layer_module),
997
+ hidden_states,
998
+ extended_attention_mask,
999
+ position_bias,
1000
+ encoder_hidden_states,
1001
+ encoder_extended_attention_mask,
1002
+ encoder_decoder_position_bias,
1003
+ layer_head_mask,
1004
+ cross_attn_layer_head_mask,
1005
+ None, # past_key_value is always None with gradient checkpointing
1006
+ )
1007
+ else:
1008
+ layer_outputs = layer_module(
1009
+ hidden_states,
1010
+ attention_mask=extended_attention_mask,
1011
+ position_bias=position_bias,
1012
+ encoder_hidden_states=encoder_hidden_states,
1013
+ encoder_attention_mask=encoder_extended_attention_mask,
1014
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
1015
+ layer_head_mask=layer_head_mask,
1016
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1017
+ past_key_value=past_key_value,
1018
+ use_cache=use_cache,
1019
+ output_attentions=output_attentions,
1020
+ )
1021
+
1022
+ # layer_outputs is a tuple with:
1023
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1024
+ if use_cache is False:
1025
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1026
+
1027
+ hidden_states, present_key_value_state = layer_outputs[:2]
1028
+
1029
+ # We share the position biases between the layers - the first layer store them
1030
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1031
+ # (cross-attention position bias), (cross-attention weights)
1032
+ position_bias = layer_outputs[2]
1033
+ if self.is_decoder and encoder_hidden_states is not None:
1034
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1035
+ # append next layer key value states
1036
+ if use_cache:
1037
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
1038
+
1039
+ if output_attentions:
1040
+ all_attentions = all_attentions + (layer_outputs[3],)
1041
+ if self.is_decoder:
1042
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
1043
+
1044
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1045
+ if self.model_parallel:
1046
+ for k, v in self.device_map.items():
1047
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1048
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1049
+
1050
+ hidden_states = self.final_layer_norm(hidden_states)
1051
+ hidden_states = self.dropout(hidden_states)
1052
+
1053
+ # Add last layer
1054
+ if output_hidden_states:
1055
+ all_hidden_states = all_hidden_states + (hidden_states,)
1056
+
1057
+ if not return_dict:
1058
+ return tuple(
1059
+ v
1060
+ for v in [
1061
+ hidden_states,
1062
+ present_key_value_states,
1063
+ all_hidden_states,
1064
+ all_attentions,
1065
+ all_cross_attentions,
1066
+ ]
1067
+ if v is not None
1068
+ )
1069
+ return BaseModelOutputWithPastAndCrossAttentions(
1070
+ last_hidden_state=hidden_states,
1071
+ past_key_values=present_key_value_states,
1072
+ hidden_states=all_hidden_states,
1073
+ attentions=all_attentions,
1074
+ cross_attentions=all_cross_attentions,
1075
+ )
1076
+
1077
+
1078
+ T5_START_DOCSTRING = r"""
1079
+
1080
+ The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
1081
+ Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
1082
+ Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
1083
+ text-to-text denoising generative setting.
1084
+
1085
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1086
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1087
+ etc.)
1088
+
1089
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1090
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1091
+ and behavior.
1092
+
1093
+ Parameters:
1094
+ config ([`T5Config`]): Model configuration class with all the parameters of the model.
1095
+ Initializing with a config file does not load the weights associated with the model, only the
1096
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1097
+ """
1098
+
1099
+ T5_INPUTS_DOCSTRING = r"""
1100
+ Args:
1101
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1102
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1103
+ should be able to pad the inputs on both the right and the left.
1104
+
1105
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1106
+ [`PreTrainedTokenizer.__call__`] for detail.
1107
+
1108
+ [What are input IDs?](../glossary#input-ids)
1109
+
1110
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1111
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1112
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1113
+
1114
+ - 1 for tokens that are **not masked**,
1115
+ - 0 for tokens that are **masked**.
1116
+
1117
+ [What are attention masks?](../glossary#attention-mask)
1118
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1119
+ Indices of decoder input sequence tokens in the vocabulary.
1120
+
1121
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1122
+ [`PreTrainedTokenizer.__call__`] for details.
1123
+
1124
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
1125
+
1126
+ T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
1127
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
1128
+
1129
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
1130
+ Training](./t5#training).
1131
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1132
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
1133
+ be used by default.
1134
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1135
+ Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
1136
+ 1]`:
1137
+
1138
+ - 1 indicates the head is **not masked**,
1139
+ - 0 indicates the head is **masked**.
1140
+
1141
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1142
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
1143
+ 1]`:
1144
+
1145
+ - 1 indicates the head is **not masked**,
1146
+ - 0 indicates the head is **masked**.
1147
+
1148
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1149
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1150
+ `[0, 1]`:
1151
+
1152
+ - 1 indicates the head is **not masked**,
1153
+ - 0 indicates the head is **masked**.
1154
+
1155
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1156
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
1157
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
1158
+ the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1159
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1160
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1161
+
1162
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1163
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1164
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1165
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1166
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1167
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1168
+ model's internal embedding lookup matrix.
1169
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
1170
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
1171
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
1172
+ input (see `past_key_values`). This is useful if you want more control over how to convert
1173
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1174
+
1175
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
1176
+ of `inputs_embeds`.
1177
+
1178
+ use_cache (`bool`, *optional*):
1179
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1180
+ `past_key_values`).
1181
+
1182
+ output_attentions (`bool`, *optional*):
1183
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1184
+ tensors for more detail.
1185
+ output_hidden_states (`bool`, *optional*):
1186
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1187
+ more detail.
1188
+ return_dict (`bool`, *optional*):
1189
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1190
+ """
1191
+
1192
+ T5_ENCODER_INPUTS_DOCSTRING = r"""
1193
+ Args:
1194
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1195
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1196
+ should be able to pad the inputs on both the right and the left.
1197
+
1198
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1199
+ [`PreTrainedTokenizer.__call__`] for detail.
1200
+
1201
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1202
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1203
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1204
+
1205
+ - 1 for tokens that are **not masked**,
1206
+ - 0 for tokens that are **masked**.
1207
+
1208
+ [What are attention masks?](../glossary#attention-mask)
1209
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1210
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1211
+
1212
+ - 1 indicates the head is **not masked**,
1213
+ - 0 indicates the head is **masked**.
1214
+
1215
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1216
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1217
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1218
+ model's internal embedding lookup matrix.
1219
+ output_attentions (`bool`, *optional*):
1220
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1221
+ tensors for more detail.
1222
+ output_hidden_states (`bool`, *optional*):
1223
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1224
+ more detail.
1225
+ return_dict (`bool`, *optional*):
1226
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1227
+ """
1228
+
1229
+ # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1230
+ __HEAD_MASK_WARNING_MSG = """
1231
+ The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
1232
+ `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
1233
+ If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
1234
+ num_heads)`.
1235
+ """
1236
+
1237
+
1238
+ @add_start_docstrings(
1239
+ "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
1240
+ T5_START_DOCSTRING,
1241
+ )
1242
+ class T5Model(T5PreTrainedModel):
1243
+ _keys_to_ignore_on_load_missing = [
1244
+ r"encoder\.embed_tokens\.weight",
1245
+ r"decoder\.embed_tokens\.weight",
1246
+ ]
1247
+ _keys_to_ignore_on_load_unexpected = [
1248
+ r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
1249
+ ]
1250
+
1251
+ def __init__(self, config: T5Config):
1252
+ super().__init__(config)
1253
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1254
+
1255
+ encoder_config = copy.deepcopy(config)
1256
+ encoder_config.is_decoder = False
1257
+ encoder_config.use_cache = False
1258
+ encoder_config.is_encoder_decoder = False
1259
+ self.encoder = T5Stack(encoder_config, self.shared)
1260
+
1261
+ decoder_config = copy.deepcopy(config)
1262
+ decoder_config.is_decoder = True
1263
+ decoder_config.is_encoder_decoder = False
1264
+ decoder_config.num_layers = config.num_decoder_layers
1265
+ self.decoder = T5Stack(decoder_config, self.shared)
1266
+
1267
+ # Initialize weights and apply final processing
1268
+ self.post_init()
1269
+
1270
+ # Model parallel
1271
+ self.model_parallel = False
1272
+ self.device_map = None
1273
+
1274
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1275
+ def parallelize(self, device_map=None):
1276
+ self.device_map = (
1277
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1278
+ if device_map is None
1279
+ else device_map
1280
+ )
1281
+ assert_device_map(self.device_map, len(self.encoder.block))
1282
+ self.encoder.parallelize(self.device_map)
1283
+ self.decoder.parallelize(self.device_map)
1284
+ self.model_parallel = True
1285
+
1286
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1287
+ def deparallelize(self):
1288
+ self.encoder.deparallelize()
1289
+ self.decoder.deparallelize()
1290
+ self.encoder = self.encoder.to("cpu")
1291
+ self.decoder = self.decoder.to("cpu")
1292
+ self.model_parallel = False
1293
+ self.device_map = None
1294
+ torch.cuda.empty_cache()
1295
+
1296
+ def get_input_embeddings(self):
1297
+ return self.shared
1298
+
1299
+ def set_input_embeddings(self, new_embeddings):
1300
+ self.shared = new_embeddings
1301
+ self.encoder.set_input_embeddings(new_embeddings)
1302
+ self.decoder.set_input_embeddings(new_embeddings)
1303
+
1304
+ def get_encoder(self):
1305
+ return self.encoder
1306
+
1307
+ def get_decoder(self):
1308
+ return self.decoder
1309
+
1310
+ def _prune_heads(self, heads_to_prune):
1311
+ """
1312
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1313
+ class PreTrainedModel
1314
+ """
1315
+ for layer, heads in heads_to_prune.items():
1316
+ self.encoder.layer[layer].attention.prune_heads(heads)
1317
+
1318
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1319
+ @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1320
+ def forward(
1321
+ self,
1322
+ input_ids=None,
1323
+ attention_mask=None,
1324
+ decoder_input_ids=None,
1325
+ decoder_attention_mask=None,
1326
+ head_mask=None,
1327
+ decoder_head_mask=None,
1328
+ cross_attn_head_mask=None,
1329
+ encoder_outputs=None,
1330
+ past_key_values=None,
1331
+ inputs_embeds=None,
1332
+ decoder_inputs_embeds=None,
1333
+ use_cache=None,
1334
+ output_attentions=None,
1335
+ output_hidden_states=None,
1336
+ return_dict=None,
1337
+ ):
1338
+ r"""
1339
+ Returns:
1340
+
1341
+ Example:
1342
+
1343
+ ```python
1344
+ >>> from transformers import T5Tokenizer, T5Model
1345
+
1346
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1347
+ >>> model = T5Model.from_pretrained("t5-small")
1348
+
1349
+ >>> input_ids = tokenizer(
1350
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1351
+ >>> ).input_ids # Batch size 1
1352
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1353
+
1354
+ >>> # forward pass
1355
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1356
+ >>> last_hidden_states = outputs.last_hidden_state
1357
+ ```"""
1358
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1359
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1360
+
1361
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1362
+ if head_mask is not None and decoder_head_mask is None:
1363
+ if self.config.num_layers == self.config.num_decoder_layers:
1364
+ decoder_head_mask = head_mask
1365
+
1366
+ # Encode if needed (training, first prediction pass)
1367
+ if encoder_outputs is None:
1368
+ encoder_outputs = self.encoder(
1369
+ input_ids=input_ids,
1370
+ attention_mask=attention_mask,
1371
+ inputs_embeds=inputs_embeds,
1372
+ head_mask=head_mask,
1373
+ output_attentions=output_attentions,
1374
+ output_hidden_states=output_hidden_states,
1375
+ return_dict=return_dict,
1376
+ )
1377
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1378
+ encoder_outputs = BaseModelOutput(
1379
+ last_hidden_state=encoder_outputs[0],
1380
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1381
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1382
+ )
1383
+
1384
+ hidden_states = encoder_outputs[0]
1385
+ if self.model_parallel:
1386
+ torch.cuda.set_device(self.decoder.first_device)
1387
+ # Set device for model parallelism
1388
+ if self.model_parallel:
1389
+ torch.cuda.set_device(self.decoder.first_device)
1390
+ hidden_states = hidden_states.to(self.decoder.first_device)
1391
+ if decoder_input_ids is not None:
1392
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1393
+ if attention_mask is not None:
1394
+ attention_mask = attention_mask.to(self.decoder.first_device)
1395
+ if decoder_attention_mask is not None:
1396
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1397
+
1398
+ # Decode
1399
+ decoder_outputs = self.decoder(
1400
+ input_ids=decoder_input_ids,
1401
+ attention_mask=decoder_attention_mask,
1402
+ inputs_embeds=decoder_inputs_embeds,
1403
+ past_key_values=past_key_values,
1404
+ encoder_hidden_states=hidden_states,
1405
+ encoder_attention_mask=attention_mask,
1406
+ head_mask=decoder_head_mask,
1407
+ cross_attn_head_mask=cross_attn_head_mask,
1408
+ use_cache=use_cache,
1409
+ output_attentions=output_attentions,
1410
+ output_hidden_states=output_hidden_states,
1411
+ return_dict=return_dict,
1412
+ )
1413
+
1414
+ if not return_dict:
1415
+ return decoder_outputs + encoder_outputs
1416
+
1417
+ return Seq2SeqModelOutput(
1418
+ last_hidden_state=decoder_outputs.last_hidden_state,
1419
+ past_key_values=decoder_outputs.past_key_values,
1420
+ decoder_hidden_states=decoder_outputs.hidden_states,
1421
+ decoder_attentions=decoder_outputs.attentions,
1422
+ cross_attentions=decoder_outputs.cross_attentions,
1423
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1424
+ encoder_hidden_states=encoder_outputs.hidden_states,
1425
+ encoder_attentions=encoder_outputs.attentions,
1426
+ )
1427
+
1428
+
1429
+ @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
1430
+ class T5ForConditionalGeneration(T5PreTrainedModel):
1431
+ _keys_to_ignore_on_load_missing = [
1432
+ r"encoder\.embed_tokens\.weight",
1433
+ r"decoder\.embed_tokens\.weight",
1434
+ r"lm_head\.weight",
1435
+ ]
1436
+ _keys_to_ignore_on_load_unexpected = [
1437
+ r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
1438
+ ]
1439
+
1440
+ def __init__(self, config):
1441
+ super().__init__(config)
1442
+ self.model_dim = config.d_model
1443
+
1444
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1445
+
1446
+ encoder_config = copy.deepcopy(config)
1447
+ encoder_config.is_decoder = False
1448
+ encoder_config.use_cache = False
1449
+ encoder_config.is_encoder_decoder = False
1450
+ self.encoder = T5Stack(encoder_config, self.shared)
1451
+
1452
+ decoder_config = copy.deepcopy(config)
1453
+ decoder_config.is_decoder = True
1454
+ decoder_config.is_encoder_decoder = False
1455
+ decoder_config.num_layers = config.num_decoder_layers
1456
+ self.decoder = T5Stack(decoder_config, self.shared)
1457
+
1458
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1459
+
1460
+ # Initialize weights and apply final processing
1461
+ self.post_init()
1462
+
1463
+ # Model parallel
1464
+ self.model_parallel = False
1465
+ self.device_map = None
1466
+
1467
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1468
+ def parallelize(self, device_map=None):
1469
+ self.device_map = (
1470
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1471
+ if device_map is None
1472
+ else device_map
1473
+ )
1474
+ assert_device_map(self.device_map, len(self.encoder.block))
1475
+ self.encoder.parallelize(self.device_map)
1476
+ self.decoder.parallelize(self.device_map)
1477
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1478
+ self.model_parallel = True
1479
+
1480
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1481
+ def deparallelize(self):
1482
+ self.encoder.deparallelize()
1483
+ self.decoder.deparallelize()
1484
+ self.encoder = self.encoder.to("cpu")
1485
+ self.decoder = self.decoder.to("cpu")
1486
+ self.lm_head = self.lm_head.to("cpu")
1487
+ self.model_parallel = False
1488
+ self.device_map = None
1489
+ torch.cuda.empty_cache()
1490
+
1491
+ def get_input_embeddings(self):
1492
+ return self.shared
1493
+
1494
+ def set_input_embeddings(self, new_embeddings):
1495
+ self.shared = new_embeddings
1496
+ self.encoder.set_input_embeddings(new_embeddings)
1497
+ self.decoder.set_input_embeddings(new_embeddings)
1498
+
1499
+ def set_output_embeddings(self, new_embeddings):
1500
+ self.lm_head = new_embeddings
1501
+
1502
+ def get_output_embeddings(self):
1503
+ return self.lm_head
1504
+
1505
+ def get_encoder(self):
1506
+ return self.encoder
1507
+
1508
+ def get_decoder(self):
1509
+ return self.decoder
1510
+
1511
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1512
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1513
+ def forward(
1514
+ self,
1515
+ input_ids=None,
1516
+ attention_mask=None,
1517
+ decoder_input_ids=None,
1518
+ decoder_attention_mask=None,
1519
+ head_mask=None,
1520
+ decoder_head_mask=None,
1521
+ cross_attn_head_mask=None,
1522
+ encoder_outputs=None,
1523
+ past_key_values=None,
1524
+ inputs_embeds=None,
1525
+ decoder_inputs_embeds=None,
1526
+ labels=None,
1527
+ use_cache=None,
1528
+ output_attentions=None,
1529
+ output_hidden_states=None,
1530
+ return_dict=None,
1531
+ ):
1532
+ r"""
1533
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1534
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1535
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1536
+ labels in `[0, ..., config.vocab_size]`
1537
+
1538
+ Returns:
1539
+
1540
+ Examples:
1541
+
1542
+ ```python
1543
+ >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
1544
+
1545
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1546
+ >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
1547
+
1548
+ >>> # training
1549
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
1550
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1551
+ >>> outputs = model(input_ids=input_ids, labels=labels)
1552
+ >>> loss = outputs.loss
1553
+ >>> logits = outputs.logits
1554
+
1555
+ >>> # inference
1556
+ >>> input_ids = tokenizer(
1557
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1558
+ >>> ).input_ids # Batch size 1
1559
+ >>> outputs = model.generate(input_ids)
1560
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1561
+ >>> # studies have shown that owning a dog is good for you.
1562
+ ```"""
1563
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1564
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1565
+
1566
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1567
+ if head_mask is not None and decoder_head_mask is None:
1568
+ if self.config.num_layers == self.config.num_decoder_layers:
1569
+ decoder_head_mask = head_mask
1570
+
1571
+ # Encode if needed (training, first prediction pass)
1572
+ if encoder_outputs is None:
1573
+ # Convert encoder inputs in embeddings if needed
1574
+ encoder_outputs = self.encoder(
1575
+ input_ids=input_ids,
1576
+ attention_mask=attention_mask,
1577
+ inputs_embeds=inputs_embeds,
1578
+ head_mask=head_mask,
1579
+ output_attentions=output_attentions,
1580
+ output_hidden_states=output_hidden_states,
1581
+ return_dict=return_dict,
1582
+ )
1583
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1584
+ encoder_outputs = BaseModelOutput(
1585
+ last_hidden_state=encoder_outputs[0],
1586
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1587
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1588
+ )
1589
+
1590
+ hidden_states = encoder_outputs[0]
1591
+
1592
+ if self.model_parallel:
1593
+ torch.cuda.set_device(self.decoder.first_device)
1594
+
1595
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1596
+ # get decoder inputs from shifting lm labels to the right
1597
+ decoder_input_ids = self._shift_right(labels)
1598
+
1599
+ # Set device for model parallelism
1600
+ if self.model_parallel:
1601
+ torch.cuda.set_device(self.decoder.first_device)
1602
+ hidden_states = hidden_states.to(self.decoder.first_device)
1603
+ if decoder_input_ids is not None:
1604
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1605
+ if attention_mask is not None:
1606
+ attention_mask = attention_mask.to(self.decoder.first_device)
1607
+ if decoder_attention_mask is not None:
1608
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1609
+
1610
+ # Decode
1611
+ decoder_outputs = self.decoder(
1612
+ input_ids=decoder_input_ids,
1613
+ attention_mask=decoder_attention_mask,
1614
+ inputs_embeds=decoder_inputs_embeds,
1615
+ past_key_values=past_key_values,
1616
+ encoder_hidden_states=hidden_states,
1617
+ encoder_attention_mask=attention_mask,
1618
+ head_mask=decoder_head_mask,
1619
+ cross_attn_head_mask=cross_attn_head_mask,
1620
+ use_cache=use_cache,
1621
+ output_attentions=output_attentions,
1622
+ output_hidden_states=output_hidden_states,
1623
+ return_dict=return_dict,
1624
+ )
1625
+
1626
+ sequence_output = decoder_outputs[0]
1627
+
1628
+ # Set device for model parallelism
1629
+ if self.model_parallel:
1630
+ torch.cuda.set_device(self.encoder.first_device)
1631
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
1632
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
1633
+
1634
+ if self.config.tie_word_embeddings:
1635
+ # Rescale output before projecting on vocab
1636
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1637
+ sequence_output = sequence_output * (self.model_dim ** -0.5)
1638
+
1639
+ lm_logits = self.lm_head(sequence_output)
1640
+
1641
+ loss = None
1642
+ if labels is not None:
1643
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1644
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1645
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
1646
+
1647
+ if not return_dict:
1648
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1649
+ return ((loss,) + output) if loss is not None else output
1650
+
1651
+ return Seq2SeqLMOutput(
1652
+ loss=loss,
1653
+ logits=lm_logits,
1654
+ past_key_values=decoder_outputs.past_key_values,
1655
+ decoder_hidden_states=decoder_outputs.hidden_states,
1656
+ decoder_attentions=decoder_outputs.attentions,
1657
+ cross_attentions=decoder_outputs.cross_attentions,
1658
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1659
+ encoder_hidden_states=encoder_outputs.hidden_states,
1660
+ encoder_attentions=encoder_outputs.attentions,
1661
+ )
1662
+
1663
+ def prepare_inputs_for_generation(
1664
+ self,
1665
+ input_ids,
1666
+ past=None,
1667
+ attention_mask=None,
1668
+ head_mask=None,
1669
+ decoder_head_mask=None,
1670
+ cross_attn_head_mask=None,
1671
+ use_cache=None,
1672
+ encoder_outputs=None,
1673
+ **kwargs
1674
+ ):
1675
+
1676
+ # cut decoder_input_ids if past is used
1677
+ if past is not None:
1678
+ input_ids = input_ids[:, -1:]
1679
+
1680
+ return {
1681
+ "decoder_input_ids": input_ids,
1682
+ "past_key_values": past,
1683
+ "encoder_outputs": encoder_outputs,
1684
+ "attention_mask": attention_mask,
1685
+ "head_mask": head_mask,
1686
+ "decoder_head_mask": decoder_head_mask,
1687
+ "cross_attn_head_mask": cross_attn_head_mask,
1688
+ "use_cache": use_cache,
1689
+ }
1690
+
1691
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1692
+ return self._shift_right(labels)
1693
+
1694
+ def _reorder_cache(self, past, beam_idx):
1695
+ # if decoder past is not included in output
1696
+ # speedy decoding is disabled and no need to reorder
1697
+ if past is None:
1698
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1699
+ return past
1700
+
1701
+ reordered_decoder_past = ()
1702
+ for layer_past_states in past:
1703
+ # get the correct batch idx from layer past batch dim
1704
+ # batch dim of `past` is at 2nd position
1705
+ reordered_layer_past_states = ()
1706
+ for layer_past_state in layer_past_states:
1707
+ # need to set correct `past` for each of the four key / value states
1708
+ reordered_layer_past_states = reordered_layer_past_states + (
1709
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1710
+ )
1711
+
1712
+ assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
1713
+ assert len(reordered_layer_past_states) == len(layer_past_states)
1714
+
1715
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1716
+ return reordered_decoder_past
1717
+
1718
+
1719
+ @add_start_docstrings(
1720
+ "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
1721
+ T5_START_DOCSTRING,
1722
+ )
1723
+ class T5EncoderModel(T5PreTrainedModel):
1724
+ authorized_missing_keys = [
1725
+ r"encoder\.embed_tokens\.weight",
1726
+ ]
1727
+
1728
+ def __init__(self, config: T5Config):
1729
+ super().__init__(config)
1730
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1731
+
1732
+ encoder_config = copy.deepcopy(config)
1733
+ encoder_config.use_cache = False
1734
+ encoder_config.is_encoder_decoder = False
1735
+ self.encoder = T5Stack(encoder_config, self.shared)
1736
+
1737
+ # Initialize weights and apply final processing
1738
+ self.post_init()
1739
+
1740
+ # Model parallel
1741
+ self.model_parallel = False
1742
+ self.device_map = None
1743
+
1744
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1745
+ def parallelize(self, device_map=None):
1746
+ self.device_map = (
1747
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1748
+ if device_map is None
1749
+ else device_map
1750
+ )
1751
+ assert_device_map(self.device_map, len(self.encoder.block))
1752
+ self.encoder.parallelize(self.device_map)
1753
+ self.model_parallel = True
1754
+
1755
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1756
+ def deparallelize(self):
1757
+ self.encoder.deparallelize()
1758
+ self.encoder = self.encoder.to("cpu")
1759
+ self.model_parallel = False
1760
+ self.device_map = None
1761
+ torch.cuda.empty_cache()
1762
+
1763
+ def get_input_embeddings(self):
1764
+ return self.shared
1765
+
1766
+ def set_input_embeddings(self, new_embeddings):
1767
+ self.shared = new_embeddings
1768
+ self.encoder.set_input_embeddings(new_embeddings)
1769
+
1770
+ def get_encoder(self):
1771
+ return self.encoder
1772
+
1773
+ def _prune_heads(self, heads_to_prune):
1774
+ """
1775
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1776
+ class PreTrainedModel
1777
+ """
1778
+ for layer, heads in heads_to_prune.items():
1779
+ self.encoder.layer[layer].attention.prune_heads(heads)
1780
+
1781
+ @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
1782
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
1783
+ def forward(
1784
+ self,
1785
+ input_ids=None,
1786
+ attention_mask=None,
1787
+ head_mask=None,
1788
+ inputs_embeds=None,
1789
+ output_attentions=None,
1790
+ output_hidden_states=None,
1791
+ return_dict=None,
1792
+ ):
1793
+ r"""
1794
+ Returns:
1795
+
1796
+ Example:
1797
+
1798
+ ```python
1799
+ >>> from transformers import T5Tokenizer, T5EncoderModel
1800
+
1801
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1802
+ >>> model = T5EncoderModel.from_pretrained("t5-small")
1803
+ >>> input_ids = tokenizer(
1804
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1805
+ >>> ).input_ids # Batch size 1
1806
+ >>> outputs = model(input_ids=input_ids)
1807
+ >>> last_hidden_states = outputs.last_hidden_state
1808
+ ```"""
1809
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1810
+
1811
+ encoder_outputs = self.encoder(
1812
+ input_ids=input_ids,
1813
+ attention_mask=attention_mask,
1814
+ inputs_embeds=inputs_embeds,
1815
+ head_mask=head_mask,
1816
+ output_attentions=output_attentions,
1817
+ output_hidden_states=output_hidden_states,
1818
+ return_dict=return_dict,
1819
+ )
1820
+
1821
+ return encoder_outputs
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:545ebe914ae9b6650ea3f63c491db19c158495c96baa98fbb8cfd4364b453e6b
3
+ size 6845513341