|
|
""" |
|
|
2025.3.17 |
|
|
2025.3.19 |
|
|
4.50.0 |
|
|
0.15.2 |
|
|
__UNSLOTH_VERSIONING__ |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import importlib.util |
|
|
if importlib.util.find_spec("unsloth_studio") is None: |
|
|
UNSLOTH_STUDIO_ENABLED = False |
|
|
else: |
|
|
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" |
|
|
pass |
|
|
from typing import List, Dict, Tuple, Optional, Any, Callable |
|
|
import math |
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
from unsloth_zoo.loss_utils import fused_linear_cross_entropy |
|
|
|
|
|
if UNSLOTH_STUDIO_ENABLED: |
|
|
from unsloth_zoo.loss_utils import fast_linear_cross_entropy |
|
|
|
|
|
scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention |
|
|
@torch.compiler.disable(recursive = False) |
|
|
def disable_compile_scaled_dot_product_attention(*args, **kwargs): |
|
|
return scaled_dot_product_attention(*args, **kwargs) |
|
|
pass |
|
|
|
|
|
|
|
|
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False} |
|
|
|
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) |
|
|
def normal_cross_entropy_loss(self, hidden_states, labels): |
|
|
logits = self.lm_head(hidden_states) |
|
|
logits = logits.float() |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
return loss, logits |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
LOGITS_ERROR_STRING = \ |
|
|
"Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\ |
|
|
'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\ |
|
|
"```\nimport os\n"\ |
|
|
"os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\ |
|
|
"trainer.train()\n```\n"\ |
|
|
"No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!" |
|
|
|
|
|
def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) |
|
|
def return_none(*args, **kwargs): return None |
|
|
class EmptyLogits: |
|
|
def __init__(self): return |
|
|
def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error |
|
|
__getitem__ = raise_logits_error |
|
|
__getattr__ = raise_getattr_error |
|
|
def __repr__(self): return LOGITS_ERROR_STRING |
|
|
def __str__ (self): return LOGITS_ERROR_STRING |
|
|
pass |
|
|
EMPTY_LOGITS = EmptyLogits() |
|
|
functions = dir(torch.Tensor) |
|
|
for j, function in enumerate(functions): |
|
|
if function.startswith("__") and function.endswith("__"): |
|
|
exec(f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals()) |
|
|
try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals()) |
|
|
except: continue |
|
|
pass |
|
|
|
|
|
|
|
|
from torch import Tensor |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from transformers.models.siglip.modeling_siglip import (math, warnings, Optional, Tuple, np, torch, nn, _calculate_fan_in_and_fan_out, ACT2FN, is_flash_attn_greater_or_equal_2_10, torch_int, SiglipTextConfig, SiglipVisionConfig, logger) |
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) |
|
|
def _trunc_normal_(tensor, mean, std, a, b): |
|
|
|
|
|
|
|
|
def norm_cdf(x): |
|
|
|
|
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 |
|
|
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
|
warnings.warn( |
|
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
|
"The distribution of values may be incorrect.", |
|
|
stacklevel=2, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std) |
|
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
|
|
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.0)) |
|
|
tensor.add_(mean) |
|
|
|
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
|
|
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) |
|
|
def trunc_normal_tf_( |
|
|
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 |
|
|
) -> torch.Tensor: |
|
|
"""Fills the input Tensor with values drawn from a truncated |
|
|
normal distribution. The values are effectively drawn from the |
|
|
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` |
|
|
with values outside :math:`[a, b]` redrawn until they are within |
|
|
the bounds. The method used for generating the random values works |
|
|
best when :math:`a \\leq \text{mean} \\leq b`. |
|
|
|
|
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the |
|
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 |
|
|
and the result is subsequently scaled and shifted by the mean and std args. |
|
|
|
|
|
Args: |
|
|
tensor: an n-dimensional `torch.Tensor` |
|
|
mean: the mean of the normal distribution |
|
|
std: the standard deviation of the normal distribution |
|
|
a: the minimum cutoff value |
|
|
b: the maximum cutoff value |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
_trunc_normal_(tensor, 0, 1.0, a, b) |
|
|
tensor.mul_(std).add_(mean) |
|
|
|
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) |
|
|
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): |
|
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
|
|
if mode == "fan_in": |
|
|
denom = fan_in |
|
|
elif mode == "fan_out": |
|
|
denom = fan_out |
|
|
elif mode == "fan_avg": |
|
|
denom = (fan_in + fan_out) / 2 |
|
|
|
|
|
variance = scale / denom |
|
|
|
|
|
if distribution == "truncated_normal": |
|
|
|
|
|
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) |
|
|
elif distribution == "normal": |
|
|
with torch.no_grad(): |
|
|
tensor.normal_(std=math.sqrt(variance)) |
|
|
elif distribution == "uniform": |
|
|
bound = math.sqrt(3 * variance) |
|
|
with torch.no_grad(): |
|
|
tensor.uniform_(-bound, bound) |
|
|
else: |
|
|
raise ValueError(f"invalid distribution {distribution}") |
|
|
|
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) |
|
|
def lecun_normal_(tensor): |
|
|
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") |
|
|
|
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) |
|
|
def default_flax_embed_init(tensor): |
|
|
variance_scaling_(tensor, mode="fan_in", distribution="normal") |
|
|
|
|
|
|
|
|
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) |
|
|
def SiglipVisionEmbeddings_forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: |
|
|
_, _, height, width = pixel_values.shape |
|
|
target_dtype = self.patch_embedding.weight.dtype |
|
|
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) |
|
|
embeddings = patch_embeds.flatten(2).transpose(1, 2) |
|
|
|
|
|
if interpolate_pos_encoding: |
|
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) |
|
|
else: |
|
|
embeddings = embeddings + self.position_embedding(self.position_ids) |
|
|
return embeddings |
|
|
|
|
|
class SiglipVisionEmbeddings(nn.Module): |
|
|
def __init__(self, config: SiglipVisionConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.embed_dim = config.hidden_size |
|
|
self.image_size = config.image_size |
|
|
self.patch_size = config.patch_size |
|
|
|
|
|
self.patch_embedding = nn.Conv2d( |
|
|
in_channels=config.num_channels, |
|
|
out_channels=self.embed_dim, |
|
|
kernel_size=self.patch_size, |
|
|
stride=self.patch_size, |
|
|
padding="valid", |
|
|
) |
|
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2 |
|
|
self.num_positions = self.num_patches |
|
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) |
|
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) |
|
|
|
|
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: |
|
|
""" |
|
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution |
|
|
images. This method is also adapted to support torch.jit tracing and no class embeddings. |
|
|
|
|
|
Adapted from: |
|
|
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and |
|
|
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 |
|
|
""" |
|
|
|
|
|
num_patches = embeddings.shape[1] |
|
|
num_positions = self.position_embedding.weight.shape[0] |
|
|
|
|
|
|
|
|
if not torch.jit.is_tracing() and num_patches == num_positions and height == width: |
|
|
return self.position_embedding(self.position_ids) |
|
|
|
|
|
patch_pos_embed = self.position_embedding.weight.unsqueeze(0) |
|
|
|
|
|
dim = embeddings.shape[-1] |
|
|
|
|
|
new_height = height // self.patch_size |
|
|
new_width = width // self.patch_size |
|
|
|
|
|
sqrt_num_positions = torch_int(num_positions**0.5) |
|
|
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) |
|
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) |
|
|
|
|
|
patch_pos_embed = nn.functional.interpolate( |
|
|
patch_pos_embed, |
|
|
size=(new_height, new_width), |
|
|
mode="bicubic", |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|
|
return patch_pos_embed |
|
|
|
|
|
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: |
|
|
return SiglipVisionEmbeddings_forward(self, pixel_values, interpolate_pos_encoding) |
|
|
|
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) |
|
|
def SiglipTextEmbeddings_forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
) -> torch.Tensor: |
|
|
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] |
|
|
max_position_embedding = self.position_embedding.weight.shape[0] |
|
|
|
|
|
if seq_length > max_position_embedding: |
|
|
raise ValueError( |
|
|
f"Sequence length must be less than max_position_embeddings (got `sequence length`: " |
|
|
f"{seq_length} and max_position_embeddings: {max_position_embedding}" |
|
|
) |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = self.position_ids[:, :seq_length] |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.token_embedding(input_ids) |
|
|
|
|
|
position_embeddings = self.position_embedding(position_ids) |
|
|
embeddings = inputs_embeds + position_embeddings |
|
|
|
|
|
return embeddings |
|
|
|
|
|
class SiglipTextEmbeddings(nn.Module): |
|
|
def __init__(self, config: SiglipTextConfig): |
|
|
super().__init__() |
|
|
embed_dim = config.hidden_size |
|
|
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) |
|
|
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) |
|
|
|
|
|
|
|
|
self.register_buffer( |
|
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
) -> torch.Tensor: |
|
|
return SiglipTextEmbeddings_forward(self, input_ids, position_ids, inputs_embeds) |
|
|
|
|
|
|
|
|
@torch.compiler.disable(recursive = False) |
|
|
def SiglipAttention_forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
|
|
batch_size, q_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
k_v_seq_len = key_states.shape[-2] |
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale |
|
|
|
|
|
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): |
|
|
raise ValueError( |
|
|
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" |
|
|
f" {attn_weights.size()}" |
|
|
) |
|
|
|
|
|
if attention_mask is not None: |
|
|
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): |
|
|
raise ValueError( |
|
|
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" |
|
|
) |
|
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
|
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
|
|
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): |
|
|
raise ValueError( |
|
|
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" |
|
|
f" {attn_output.size()}" |
|
|
) |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) |
|
|
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
class SiglipAttention(nn.Module): |
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.embed_dim = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_dim = self.embed_dim // self.num_heads |
|
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
|
raise ValueError( |
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
|
|
f" {self.num_heads})." |
|
|
) |
|
|
self.scale = self.head_dim**-0.5 |
|
|
self.dropout = config.attention_dropout |
|
|
|
|
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
return SiglipAttention_forward(self, hidden_states, attention_mask, output_attentions) |
|
|
|
|
|
|
|
|
@torch.compiler.disable(recursive = False) |
|
|
def SiglipFlashAttention2_forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.LongTensor] = None, |
|
|
output_attentions: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
output_attentions = False |
|
|
|
|
|
batch_size, q_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) |
|
|
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) |
|
|
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) |
|
|
|
|
|
dropout_rate = self.dropout if self.training else 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype |
|
|
if input_dtype == torch.float32: |
|
|
if torch.is_autocast_enabled(): |
|
|
target_dtype = torch.get_autocast_gpu_dtype() |
|
|
|
|
|
elif hasattr(self.config, "_pre_quantization_dtype"): |
|
|
target_dtype = self.config._pre_quantization_dtype |
|
|
else: |
|
|
target_dtype = self.q_proj.weight.dtype |
|
|
|
|
|
logger.warning_once( |
|
|
f"The input hidden states seems to be silently casted in float32, this might be related to" |
|
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
|
|
f" {target_dtype}." |
|
|
) |
|
|
|
|
|
query_states = query_states.to(target_dtype) |
|
|
key_states = key_states.to(target_dtype) |
|
|
value_states = value_states.to(target_dtype) |
|
|
|
|
|
attn_output = _flash_attention_forward( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
q_len, |
|
|
dropout=dropout_rate, |
|
|
is_causal=self.is_causal, |
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() |
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
|
|
if not output_attentions: |
|
|
attn_weights = None |
|
|
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
class SiglipFlashAttention2(SiglipAttention): |
|
|
""" |
|
|
SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays |
|
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of |
|
|
flash attention and deal with padding tokens in case the input contains any of them. |
|
|
""" |
|
|
|
|
|
is_causal = False |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.LongTensor] = None, |
|
|
output_attentions: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
return SiglipFlashAttention2_forward(self, hidden_states, attention_mask, output_attentions) |
|
|
|
|
|
|
|
|
@torch.compiler.disable(recursive = False) |
|
|
def SiglipSdpaAttention_forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
if output_attentions: raise RuntimeError('Unsloth: Not supported') |
|
|
|
|
|
batch_size, q_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
if query_states.device.type == "cuda" and attention_mask is not None: |
|
|
query_states = query_states.contiguous() |
|
|
key_states = key_states.contiguous() |
|
|
value_states = value_states.contiguous() |
|
|
|
|
|
|
|
|
|
|
|
is_causal = True if self.is_causal and q_len > 1 else False |
|
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attn_mask=attention_mask, |
|
|
dropout_p=self.dropout if self.training else 0.0, |
|
|
is_causal=is_causal, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
attn_output = attn_output.view(batch_size, q_len, self.embed_dim) |
|
|
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
|
|
return attn_output, None |
|
|
|
|
|
class SiglipSdpaAttention(SiglipAttention): |
|
|
""" |
|
|
Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from |
|
|
`SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to |
|
|
SDPA API. |
|
|
""" |
|
|
|
|
|
is_causal = False |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
return SiglipSdpaAttention_forward(self, hidden_states, attention_mask, output_attentions) |
|
|
|
|
|
|
|
|
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) |
|
|
def SiglipMLP_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.fc1(hidden_states) |
|
|
hidden_states = self.activation_fn(hidden_states) |
|
|
hidden_states = self.fc2(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
class SiglipMLP(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.activation_fn = ACT2FN[config.hidden_act] |
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
return SiglipMLP_forward(self, hidden_states) |
|
|
|
|
|
|
|
|
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) |
|
|
def SiglipMultiheadAttentionPoolingHead_forward(self, hidden_state): |
|
|
batch_size = hidden_state.shape[0] |
|
|
probe = self.probe.repeat(batch_size, 1, 1) |
|
|
|
|
|
hidden_state = self.attention(probe, hidden_state, hidden_state)[0] |
|
|
|
|
|
residual = hidden_state |
|
|
hidden_state = self.layernorm(hidden_state) |
|
|
hidden_state = residual + self.mlp(hidden_state) |
|
|
|
|
|
return hidden_state[:, 0] |
|
|
|
|
|
class SiglipMultiheadAttentionPoolingHead(nn.Module): |
|
|
"""Multihead Attention Pooling.""" |
|
|
|
|
|
def __init__(self, config: SiglipVisionConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) |
|
|
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) |
|
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.mlp = SiglipMLP(config) |
|
|
|
|
|
def forward(self, hidden_state): |
|
|
return SiglipMultiheadAttentionPoolingHead_forward(self, hidden_state) |
|
|
|