Update blocks.py
Browse files
blocks.py
CHANGED
|
@@ -33,9 +33,9 @@ class MPTBlock(nn.Module):
|
|
| 33 |
|
| 34 |
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
| 35 |
a = self.norm_1(x)
|
| 36 |
-
(b,
|
| 37 |
x = x + self.resid_attn_dropout(b)
|
| 38 |
m = self.norm_2(x)
|
| 39 |
n = self.ffn(m)
|
| 40 |
x = x + self.resid_ffn_dropout(n)
|
| 41 |
-
return (x, past_key_value)
|
|
|
|
| 33 |
|
| 34 |
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
| 35 |
a = self.norm_1(x)
|
| 36 |
+
(b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
|
| 37 |
x = x + self.resid_attn_dropout(b)
|
| 38 |
m = self.norm_2(x)
|
| 39 |
n = self.ffn(m)
|
| 40 |
x = x + self.resid_ffn_dropout(n)
|
| 41 |
+
return (x, attn_weights, past_key_value)
|