Change SaAMPLIFY API to be compatible with AMPLIFY v1
Browse filessrc --> input_ids
pad_mask --> attention_mask
- amplify.py +15 -15
amplify.py
CHANGED
|
@@ -134,13 +134,13 @@ class EncoderBlock(nn.Module):
|
|
| 134 |
|
| 135 |
self.ffn_dropout = nn.Dropout(config.dropout_prob)
|
| 136 |
|
| 137 |
-
def forward(self, x: torch.Tensor,
|
| 138 |
-
attn, contact = self._att_block(self.attention_norm(x),
|
| 139 |
x = x + attn
|
| 140 |
x = x + self._ff_block(self.ffn_norm(x))
|
| 141 |
return x, contact
|
| 142 |
|
| 143 |
-
def _att_block(self, x: torch.Tensor,
|
| 144 |
batch_size, seq_len, _ = x.shape
|
| 145 |
xq, xk, xv = self.q(x), self.k(x), self.v(x)
|
| 146 |
|
|
@@ -154,8 +154,8 @@ class EncoderBlock(nn.Module):
|
|
| 154 |
attn_weights = None
|
| 155 |
if output_attentions:
|
| 156 |
attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
|
| 157 |
-
if
|
| 158 |
-
attn_weights = attn_weights +
|
| 159 |
attn_weights = attn_weights.softmax(-1)
|
| 160 |
|
| 161 |
# Compute the attention using xformers if the tensors are on GPU
|
|
@@ -166,7 +166,7 @@ class EncoderBlock(nn.Module):
|
|
| 166 |
query=xq,
|
| 167 |
key=xk,
|
| 168 |
value=xv,
|
| 169 |
-
attn_bias=
|
| 170 |
p=self.config.dropout_prob if self.training else 0,
|
| 171 |
)
|
| 172 |
else:
|
|
@@ -175,7 +175,7 @@ class EncoderBlock(nn.Module):
|
|
| 175 |
query=xq.transpose(1, 2),
|
| 176 |
key=xk.transpose(1, 2),
|
| 177 |
value=xv.transpose(1, 2),
|
| 178 |
-
attn_mask=
|
| 179 |
dropout_p=self.config.dropout_prob if self.training else 0,
|
| 180 |
).transpose(1, 2)
|
| 181 |
|
|
@@ -249,27 +249,27 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
| 249 |
return model, tokenizer
|
| 250 |
|
| 251 |
|
| 252 |
-
def forward(self,
|
| 253 |
# Initialize
|
| 254 |
hidden_states, attentions = [], []
|
| 255 |
|
| 256 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
| 257 |
-
if
|
| 258 |
-
assert
|
| 259 |
-
|
| 260 |
|
| 261 |
# RoPE
|
| 262 |
-
self.freqs_cis = self.freqs_cis.to(
|
| 263 |
-
freqs_cis = self.freqs_cis[:
|
| 264 |
|
| 265 |
# Embedding
|
| 266 |
-
x = self.encoder(
|
| 267 |
if self.config.layer_norm_after_embedding:
|
| 268 |
x = self.layer_norm_1(x)
|
| 269 |
|
| 270 |
# Transformer encoder
|
| 271 |
for layer in self.transformer_encoder:
|
| 272 |
-
x, attn = layer(x,
|
| 273 |
if output_hidden_states:
|
| 274 |
hidden_states.append(x)
|
| 275 |
if output_attentions:
|
|
|
|
| 134 |
|
| 135 |
self.ffn_dropout = nn.Dropout(config.dropout_prob)
|
| 136 |
|
| 137 |
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
|
| 138 |
+
attn, contact = self._att_block(self.attention_norm(x), attention_mask, freqs_cis, output_attentions)
|
| 139 |
x = x + attn
|
| 140 |
x = x + self._ff_block(self.ffn_norm(x))
|
| 141 |
return x, contact
|
| 142 |
|
| 143 |
+
def _att_block(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
|
| 144 |
batch_size, seq_len, _ = x.shape
|
| 145 |
xq, xk, xv = self.q(x), self.k(x), self.v(x)
|
| 146 |
|
|
|
|
| 154 |
attn_weights = None
|
| 155 |
if output_attentions:
|
| 156 |
attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
|
| 157 |
+
if attention_mask is not None:
|
| 158 |
+
attn_weights = attn_weights + attention_mask
|
| 159 |
attn_weights = attn_weights.softmax(-1)
|
| 160 |
|
| 161 |
# Compute the attention using xformers if the tensors are on GPU
|
|
|
|
| 166 |
query=xq,
|
| 167 |
key=xk,
|
| 168 |
value=xv,
|
| 169 |
+
attn_bias=attention_mask,
|
| 170 |
p=self.config.dropout_prob if self.training else 0,
|
| 171 |
)
|
| 172 |
else:
|
|
|
|
| 175 |
query=xq.transpose(1, 2),
|
| 176 |
key=xk.transpose(1, 2),
|
| 177 |
value=xv.transpose(1, 2),
|
| 178 |
+
attn_mask=attention_mask,
|
| 179 |
dropout_p=self.config.dropout_prob if self.training else 0,
|
| 180 |
).transpose(1, 2)
|
| 181 |
|
|
|
|
| 249 |
return model, tokenizer
|
| 250 |
|
| 251 |
|
| 252 |
+
def forward(self, input_ids, attention_mask=None, output_hidden_states=False, output_attentions=False):
|
| 253 |
# Initialize
|
| 254 |
hidden_states, attentions = [], []
|
| 255 |
|
| 256 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
| 257 |
+
if attention_mask is not None:
|
| 258 |
+
assert attention_mask.dtype != torch.bool and 1.0 not in attention_mask, "AMPLIFY expects an additive attention_mask"
|
| 259 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
|
| 260 |
|
| 261 |
# RoPE
|
| 262 |
+
self.freqs_cis = self.freqs_cis.to(input_ids.device, non_blocking=True)
|
| 263 |
+
freqs_cis = self.freqs_cis[: input_ids.shape[1]]
|
| 264 |
|
| 265 |
# Embedding
|
| 266 |
+
x = self.encoder(input_ids)
|
| 267 |
if self.config.layer_norm_after_embedding:
|
| 268 |
x = self.layer_norm_1(x)
|
| 269 |
|
| 270 |
# Transformer encoder
|
| 271 |
for layer in self.transformer_encoder:
|
| 272 |
+
x, attn = layer(x, attention_mask, freqs_cis, output_attentions)
|
| 273 |
if output_hidden_states:
|
| 274 |
hidden_states.append(x)
|
| 275 |
if output_attentions:
|