Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,872 Bytes
b7f83b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
import os
import warnings
import torch
import torch.nn.functional as F
from torch import Tensor, nn
XFORMERS_AVAILABLE = False
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
kv_cache=False,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = fused_attn
# KV Cache
self.kv_cache = kv_cache
self.k_cache: Tensor | None = None
self.v_cache: Tensor | None = None
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = rope
def clear_kv_cache(self):
"""Clear the KV cache. Should be called between different sequences."""
if self.k_cache is not None:
if self.k_cache.is_cuda:
del self.k_cache
del self.v_cache
torch.cuda.empty_cache()
else:
del self.k_cache
del self.v_cache
self.k_cache = None
self.v_cache = None
def forward(self, x: Tensor, pos=None, attn_mask: Tensor | None = None) -> Tensor:
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.rope is not None:
q = self.rope(q, pos)
k = self.rope(k, pos)
# --- KV Cache Logic ---
if self.kv_cache:
if self.k_cache is not None:
# if self.k_cache.device != x.device:
# self.k_cache = self.k_cache.cuda()
# self.v_cache = self.v_cache.cuda()
# We are in generation mode. Concatenate the new keys and values to the cache.
# The new k and v have a sequence length of N (usually 1).
k = torch.cat([self.k_cache.cuda(), k], dim=2)
v = torch.cat([self.v_cache.cuda(), v], dim=2)
else:
# This is the first pass (prompt processing). Initialize the cache.
self.k_cache = k.cpu()
self.v_cache = v.cpu()
self.k_length = k.shape[2]
return torch.zeros_like(x) # Return zero tensor for the first pass
if self.fused_attn:
x = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn = attn.masked_fill(attn_mask, float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None, pos=None, attn_mask=None) -> Tensor:
assert pos is None
if not XFORMERS_AVAILABLE:
if attn_bias is not None:
raise AssertionError("xFormers is required for using nested tensors")
return super().forward(x, attn_mask=attn_mask)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
|