Commit
·
4694579
1
Parent(s):
cba2f63
Update modelling_RW.py
Browse filesensure use 40b file as reference.
- modelling_RW.py +4 -4
modelling_RW.py
CHANGED
|
@@ -52,10 +52,11 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 52 |
|
| 53 |
def __init__(
|
| 54 |
self,
|
| 55 |
-
|
| 56 |
base=10000,
|
| 57 |
-
use_cache=False,
|
| 58 |
):
|
|
|
|
|
|
|
| 59 |
super().__init__()
|
| 60 |
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 61 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
@@ -64,7 +65,6 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 64 |
self.batch_size_cached = None
|
| 65 |
self.cos_cached: torch.Tensor | None = None
|
| 66 |
self.sin_cached: torch.Tensor | None = None
|
| 67 |
-
self.use_cache = use_cache
|
| 68 |
|
| 69 |
def cos_sin(
|
| 70 |
self,
|
|
@@ -184,7 +184,7 @@ class Attention(nn.Module):
|
|
| 184 |
f" {self.num_heads})."
|
| 185 |
)
|
| 186 |
|
| 187 |
-
self.maybe_rotary = RotaryEmbedding(config
|
| 188 |
|
| 189 |
# Layer-wise attention scaling
|
| 190 |
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
|
|
|
| 52 |
|
| 53 |
def __init__(
|
| 54 |
self,
|
| 55 |
+
config,
|
| 56 |
base=10000,
|
|
|
|
| 57 |
):
|
| 58 |
+
head_dim = config.head_dim
|
| 59 |
+
self.use_cache = config.use_cache
|
| 60 |
super().__init__()
|
| 61 |
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 62 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
| 65 |
self.batch_size_cached = None
|
| 66 |
self.cos_cached: torch.Tensor | None = None
|
| 67 |
self.sin_cached: torch.Tensor | None = None
|
|
|
|
| 68 |
|
| 69 |
def cos_sin(
|
| 70 |
self,
|
|
|
|
| 184 |
f" {self.num_heads})."
|
| 185 |
)
|
| 186 |
|
| 187 |
+
self.maybe_rotary = RotaryEmbedding(config) if config.rotary else lambda q, k: (q, k)
|
| 188 |
|
| 189 |
# Layer-wise attention scaling
|
| 190 |
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|