davidhd commited on
Commit
84a1166
·
verified ·
1 Parent(s): 8774fb1

Change SaAMPLIFY API to be compatible with AMPLIFY v1

Browse files

src --> input_ids
pad_mask --> attention_mask

Files changed (1) hide show
  1. amplify.py +15 -15
amplify.py CHANGED
@@ -134,13 +134,13 @@ class EncoderBlock(nn.Module):
134
 
135
  self.ffn_dropout = nn.Dropout(config.dropout_prob)
136
 
137
- def forward(self, x: torch.Tensor, pad_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
138
- attn, contact = self._att_block(self.attention_norm(x), pad_mask, freqs_cis, output_attentions)
139
  x = x + attn
140
  x = x + self._ff_block(self.ffn_norm(x))
141
  return x, contact
142
 
143
- def _att_block(self, x: torch.Tensor, pad_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
144
  batch_size, seq_len, _ = x.shape
145
  xq, xk, xv = self.q(x), self.k(x), self.v(x)
146
 
@@ -154,8 +154,8 @@ class EncoderBlock(nn.Module):
154
  attn_weights = None
155
  if output_attentions:
156
  attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
157
- if pad_mask is not None:
158
- attn_weights = attn_weights + pad_mask
159
  attn_weights = attn_weights.softmax(-1)
160
 
161
  # Compute the attention using xformers if the tensors are on GPU
@@ -166,7 +166,7 @@ class EncoderBlock(nn.Module):
166
  query=xq,
167
  key=xk,
168
  value=xv,
169
- attn_bias=pad_mask,
170
  p=self.config.dropout_prob if self.training else 0,
171
  )
172
  else:
@@ -175,7 +175,7 @@ class EncoderBlock(nn.Module):
175
  query=xq.transpose(1, 2),
176
  key=xk.transpose(1, 2),
177
  value=xv.transpose(1, 2),
178
- attn_mask=pad_mask,
179
  dropout_p=self.config.dropout_prob if self.training else 0,
180
  ).transpose(1, 2)
181
 
@@ -249,27 +249,27 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
249
  return model, tokenizer
250
 
251
 
252
- def forward(self, src, pad_mask=None, output_hidden_states=False, output_attentions=False):
253
  # Initialize
254
  hidden_states, attentions = [], []
255
 
256
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
257
- if pad_mask is not None:
258
- assert pad_mask.dtype != torch.bool and 1.0 not in pad_mask, "AMPLIFY expects an additive pad_mask"
259
- pad_mask = pad_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, pad_mask.size(-1), 1)
260
 
261
  # RoPE
262
- self.freqs_cis = self.freqs_cis.to(src.device, non_blocking=True)
263
- freqs_cis = self.freqs_cis[: src.shape[1]]
264
 
265
  # Embedding
266
- x = self.encoder(src)
267
  if self.config.layer_norm_after_embedding:
268
  x = self.layer_norm_1(x)
269
 
270
  # Transformer encoder
271
  for layer in self.transformer_encoder:
272
- x, attn = layer(x, pad_mask, freqs_cis, output_attentions)
273
  if output_hidden_states:
274
  hidden_states.append(x)
275
  if output_attentions:
 
134
 
135
  self.ffn_dropout = nn.Dropout(config.dropout_prob)
136
 
137
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
138
+ attn, contact = self._att_block(self.attention_norm(x), attention_mask, freqs_cis, output_attentions)
139
  x = x + attn
140
  x = x + self._ff_block(self.ffn_norm(x))
141
  return x, contact
142
 
143
+ def _att_block(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
144
  batch_size, seq_len, _ = x.shape
145
  xq, xk, xv = self.q(x), self.k(x), self.v(x)
146
 
 
154
  attn_weights = None
155
  if output_attentions:
156
  attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
157
+ if attention_mask is not None:
158
+ attn_weights = attn_weights + attention_mask
159
  attn_weights = attn_weights.softmax(-1)
160
 
161
  # Compute the attention using xformers if the tensors are on GPU
 
166
  query=xq,
167
  key=xk,
168
  value=xv,
169
+ attn_bias=attention_mask,
170
  p=self.config.dropout_prob if self.training else 0,
171
  )
172
  else:
 
175
  query=xq.transpose(1, 2),
176
  key=xk.transpose(1, 2),
177
  value=xv.transpose(1, 2),
178
+ attn_mask=attention_mask,
179
  dropout_p=self.config.dropout_prob if self.training else 0,
180
  ).transpose(1, 2)
181
 
 
249
  return model, tokenizer
250
 
251
 
252
+ def forward(self, input_ids, attention_mask=None, output_hidden_states=False, output_attentions=False):
253
  # Initialize
254
  hidden_states, attentions = [], []
255
 
256
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
257
+ if attention_mask is not None:
258
+ assert attention_mask.dtype != torch.bool and 1.0 not in attention_mask, "AMPLIFY expects an additive attention_mask"
259
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
260
 
261
  # RoPE
262
+ self.freqs_cis = self.freqs_cis.to(input_ids.device, non_blocking=True)
263
+ freqs_cis = self.freqs_cis[: input_ids.shape[1]]
264
 
265
  # Embedding
266
+ x = self.encoder(input_ids)
267
  if self.config.layer_norm_after_embedding:
268
  x = self.layer_norm_1(x)
269
 
270
  # Transformer encoder
271
  for layer in self.transformer_encoder:
272
+ x, attn = layer(x, attention_mask, freqs_cis, output_attentions)
273
  if output_hidden_states:
274
  hidden_states.append(x)
275
  if output_attentions: