Upload wkv.py
Browse files
wkv.py
CHANGED
|
@@ -6,6 +6,8 @@ import math
|
|
| 6 |
import torch.nn as nn
|
| 7 |
from torch.nn import functional as F
|
| 8 |
from .configuration_rwkv_hybrid import RwkvHybridConfig
|
|
|
|
|
|
|
| 9 |
|
| 10 |
try:
|
| 11 |
import triton
|
|
@@ -13,6 +15,7 @@ try:
|
|
| 13 |
fused_recurrent_rwkv7,
|
| 14 |
chunk_rwkv7,
|
| 15 |
native_recurrent_rwkv7,
|
|
|
|
| 16 |
) # pylint: disable=C0411
|
| 17 |
from rwkvfla.ops.rwkv6 import (
|
| 18 |
fused_recurrent_rwkv6,
|
|
@@ -22,11 +25,13 @@ try:
|
|
| 22 |
except ImportError:
|
| 23 |
from rwkvfla.ops.rwkv7 import native_recurrent_rwkv7 # pylint: disable=C0411
|
| 24 |
from rwkvfla.ops.rwkv6 import native_recurrent_rwkv6
|
|
|
|
| 25 |
|
| 26 |
fused_recurrent_rwkv7 = native_recurrent_rwkv7
|
| 27 |
chunk_rwkv7 = native_recurrent_rwkv7
|
| 28 |
chunk_rwkv6 = native_recurrent_rwkv6
|
| 29 |
fused_recurrent_rwkv6 = native_recurrent_rwkv6
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
class Rwkv_Tmix_x070(nn.Module):
|
|
@@ -50,8 +55,7 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
| 50 |
self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
| 51 |
self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
| 52 |
self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
| 53 |
-
|
| 54 |
-
|
| 55 |
D_DECAY_LORA = 64
|
| 56 |
D_AAA_LORA = 64
|
| 57 |
D_MV_LORA = 32
|
|
@@ -70,6 +74,7 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
| 70 |
self.v0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
| 71 |
|
| 72 |
if self.args.wkv_has_gate:
|
|
|
|
| 73 |
self.g1 = nn.Parameter(torch.Tensor(args.hidden_size, D_GATE_LORA))
|
| 74 |
self.g2 = nn.Parameter(torch.Tensor(D_GATE_LORA, args.hidden_size))
|
| 75 |
|
|
@@ -78,7 +83,8 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
| 78 |
self.r_k = nn.Parameter(torch.Tensor(H, N))
|
| 79 |
|
| 80 |
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
| 81 |
-
self.receptance = nn.Linear(
|
|
|
|
| 82 |
self.key = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
| 83 |
self.value = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
| 84 |
self.output = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
@@ -90,7 +96,8 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
| 90 |
|
| 91 |
def post_init(self):
|
| 92 |
with torch.no_grad():
|
| 93 |
-
ratio_0_to_1 = self.layer_id /
|
|
|
|
| 94 |
ratio_1_to_almost0 = 1.0 - (
|
| 95 |
self.layer_id / self.args.num_hidden_layers
|
| 96 |
) # 1 to ~0
|
|
@@ -99,39 +106,48 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
| 99 |
for i in range(self.args.hidden_size):
|
| 100 |
ddd[0, 0, i] = i / self.args.hidden_size
|
| 101 |
|
| 102 |
-
nn.init.constant_(
|
| 103 |
-
|
|
|
|
|
|
|
| 104 |
nn.init.constant_(
|
| 105 |
self.x_k,
|
| 106 |
-
1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) +
|
|
|
|
| 107 |
)
|
| 108 |
nn.init.constant_(
|
| 109 |
self.x_v,
|
| 110 |
-
1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) +
|
|
|
|
| 111 |
)
|
| 112 |
-
nn.init.constant_(
|
| 113 |
-
|
|
|
|
| 114 |
|
| 115 |
def ortho_init(x, scale):
|
| 116 |
shape = x.shape
|
| 117 |
original_dtype = x.dtype
|
| 118 |
x_fp32 = x.float()
|
| 119 |
if len(shape) == 2:
|
| 120 |
-
gain = math.sqrt(shape[0] / shape[1]
|
|
|
|
| 121 |
nn.init.orthogonal_(x_fp32, gain=gain * scale)
|
| 122 |
elif len(shape) == 3:
|
| 123 |
-
gain = math.sqrt(shape[1] / shape[2]
|
|
|
|
| 124 |
for i in range(shape[0]):
|
| 125 |
nn.init.orthogonal_(x_fp32[i], gain=gain * scale)
|
| 126 |
else:
|
| 127 |
-
raise ValueError(
|
|
|
|
| 128 |
x.data.copy_(x_fp32.to(original_dtype))
|
| 129 |
return x
|
| 130 |
|
| 131 |
D_DECAY_LORA = 64
|
| 132 |
nn.init.zeros_(self.w1)
|
| 133 |
self.w2 = nn.Parameter(
|
| 134 |
-
ortho_init(torch.zeros(
|
|
|
|
| 135 |
)
|
| 136 |
|
| 137 |
decay_speed = torch.ones(self.args.hidden_size)
|
|
@@ -161,8 +177,11 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
| 161 |
if self.args.wkv_has_gate:
|
| 162 |
nn.init.zeros_(self.g1)
|
| 163 |
self.g2 = nn.Parameter(
|
| 164 |
-
ortho_init(torch.zeros(
|
|
|
|
| 165 |
)
|
|
|
|
|
|
|
| 166 |
|
| 167 |
nn.init.constant_(self.k_k, 0.85)
|
| 168 |
nn.init.constant_(self.k_a, 1.0)
|
|
@@ -177,77 +196,68 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
| 177 |
nn.init.ones_(self.ln_x.weight)
|
| 178 |
nn.init.zeros_(self.ln_x.bias)
|
| 179 |
|
| 180 |
-
def apply_wkv7_state(self, r, k, v, w, a, b, s
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
a = rearrange(a, "b l (h d) -> b h l d", h=self.n_head)
|
| 186 |
-
b = rearrange(b, "b l (h d) -> b h l d", h=self.n_head)
|
| 187 |
|
| 188 |
if r.device.type == "cpu":
|
|
|
|
| 189 |
o, state = native_recurrent_rwkv7(
|
| 190 |
-
r,
|
| 191 |
-
|
| 192 |
-
v,
|
| 193 |
-
w,
|
| 194 |
-
a,
|
| 195 |
-
b,
|
| 196 |
scale=1.0,
|
| 197 |
initial_state=s.transpose(-1, -2),
|
| 198 |
output_final_state=True,
|
| 199 |
-
use_log_w=False,
|
| 200 |
head_first=True,
|
| 201 |
)
|
| 202 |
state = state.transpose(-1, -2)
|
| 203 |
-
|
| 204 |
-
o, state = chunk_rwkv7(
|
| 205 |
-
r,
|
| 206 |
-
k,
|
| 207 |
-
v,
|
| 208 |
-
w,
|
| 209 |
-
a,
|
| 210 |
-
b,
|
| 211 |
-
scale=1.0,
|
| 212 |
-
initial_state=s,
|
| 213 |
-
output_final_state=True,
|
| 214 |
-
use_log_w=False,
|
| 215 |
-
head_first=True,
|
| 216 |
-
)
|
| 217 |
else:
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
v,
|
| 222 |
-
|
| 223 |
-
a,
|
| 224 |
-
b,
|
| 225 |
scale=1.0,
|
| 226 |
initial_state=s,
|
| 227 |
-
output_final_state=
|
| 228 |
-
|
| 229 |
-
head_first=
|
| 230 |
)
|
| 231 |
-
|
| 232 |
-
x = rearrange(o, "b h l d -> b l (h d)")
|
| 233 |
return x, state
|
| 234 |
|
| 235 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
shift_state = last_state.shift_state
|
| 237 |
-
B, T, C =
|
| 238 |
-
|
| 239 |
if shift_state is not None:
|
| 240 |
-
xx = torch.concat((shift_state.unsqueeze(
|
|
|
|
| 241 |
else:
|
| 242 |
-
xx = self.time_shift(
|
| 243 |
-
lx = x[:, -1]
|
| 244 |
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
|
|
|
| 251 |
|
| 252 |
r = self.receptance(xr)
|
| 253 |
w = (
|
|
@@ -269,11 +279,11 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
| 269 |
if self.args.wkv_has_gate:
|
| 270 |
g = torch.sigmoid(xg @ self.g1) @ self.g2
|
| 271 |
kk = k * self.k_k
|
| 272 |
-
kk = F.normalize(kk.view(B, T,
|
| 273 |
k = k * (1 + (a - 1) * self.k_a)
|
| 274 |
|
| 275 |
wkv_state = last_state.wkv_state
|
| 276 |
-
|
| 277 |
r,
|
| 278 |
k,
|
| 279 |
v,
|
|
@@ -281,17 +291,22 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
| 281 |
-kk,
|
| 282 |
(kk * a),
|
| 283 |
s=wkv_state,
|
|
|
|
|
|
|
|
|
|
| 284 |
)
|
| 285 |
if self.args.wkv_has_group_norm:
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
|
|
|
| 289 |
dim=-1, keepdim=True
|
| 290 |
)
|
| 291 |
-
* v.view(B, T,
|
| 292 |
).view(B, T, C)
|
| 293 |
-
|
| 294 |
-
|
|
|
|
| 295 |
|
| 296 |
|
| 297 |
class Rwkv7Attention(nn.Module):
|
|
@@ -299,24 +314,43 @@ class Rwkv7Attention(nn.Module):
|
|
| 299 |
super().__init__()
|
| 300 |
self.args = args
|
| 301 |
self.layer_idx = layer_id
|
| 302 |
-
self.time_mixer = Rwkv_Tmix_x070(
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
if past_key_value is not None and len(past_key_value) > self.layer_idx:
|
| 309 |
last_state = past_key_value[self.layer_idx][0]
|
| 310 |
else:
|
| 311 |
last_state = self.init_state(
|
| 312 |
-
batch_size,
|
| 313 |
)
|
| 314 |
|
| 315 |
-
attn_output, states = self.time_mixer(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
last_state.time_mix_state = states
|
| 317 |
|
| 318 |
if past_key_value is not None:
|
| 319 |
past_key_value.update(token_length, last_state, self.layer_idx)
|
|
|
|
| 320 |
return attn_output, None
|
| 321 |
|
| 322 |
def init_state(self, batch_size, device, dtype) -> BlockState:
|
|
@@ -357,9 +391,12 @@ class Rwkv_Tmix_x060(nn.Module):
|
|
| 357 |
ddd[0, 0, i] = i / args.hidden_size
|
| 358 |
|
| 359 |
# fancy time_mix
|
| 360 |
-
self.time_maa_x = nn.Parameter(
|
| 361 |
-
|
| 362 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 363 |
self.time_maa_v = nn.Parameter(
|
| 364 |
1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
| 365 |
)
|
|
@@ -377,7 +414,8 @@ class Rwkv_Tmix_x060(nn.Module):
|
|
| 377 |
torch.zeros(args.hidden_size, D_MIX_LORA * 5)
|
| 378 |
)
|
| 379 |
self.time_maa_w2 = nn.Parameter(
|
| 380 |
-
torch.zeros(5, D_MIX_LORA,
|
|
|
|
| 381 |
)
|
| 382 |
|
| 383 |
# fancy time_decay
|
|
@@ -386,7 +424,8 @@ class Rwkv_Tmix_x060(nn.Module):
|
|
| 386 |
decay_speed[n] = -6 + 5 * (n / (args.head_size - 1)) ** (
|
| 387 |
0.7 + 1.3 * ratio_0_to_1
|
| 388 |
)
|
| 389 |
-
self.time_decay = nn.Parameter(
|
|
|
|
| 390 |
|
| 391 |
D_DECAY_LORA = 64
|
| 392 |
if args.hidden_size == 4096:
|
|
@@ -401,13 +440,16 @@ class Rwkv_Tmix_x060(nn.Module):
|
|
| 401 |
tmp = torch.zeros(args.head_size)
|
| 402 |
for n in range(args.head_size):
|
| 403 |
zigzag = ((n + 1) % 3 - 1) * 0.1
|
| 404 |
-
tmp[n] = ratio_0_to_1 *
|
|
|
|
| 405 |
|
| 406 |
-
self.time_faaaa = nn.Parameter(
|
|
|
|
| 407 |
# self.time_state = nn.Parameter(torch.zeros(self.n_head, self.head_size, self.head_size))
|
| 408 |
|
| 409 |
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
| 410 |
-
self.receptance = nn.Linear(
|
|
|
|
| 411 |
self.key = nn.Linear(args.hidden_size, args.head_size, bias=False)
|
| 412 |
|
| 413 |
self.value = nn.Linear(args.hidden_size, args.head_size, bias=False)
|
|
@@ -416,7 +458,8 @@ class Rwkv_Tmix_x060(nn.Module):
|
|
| 416 |
|
| 417 |
if self.args.wkv_has_group_norm:
|
| 418 |
self.ln_x = nn.GroupNorm(
|
| 419 |
-
self.n_head, args.head_size, eps=(
|
|
|
|
| 420 |
)
|
| 421 |
|
| 422 |
def post_init(self):
|
|
@@ -433,7 +476,8 @@ class Rwkv_Tmix_x060(nn.Module):
|
|
| 433 |
lx = x[:, -1]
|
| 434 |
|
| 435 |
xxx = x + xx * self.time_maa_x
|
| 436 |
-
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B *
|
|
|
|
| 437 |
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
|
| 438 |
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
|
| 439 |
|
|
@@ -461,10 +505,7 @@ class Rwkv_Tmix_x060(nn.Module):
|
|
| 461 |
return x, TimeMixState(lx, wkv_state)
|
| 462 |
|
| 463 |
def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
|
| 464 |
-
r = rearrange(
|
| 465 |
-
k = rearrange(k, "b l (h d) -> b h l d", h=H)
|
| 466 |
-
v = rearrange(v, "b l (h d) -> b h l d", h=H)
|
| 467 |
-
w = rearrange(w, "b l (h d) -> b h l d", h=H)
|
| 468 |
|
| 469 |
if r.device.type == "cpu":
|
| 470 |
wkv6_func = native_recurrent_rwkv6
|
|
@@ -504,7 +545,8 @@ class Rwkv6Attention(nn.Module):
|
|
| 504 |
last_state = past_key_value[self.layer_idx][0]
|
| 505 |
if last_state is None:
|
| 506 |
wkv_states = torch.zeros(
|
| 507 |
-
(B, self.args.num_wkv_heads,
|
|
|
|
| 508 |
device=attn_output.device,
|
| 509 |
dtype=torch.float32,
|
| 510 |
)
|
|
@@ -514,7 +556,8 @@ class Rwkv6Attention(nn.Module):
|
|
| 514 |
time_state = TimeMixState(token_shift, wkv_states)
|
| 515 |
channel_state = None
|
| 516 |
last_state = BlockState(time_state, channel_state)
|
| 517 |
-
attn_output, states = self.time_mixer(
|
|
|
|
| 518 |
last_state.time_mix_state = states
|
| 519 |
|
| 520 |
if past_key_value is not None:
|
|
|
|
| 6 |
import torch.nn as nn
|
| 7 |
from torch.nn import functional as F
|
| 8 |
from .configuration_rwkv_hybrid import RwkvHybridConfig
|
| 9 |
+
from typing import TYPE_CHECKING, Optional
|
| 10 |
+
from transformers.cache_utils import Cache
|
| 11 |
|
| 12 |
try:
|
| 13 |
import triton
|
|
|
|
| 15 |
fused_recurrent_rwkv7,
|
| 16 |
chunk_rwkv7,
|
| 17 |
native_recurrent_rwkv7,
|
| 18 |
+
fused_addcmul_rwkv7,
|
| 19 |
) # pylint: disable=C0411
|
| 20 |
from rwkvfla.ops.rwkv6 import (
|
| 21 |
fused_recurrent_rwkv6,
|
|
|
|
| 25 |
except ImportError:
|
| 26 |
from rwkvfla.ops.rwkv7 import native_recurrent_rwkv7 # pylint: disable=C0411
|
| 27 |
from rwkvfla.ops.rwkv6 import native_recurrent_rwkv6
|
| 28 |
+
from rwkvfla.ops.rwkv7 import torch_addcmul_rwkv7
|
| 29 |
|
| 30 |
fused_recurrent_rwkv7 = native_recurrent_rwkv7
|
| 31 |
chunk_rwkv7 = native_recurrent_rwkv7
|
| 32 |
chunk_rwkv6 = native_recurrent_rwkv6
|
| 33 |
fused_recurrent_rwkv6 = native_recurrent_rwkv6
|
| 34 |
+
fused_addcmul_rwkv7 = torch_addcmul_rwkv7
|
| 35 |
|
| 36 |
|
| 37 |
class Rwkv_Tmix_x070(nn.Module):
|
|
|
|
| 55 |
self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
| 56 |
self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
| 57 |
self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
| 58 |
+
|
|
|
|
| 59 |
D_DECAY_LORA = 64
|
| 60 |
D_AAA_LORA = 64
|
| 61 |
D_MV_LORA = 32
|
|
|
|
| 74 |
self.v0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
| 75 |
|
| 76 |
if self.args.wkv_has_gate:
|
| 77 |
+
self.x_g = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
| 78 |
self.g1 = nn.Parameter(torch.Tensor(args.hidden_size, D_GATE_LORA))
|
| 79 |
self.g2 = nn.Parameter(torch.Tensor(D_GATE_LORA, args.hidden_size))
|
| 80 |
|
|
|
|
| 83 |
self.r_k = nn.Parameter(torch.Tensor(H, N))
|
| 84 |
|
| 85 |
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
| 86 |
+
self.receptance = nn.Linear(
|
| 87 |
+
args.hidden_size, args.hidden_size, bias=False)
|
| 88 |
self.key = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
| 89 |
self.value = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
| 90 |
self.output = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
|
|
| 96 |
|
| 97 |
def post_init(self):
|
| 98 |
with torch.no_grad():
|
| 99 |
+
ratio_0_to_1 = self.layer_id / \
|
| 100 |
+
(self.args.num_hidden_layers - 1) # 0 to 1
|
| 101 |
ratio_1_to_almost0 = 1.0 - (
|
| 102 |
self.layer_id / self.args.num_hidden_layers
|
| 103 |
) # 1 to ~0
|
|
|
|
| 106 |
for i in range(self.args.hidden_size):
|
| 107 |
ddd[0, 0, i] = i / self.args.hidden_size
|
| 108 |
|
| 109 |
+
nn.init.constant_(
|
| 110 |
+
self.x_r, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
|
| 111 |
+
nn.init.constant_(
|
| 112 |
+
self.x_w, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
|
| 113 |
nn.init.constant_(
|
| 114 |
self.x_k,
|
| 115 |
+
1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) +
|
| 116 |
+
0.4 * ratio_0_to_1),
|
| 117 |
)
|
| 118 |
nn.init.constant_(
|
| 119 |
self.x_v,
|
| 120 |
+
1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) +
|
| 121 |
+
0.6 * ratio_0_to_1),
|
| 122 |
)
|
| 123 |
+
nn.init.constant_(
|
| 124 |
+
self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
|
| 125 |
+
|
| 126 |
|
| 127 |
def ortho_init(x, scale):
|
| 128 |
shape = x.shape
|
| 129 |
original_dtype = x.dtype
|
| 130 |
x_fp32 = x.float()
|
| 131 |
if len(shape) == 2:
|
| 132 |
+
gain = math.sqrt(shape[0] / shape[1]
|
| 133 |
+
) if shape[0] > shape[1] else 1
|
| 134 |
nn.init.orthogonal_(x_fp32, gain=gain * scale)
|
| 135 |
elif len(shape) == 3:
|
| 136 |
+
gain = math.sqrt(shape[1] / shape[2]
|
| 137 |
+
) if shape[1] > shape[2] else 1
|
| 138 |
for i in range(shape[0]):
|
| 139 |
nn.init.orthogonal_(x_fp32[i], gain=gain * scale)
|
| 140 |
else:
|
| 141 |
+
raise ValueError(
|
| 142 |
+
"ortho_init only supports 2D or 3D tensors")
|
| 143 |
x.data.copy_(x_fp32.to(original_dtype))
|
| 144 |
return x
|
| 145 |
|
| 146 |
D_DECAY_LORA = 64
|
| 147 |
nn.init.zeros_(self.w1)
|
| 148 |
self.w2 = nn.Parameter(
|
| 149 |
+
ortho_init(torch.zeros(
|
| 150 |
+
D_DECAY_LORA, self.args.hidden_size), 0.1)
|
| 151 |
)
|
| 152 |
|
| 153 |
decay_speed = torch.ones(self.args.hidden_size)
|
|
|
|
| 177 |
if self.args.wkv_has_gate:
|
| 178 |
nn.init.zeros_(self.g1)
|
| 179 |
self.g2 = nn.Parameter(
|
| 180 |
+
ortho_init(torch.zeros(
|
| 181 |
+
D_GATE_LORA, self.args.hidden_size), 0.1)
|
| 182 |
)
|
| 183 |
+
nn.init.constant_(
|
| 184 |
+
self.x_g, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
|
| 185 |
|
| 186 |
nn.init.constant_(self.k_k, 0.85)
|
| 187 |
nn.init.constant_(self.k_a, 1.0)
|
|
|
|
| 196 |
nn.init.ones_(self.ln_x.weight)
|
| 197 |
nn.init.zeros_(self.ln_x.bias)
|
| 198 |
|
| 199 |
+
def apply_wkv7_state(self, r, k, v, w, a, b, s,
|
| 200 |
+
output_final_state,
|
| 201 |
+
cu_seqlens,
|
| 202 |
+
head_first
|
| 203 |
+
):
|
|
|
|
|
|
|
| 204 |
|
| 205 |
if r.device.type == "cpu":
|
| 206 |
+
r, w, k, v, a, b = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
|
| 207 |
o, state = native_recurrent_rwkv7(
|
| 208 |
+
r=r, k=k, v=v, w=w,
|
| 209 |
+
a=a, b=b,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
scale=1.0,
|
| 211 |
initial_state=s.transpose(-1, -2),
|
| 212 |
output_final_state=True,
|
|
|
|
| 213 |
head_first=True,
|
| 214 |
)
|
| 215 |
state = state.transpose(-1, -2)
|
| 216 |
+
x = rearrange(o, "b h l d -> b l (h d)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
else:
|
| 218 |
+
r, w, k, v, a, b = map(lambda x: rearrange(x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
|
| 219 |
+
wkv7_func = chunk_rwkv7 if self.training else fused_recurrent_rwkv7
|
| 220 |
+
o, state = wkv7_func(
|
| 221 |
+
r=r, k=k, v=v, w=w,
|
| 222 |
+
a=a, b=b,
|
|
|
|
|
|
|
| 223 |
scale=1.0,
|
| 224 |
initial_state=s,
|
| 225 |
+
output_final_state=output_final_state,
|
| 226 |
+
cu_seqlens=cu_seqlens,
|
| 227 |
+
head_first=head_first,
|
| 228 |
)
|
| 229 |
+
x = rearrange(o, "b l h d -> b l (h d)")
|
|
|
|
| 230 |
return x, state
|
| 231 |
|
| 232 |
+
def forward(
|
| 233 |
+
self,
|
| 234 |
+
hidden_states,
|
| 235 |
+
last_state: TimeMixState,
|
| 236 |
+
sequence_mask: Optional[torch.Tensor] = None,
|
| 237 |
+
use_cache: Optional[bool] = False,
|
| 238 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 239 |
+
**kwargs
|
| 240 |
+
):
|
| 241 |
+
if sequence_mask is not None:
|
| 242 |
+
hidden_states = hidden_states.mul(
|
| 243 |
+
sequence_mask[:, -hidden_states.shape[-2]:, None])
|
| 244 |
+
|
| 245 |
shift_state = last_state.shift_state
|
| 246 |
+
B, T, C = hidden_states.size()
|
| 247 |
+
|
| 248 |
if shift_state is not None:
|
| 249 |
+
xx = torch.concat((shift_state.unsqueeze(
|
| 250 |
+
1), hidden_states[:, :-1]), dim=1) - hidden_states
|
| 251 |
else:
|
| 252 |
+
xx = self.time_shift(hidden_states) - hidden_states
|
|
|
|
| 253 |
|
| 254 |
+
lx = hidden_states[:, -1]
|
| 255 |
+
|
| 256 |
+
if self.args.wkv_has_gate:
|
| 257 |
+
xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(
|
| 258 |
+
hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a, self.x_g)
|
| 259 |
+
else:
|
| 260 |
+
xr, xw, xk, xv, xa, _ = fused_addcmul_rwkv7(hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a)
|
| 261 |
|
| 262 |
r = self.receptance(xr)
|
| 263 |
w = (
|
|
|
|
| 279 |
if self.args.wkv_has_gate:
|
| 280 |
g = torch.sigmoid(xg @ self.g1) @ self.g2
|
| 281 |
kk = k * self.k_k
|
| 282 |
+
kk = F.normalize(kk.view(B, T, self.n_head, -1), dim=-1, p=2.0).view(B, T, C)
|
| 283 |
k = k * (1 + (a - 1) * self.k_a)
|
| 284 |
|
| 285 |
wkv_state = last_state.wkv_state
|
| 286 |
+
hidden_states, wkv_state = self.apply_wkv7_state(
|
| 287 |
r,
|
| 288 |
k,
|
| 289 |
v,
|
|
|
|
| 291 |
-kk,
|
| 292 |
(kk * a),
|
| 293 |
s=wkv_state,
|
| 294 |
+
output_final_state=use_cache,
|
| 295 |
+
cu_seqlens=cu_seqlens,
|
| 296 |
+
head_first=False
|
| 297 |
)
|
| 298 |
if self.args.wkv_has_group_norm:
|
| 299 |
+
hidden_states = self.ln_x(
|
| 300 |
+
hidden_states.view(B * T, C)).view(B, T, C)
|
| 301 |
+
hidden_states = hidden_states + (
|
| 302 |
+
(r.view(B, T, self.n_head, -1) * k.view(B, T, self.n_head, -1) * self.r_k).sum(
|
| 303 |
dim=-1, keepdim=True
|
| 304 |
)
|
| 305 |
+
* v.view(B, T, self.n_head, -1)
|
| 306 |
).view(B, T, C)
|
| 307 |
+
hidden_states = self.output(
|
| 308 |
+
hidden_states * g) if self.args.wkv_has_gate else self.output(hidden_states)
|
| 309 |
+
return hidden_states, TimeMixState(lx, wkv_state)
|
| 310 |
|
| 311 |
|
| 312 |
class Rwkv7Attention(nn.Module):
|
|
|
|
| 314 |
super().__init__()
|
| 315 |
self.args = args
|
| 316 |
self.layer_idx = layer_id
|
| 317 |
+
self.time_mixer = Rwkv_Tmix_x070(
|
| 318 |
+
args, layer_id, update_v_first, get_v_first)
|
| 319 |
+
|
| 320 |
+
def forward(
|
| 321 |
+
self,
|
| 322 |
+
hidden_states: torch.Tensor,
|
| 323 |
+
sequence_mask: Optional[torch.Tensor] = None,
|
| 324 |
+
past_key_value: Optional[Cache] = None,
|
| 325 |
+
use_cache: Optional[bool] = False,
|
| 326 |
+
output_attentions: Optional[bool] = False,
|
| 327 |
+
**kwargs
|
| 328 |
+
):
|
| 329 |
+
if sequence_mask is not None:
|
| 330 |
+
assert len(sequence_mask.shape) == 2, (
|
| 331 |
+
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
| 332 |
+
"for padding purposes (0 indicating padding). "
|
| 333 |
+
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
|
| 334 |
+
)
|
| 335 |
+
batch_size, token_length, _ = hidden_states.shape
|
| 336 |
|
| 337 |
if past_key_value is not None and len(past_key_value) > self.layer_idx:
|
| 338 |
last_state = past_key_value[self.layer_idx][0]
|
| 339 |
else:
|
| 340 |
last_state = self.init_state(
|
| 341 |
+
batch_size, hidden_states.device, hidden_states.dtype
|
| 342 |
)
|
| 343 |
|
| 344 |
+
attn_output, states = self.time_mixer(hidden_states=hidden_states,
|
| 345 |
+
last_state=last_state.time_mix_state,
|
| 346 |
+
sequence_mask=sequence_mask,
|
| 347 |
+
use_cache=use_cache,
|
| 348 |
+
**kwargs)
|
| 349 |
last_state.time_mix_state = states
|
| 350 |
|
| 351 |
if past_key_value is not None:
|
| 352 |
past_key_value.update(token_length, last_state, self.layer_idx)
|
| 353 |
+
|
| 354 |
return attn_output, None
|
| 355 |
|
| 356 |
def init_state(self, batch_size, device, dtype) -> BlockState:
|
|
|
|
| 391 |
ddd[0, 0, i] = i / args.hidden_size
|
| 392 |
|
| 393 |
# fancy time_mix
|
| 394 |
+
self.time_maa_x = nn.Parameter(
|
| 395 |
+
1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
| 396 |
+
self.time_maa_w = nn.Parameter(
|
| 397 |
+
1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
| 398 |
+
self.time_maa_k = nn.Parameter(
|
| 399 |
+
1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
| 400 |
self.time_maa_v = nn.Parameter(
|
| 401 |
1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
| 402 |
)
|
|
|
|
| 414 |
torch.zeros(args.hidden_size, D_MIX_LORA * 5)
|
| 415 |
)
|
| 416 |
self.time_maa_w2 = nn.Parameter(
|
| 417 |
+
torch.zeros(5, D_MIX_LORA,
|
| 418 |
+
args.hidden_size).uniform_(-0.01, 0.01)
|
| 419 |
)
|
| 420 |
|
| 421 |
# fancy time_decay
|
|
|
|
| 424 |
decay_speed[n] = -6 + 5 * (n / (args.head_size - 1)) ** (
|
| 425 |
0.7 + 1.3 * ratio_0_to_1
|
| 426 |
)
|
| 427 |
+
self.time_decay = nn.Parameter(
|
| 428 |
+
decay_speed.reshape(1, 1, args.head_size))
|
| 429 |
|
| 430 |
D_DECAY_LORA = 64
|
| 431 |
if args.hidden_size == 4096:
|
|
|
|
| 440 |
tmp = torch.zeros(args.head_size)
|
| 441 |
for n in range(args.head_size):
|
| 442 |
zigzag = ((n + 1) % 3 - 1) * 0.1
|
| 443 |
+
tmp[n] = ratio_0_to_1 * \
|
| 444 |
+
(1 - (n / (args.head_size - 1))) + zigzag
|
| 445 |
|
| 446 |
+
self.time_faaaa = nn.Parameter(
|
| 447 |
+
tmp.reshape(self.n_head, self.head_size))
|
| 448 |
# self.time_state = nn.Parameter(torch.zeros(self.n_head, self.head_size, self.head_size))
|
| 449 |
|
| 450 |
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
| 451 |
+
self.receptance = nn.Linear(
|
| 452 |
+
args.hidden_size, args.head_size, bias=False)
|
| 453 |
self.key = nn.Linear(args.hidden_size, args.head_size, bias=False)
|
| 454 |
|
| 455 |
self.value = nn.Linear(args.hidden_size, args.head_size, bias=False)
|
|
|
|
| 458 |
|
| 459 |
if self.args.wkv_has_group_norm:
|
| 460 |
self.ln_x = nn.GroupNorm(
|
| 461 |
+
self.n_head, args.head_size, eps=(
|
| 462 |
+
1e-5) * (args.head_size_divisor**2)
|
| 463 |
)
|
| 464 |
|
| 465 |
def post_init(self):
|
|
|
|
| 476 |
lx = x[:, -1]
|
| 477 |
|
| 478 |
xxx = x + xx * self.time_maa_x
|
| 479 |
+
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B *
|
| 480 |
+
T, 5, -1).transpose(0, 1)
|
| 481 |
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
|
| 482 |
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
|
| 483 |
|
|
|
|
| 505 |
return x, TimeMixState(lx, wkv_state)
|
| 506 |
|
| 507 |
def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
|
| 508 |
+
r, w, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v))
|
|
|
|
|
|
|
|
|
|
| 509 |
|
| 510 |
if r.device.type == "cpu":
|
| 511 |
wkv6_func = native_recurrent_rwkv6
|
|
|
|
| 545 |
last_state = past_key_value[self.layer_idx][0]
|
| 546 |
if last_state is None:
|
| 547 |
wkv_states = torch.zeros(
|
| 548 |
+
(B, self.args.num_wkv_heads,
|
| 549 |
+
self.args.head_size, self.args.head_size),
|
| 550 |
device=attn_output.device,
|
| 551 |
dtype=torch.float32,
|
| 552 |
)
|
|
|
|
| 556 |
time_state = TimeMixState(token_shift, wkv_states)
|
| 557 |
channel_state = None
|
| 558 |
last_state = BlockState(time_state, channel_state)
|
| 559 |
+
attn_output, states = self.time_mixer(
|
| 560 |
+
attn_output, last_state.time_mix_state)
|
| 561 |
last_state.time_mix_state = states
|
| 562 |
|
| 563 |
if past_key_value is not None:
|