SAIL-Recon / sailrecon /layers /attention.py
hengli
first
b7f83b0
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
import os
import warnings
import torch
import torch.nn.functional as F
from torch import Tensor, nn
XFORMERS_AVAILABLE = False
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
kv_cache=False,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = fused_attn
# KV Cache
self.kv_cache = kv_cache
self.k_cache: Tensor | None = None
self.v_cache: Tensor | None = None
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = rope
def clear_kv_cache(self):
"""Clear the KV cache. Should be called between different sequences."""
if self.k_cache is not None:
if self.k_cache.is_cuda:
del self.k_cache
del self.v_cache
torch.cuda.empty_cache()
else:
del self.k_cache
del self.v_cache
self.k_cache = None
self.v_cache = None
def forward(self, x: Tensor, pos=None, attn_mask: Tensor | None = None) -> Tensor:
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.rope is not None:
q = self.rope(q, pos)
k = self.rope(k, pos)
# --- KV Cache Logic ---
if self.kv_cache:
if self.k_cache is not None:
# if self.k_cache.device != x.device:
# self.k_cache = self.k_cache.cuda()
# self.v_cache = self.v_cache.cuda()
# We are in generation mode. Concatenate the new keys and values to the cache.
# The new k and v have a sequence length of N (usually 1).
k = torch.cat([self.k_cache.cuda(), k], dim=2)
v = torch.cat([self.v_cache.cuda(), v], dim=2)
else:
# This is the first pass (prompt processing). Initialize the cache.
self.k_cache = k.cpu()
self.v_cache = v.cpu()
self.k_length = k.shape[2]
return torch.zeros_like(x) # Return zero tensor for the first pass
if self.fused_attn:
x = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn = attn.masked_fill(attn_mask, float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None, pos=None, attn_mask=None) -> Tensor:
assert pos is None
if not XFORMERS_AVAILABLE:
if attn_bias is not None:
raise AssertionError("xFormers is required for using nested tensors")
return super().forward(x, attn_mask=attn_mask)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x