| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class SelfAttention(nn.Module): |
| def __init__(self, d_model, nhead): |
| super().__init__() |
| self.qkv_proj = nn.Linear(d_model, 3 * d_model) |
| self.out_proj = nn.Linear(d_model, d_model) |
| self.nhead = nhead |
| self.d_model = d_model |
|
|
| def forward(self, x): |
| B, T, C = x.size() |
| qkv = self.qkv_proj(x) |
| q, k, v = qkv.chunk(3, dim=-1) |
|
|
| q = q.view(B, T, self.nhead, C // self.nhead).transpose(1, 2) |
| k = k.view(B, T, self.nhead, C // self.nhead).transpose(1, 2) |
| v = v.view(B, T, self.nhead, C // self.nhead).transpose(1, 2) |
|
|
| scores = torch.matmul(q, k.transpose(-2, -1)) / (C // self.nhead) ** 0.5 |
| attn = torch.softmax(scores, dim=-1) |
| out = torch.matmul(attn, v) |
|
|
| out = out.transpose(1, 2).contiguous().view(B, T, C) |
| return self.out_proj(out) |
|
|
| class FeedForward(nn.Module): |
| def __init__(self, d_model, dim_feedforward): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(d_model, dim_feedforward), |
| nn.ReLU(), |
| nn.Dropout(), |
| nn.Linear(dim_feedforward, d_model) |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, d_model, nhead, dim_feedforward): |
| super().__init__() |
| self.attn = SelfAttention(d_model, nhead) |
| self.ln1 = nn.LayerNorm(d_model) |
| self.ffn = FeedForward(d_model, dim_feedforward) |
| self.ln2 = nn.LayerNorm(d_model) |
|
|
| def forward(self, x): |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.ffn(self.ln2(x)) |
| return x |
|
|
| class EvoDecoder(nn.Module): |
| def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dim_feedforward=512): |
| super().__init__() |
| self.token_emb = nn.Embedding(vocab_size, d_model) |
| self.pos_emb = nn.Embedding(512, d_model) |
| self.blocks = nn.Sequential(*[ |
| TransformerBlock(d_model, nhead, dim_feedforward) for _ in range(num_layers) |
| ]) |
| self.ln_f = nn.LayerNorm(d_model) |
| self.fc_out = nn.Linear(d_model, vocab_size) |
|
|
| def forward(self, x): |
| B, T = x.size() |
| tok = self.token_emb(x) |
| pos = self.pos_emb(torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)) |
| x = tok + pos |
| x = self.blocks(x) |
| x = self.ln_f(x) |
| return self.fc_out(x) |
|
|