Bird-MAE-Huge
Browse files- config.json +4 -3
- configuration_bird_mae.py +55 -0
- modeling_bird_mae.py +516 -0
config.json
CHANGED
|
@@ -4,18 +4,19 @@
|
|
| 4 |
],
|
| 5 |
"attn_drop_rate": 0.0,
|
| 6 |
"auto_map": {
|
| 7 |
-
"AutoConfig": "
|
| 8 |
-
"AutoModel": "
|
| 9 |
},
|
| 10 |
"depth": 32,
|
| 11 |
"drop_path_rate": 0.0,
|
| 12 |
"drop_rate": 0.0,
|
| 13 |
"embed_dim": 1280,
|
|
|
|
| 14 |
"img_size_x": 512,
|
| 15 |
"img_size_y": 128,
|
| 16 |
"in_chans": 1,
|
| 17 |
"init_values": null,
|
| 18 |
-
"mlp_ratio": 4
|
| 19 |
"norm_layer_eps": 1e-06,
|
| 20 |
"num_heads": 16,
|
| 21 |
"num_patches": 256,
|
|
|
|
| 4 |
],
|
| 5 |
"attn_drop_rate": 0.0,
|
| 6 |
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_bird_mae.BirdMAEConfig",
|
| 8 |
+
"AutoModel": "modeling_bird_mae.BirdMAEModel"
|
| 9 |
},
|
| 10 |
"depth": 32,
|
| 11 |
"drop_path_rate": 0.0,
|
| 12 |
"drop_rate": 0.0,
|
| 13 |
"embed_dim": 1280,
|
| 14 |
+
"global_pool": "mean",
|
| 15 |
"img_size_x": 512,
|
| 16 |
"img_size_y": 128,
|
| 17 |
"in_chans": 1,
|
| 18 |
"init_values": null,
|
| 19 |
+
"mlp_ratio": 4,
|
| 20 |
"norm_layer_eps": 1e-06,
|
| 21 |
"num_heads": 16,
|
| 22 |
"num_patches": 256,
|
configuration_bird_mae.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BirdMAEConfig(PretrainedConfig):
|
| 6 |
+
"""This represents the Bird-MAE-Base config from the original paper"""
|
| 7 |
+
_auto_class = "AutoConfig"
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
img_size_x: int = 512,
|
| 12 |
+
img_size_y: int = 128,
|
| 13 |
+
patch_size: int = 16,
|
| 14 |
+
in_chans: int = 1,
|
| 15 |
+
embed_dim: int = 768,
|
| 16 |
+
depth: int = 12,
|
| 17 |
+
num_heads: int = 12,
|
| 18 |
+
mlp_ratio: int = 4,
|
| 19 |
+
pos_trainable: bool = False,
|
| 20 |
+
qkv_bias: bool = True,
|
| 21 |
+
qk_norm: bool = False,
|
| 22 |
+
init_values: float = None,
|
| 23 |
+
drop_rate: float = 0.0,
|
| 24 |
+
norm_layer_eps: float = 1e-6,
|
| 25 |
+
global_pool: Literal["cls", "mean"] | None = "mean",
|
| 26 |
+
**kwargs
|
| 27 |
+
):
|
| 28 |
+
super().__init__(**kwargs)
|
| 29 |
+
|
| 30 |
+
self.img_size_x = img_size_x
|
| 31 |
+
self.img_size_y = img_size_y
|
| 32 |
+
self.patch_size = patch_size
|
| 33 |
+
self.in_chans = in_chans
|
| 34 |
+
self.embed_dim = embed_dim
|
| 35 |
+
self.depth = depth
|
| 36 |
+
self.num_heads = num_heads
|
| 37 |
+
self.mlp_ratio = mlp_ratio
|
| 38 |
+
self.pos_trainable = pos_trainable
|
| 39 |
+
|
| 40 |
+
self.qkv_bias = qkv_bias
|
| 41 |
+
self.qk_norm = qk_norm
|
| 42 |
+
self.init_values = init_values
|
| 43 |
+
self.drop_rate = drop_rate
|
| 44 |
+
self.pos_drop_rate = drop_rate
|
| 45 |
+
self.attn_drop_rate = drop_rate
|
| 46 |
+
self.drop_path_rate = drop_rate
|
| 47 |
+
self.proj_drop_rate = drop_rate
|
| 48 |
+
self.norm_layer_eps = norm_layer_eps
|
| 49 |
+
self.global_pool = global_pool
|
| 50 |
+
|
| 51 |
+
# Calculated properties (useful for initializing the model)
|
| 52 |
+
self.num_patches_x = img_size_x // patch_size
|
| 53 |
+
self.num_patches_y = img_size_y // patch_size
|
| 54 |
+
self.num_patches = self.num_patches_x * self.num_patches_y
|
| 55 |
+
self.num_tokens = self.num_patches + 1
|
modeling_bird_mae.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import collections
|
| 3 |
+
from itertools import repeat
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Optional, Literal
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from transformers import PreTrainedModel
|
| 11 |
+
from transformers.utils import logging
|
| 12 |
+
from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput
|
| 13 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
| 14 |
+
|
| 15 |
+
from .configuration_bird_mae import BirdMAEConfig
|
| 16 |
+
|
| 17 |
+
logger = logging.get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 21 |
+
"""
|
| 22 |
+
embed_dim: output dimension for each position
|
| 23 |
+
pos: a list of positions to be encoded: size (M,)
|
| 24 |
+
out: (M, D)
|
| 25 |
+
"""
|
| 26 |
+
assert embed_dim % 2 == 0
|
| 27 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
| 28 |
+
omega /= embed_dim / 2.
|
| 29 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 30 |
+
|
| 31 |
+
pos = pos.reshape(-1) # (M,)
|
| 32 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 33 |
+
|
| 34 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 35 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 36 |
+
|
| 37 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 38 |
+
return emb
|
| 39 |
+
|
| 40 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 41 |
+
assert embed_dim % 2 == 0
|
| 42 |
+
|
| 43 |
+
# use half of dimensions to encode grid_h
|
| 44 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 45 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 46 |
+
|
| 47 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 48 |
+
return emb
|
| 49 |
+
|
| 50 |
+
def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
|
| 51 |
+
"""
|
| 52 |
+
grid_size: int of the grid height and width
|
| 53 |
+
return:
|
| 54 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 55 |
+
"""
|
| 56 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) # grid size[0] = 8
|
| 57 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) # grid size[1] = 32
|
| 58 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 59 |
+
grid = np.stack(grid, axis=0) # 2,8,32
|
| 60 |
+
|
| 61 |
+
grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) # 2,1,8.32
|
| 62 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 63 |
+
if cls_token:
|
| 64 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 65 |
+
return pos_embed # 267 (+cls) x 1024 (feature dim)
|
| 66 |
+
|
| 67 |
+
# From timm.models.layers
|
| 68 |
+
class DropPath(nn.Module):
|
| 69 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 70 |
+
"""
|
| 71 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| 72 |
+
super(DropPath, self).__init__()
|
| 73 |
+
self.drop_prob = drop_prob
|
| 74 |
+
self.scale_by_keep = scale_by_keep
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
if self.drop_prob == 0. or not self.training:
|
| 78 |
+
return x
|
| 79 |
+
keep_prob = 1 - self.drop_prob
|
| 80 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 81 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 82 |
+
if keep_prob > 0.0 and self.scale_by_keep:
|
| 83 |
+
random_tensor.div_(keep_prob)
|
| 84 |
+
return x * random_tensor
|
| 85 |
+
|
| 86 |
+
def _ntuple(n):
|
| 87 |
+
def parse(x):
|
| 88 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 89 |
+
return tuple(x)
|
| 90 |
+
return tuple(repeat(x, n))
|
| 91 |
+
return parse
|
| 92 |
+
|
| 93 |
+
class Mlp(nn.Module):
|
| 94 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
| 95 |
+
"""
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
in_features,
|
| 99 |
+
hidden_features=None,
|
| 100 |
+
out_features=None,
|
| 101 |
+
act_layer=nn.GELU,
|
| 102 |
+
norm_layer=None,
|
| 103 |
+
bias=True,
|
| 104 |
+
drop=0.,
|
| 105 |
+
use_conv=False,
|
| 106 |
+
):
|
| 107 |
+
super().__init__()
|
| 108 |
+
out_features = out_features or in_features
|
| 109 |
+
hidden_features = hidden_features or in_features
|
| 110 |
+
bias = _ntuple(2)(bias)
|
| 111 |
+
drop_probs = _ntuple(2)(drop)
|
| 112 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 113 |
+
|
| 114 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 115 |
+
self.act = act_layer()
|
| 116 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 117 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
| 118 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 119 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 120 |
+
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
x = self.fc1(x)
|
| 123 |
+
x = self.act(x)
|
| 124 |
+
x = self.drop1(x)
|
| 125 |
+
x = self.norm(x)
|
| 126 |
+
x = self.fc2(x)
|
| 127 |
+
x = self.drop2(x)
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
# Modified from timm.models.vision_transformer
|
| 131 |
+
class Attention(nn.Module):
|
| 132 |
+
"""Standard Multi-head Self Attention module with QKV projection.
|
| 133 |
+
|
| 134 |
+
This module implements the standard multi-head attention mechanism used in transformers.
|
| 135 |
+
It supports both the fused attention implementation (scaled_dot_product_attention) for
|
| 136 |
+
efficiency when available, and a manual implementation otherwise. The module includes
|
| 137 |
+
options for QK normalization, attention dropout, and projection dropout.
|
| 138 |
+
"""
|
| 139 |
+
fused_attn: bool = True
|
| 140 |
+
|
| 141 |
+
def __init__(
|
| 142 |
+
self,
|
| 143 |
+
dim: int,
|
| 144 |
+
num_heads: int = 8,
|
| 145 |
+
qkv_bias: bool = False,
|
| 146 |
+
qk_norm: bool = False,
|
| 147 |
+
scale_norm: bool = False,
|
| 148 |
+
proj_bias: bool = True,
|
| 149 |
+
attn_drop: float = 0.,
|
| 150 |
+
proj_drop: float = 0.,
|
| 151 |
+
norm_layer: nn.Module = None,
|
| 152 |
+
) -> None:
|
| 153 |
+
"""Initialize the Attention module.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
dim: Input dimension of the token embeddings
|
| 157 |
+
num_heads: Number of attention heads
|
| 158 |
+
qkv_bias: Whether to use bias in the query, key, value projections
|
| 159 |
+
qk_norm: Whether to apply normalization to query and key vectors
|
| 160 |
+
proj_bias: Whether to use bias in the output projection
|
| 161 |
+
attn_drop: Dropout rate applied to the attention weights
|
| 162 |
+
proj_drop: Dropout rate applied after the output projection
|
| 163 |
+
norm_layer: Normalization layer constructor for QK normalization if enabled
|
| 164 |
+
"""
|
| 165 |
+
super().__init__()
|
| 166 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 167 |
+
if qk_norm or scale_norm:
|
| 168 |
+
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
|
| 169 |
+
self.num_heads = num_heads
|
| 170 |
+
self.head_dim = dim // num_heads
|
| 171 |
+
self.scale = self.head_dim ** -0.5
|
| 172 |
+
|
| 173 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 174 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 175 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 176 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 177 |
+
self.norm = norm_layer(dim) if scale_norm else nn.Identity()
|
| 178 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 179 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 180 |
+
|
| 181 |
+
def forward(
|
| 182 |
+
self,
|
| 183 |
+
x: torch.Tensor,
|
| 184 |
+
attn_mask: torch.Tensor = None,
|
| 185 |
+
output_attentions: bool = False,
|
| 186 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 187 |
+
B, N, C = x.shape
|
| 188 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 189 |
+
q, k, v = qkv.unbind(0)
|
| 190 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 191 |
+
|
| 192 |
+
attn_weights = None
|
| 193 |
+
|
| 194 |
+
if self.fused_attn and not output_attentions:
|
| 195 |
+
x = F.scaled_dot_product_attention(
|
| 196 |
+
q, k, v,
|
| 197 |
+
attn_mask=attn_mask,
|
| 198 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 199 |
+
)
|
| 200 |
+
else:
|
| 201 |
+
q = q * self.scale
|
| 202 |
+
attn = q @ k.transpose(-2, -1)
|
| 203 |
+
attn = attn if attn_mask is None else attn + attn_mask
|
| 204 |
+
attn_weights = attn.softmax(dim=-1)
|
| 205 |
+
x = self.attn_drop(attn_weights) @ v
|
| 206 |
+
|
| 207 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 208 |
+
x = self.norm(x)
|
| 209 |
+
x = self.proj(x)
|
| 210 |
+
x = self.proj_drop(x)
|
| 211 |
+
|
| 212 |
+
return x, attn_weights
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# From timm.models.vision_transformer
|
| 216 |
+
class Block(nn.Module):
|
| 217 |
+
def __init__(
|
| 218 |
+
self,
|
| 219 |
+
dim: int,
|
| 220 |
+
num_heads: int,
|
| 221 |
+
mlp_ratio: float = 4.,
|
| 222 |
+
qkv_bias: bool = False,
|
| 223 |
+
qk_norm: bool = False,
|
| 224 |
+
proj_drop: float = 0.,
|
| 225 |
+
attn_drop: float = 0.,
|
| 226 |
+
init_values: float = None,
|
| 227 |
+
drop_path: float = 0.,
|
| 228 |
+
act_layer: nn.Module = nn.GELU,
|
| 229 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 230 |
+
mlp_layer: nn.Module = Mlp,
|
| 231 |
+
) -> None:
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.norm1 = norm_layer(dim)
|
| 234 |
+
self.attn = Attention(
|
| 235 |
+
dim,
|
| 236 |
+
num_heads=num_heads,
|
| 237 |
+
qkv_bias=qkv_bias,
|
| 238 |
+
qk_norm=qk_norm,
|
| 239 |
+
attn_drop=attn_drop,
|
| 240 |
+
proj_drop=proj_drop,
|
| 241 |
+
norm_layer=norm_layer,
|
| 242 |
+
)
|
| 243 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 244 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 245 |
+
|
| 246 |
+
self.norm2 = norm_layer(dim)
|
| 247 |
+
self.mlp = mlp_layer(
|
| 248 |
+
in_features=dim,
|
| 249 |
+
hidden_features=int(dim * mlp_ratio),
|
| 250 |
+
act_layer=act_layer,
|
| 251 |
+
drop=proj_drop,
|
| 252 |
+
)
|
| 253 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 254 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 255 |
+
|
| 256 |
+
def forward(self, x: torch.Tensor,
|
| 257 |
+
output_attentions: bool = False,
|
| 258 |
+
attn_mask: torch.Tensor = None
|
| 259 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 260 |
+
#x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
| 261 |
+
x_skip = x
|
| 262 |
+
x = self.norm1(x)
|
| 263 |
+
x, att = self.attn(x, output_attentions=output_attentions, attn_mask=attn_mask)
|
| 264 |
+
x = self.ls1(x)
|
| 265 |
+
x = self.drop_path1(x)
|
| 266 |
+
x += x_skip
|
| 267 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
| 268 |
+
return x, att
|
| 269 |
+
|
| 270 |
+
# From timm.models.vision_transformer
|
| 271 |
+
class LayerScale(nn.Module):
|
| 272 |
+
def __init__(
|
| 273 |
+
self,
|
| 274 |
+
dim: int,
|
| 275 |
+
init_values: float = 1e-5,
|
| 276 |
+
inplace: bool = False,
|
| 277 |
+
) -> None:
|
| 278 |
+
super().__init__()
|
| 279 |
+
self.inplace = inplace
|
| 280 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 281 |
+
|
| 282 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 283 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class PatchEmbed_org(nn.Module):
|
| 287 |
+
""" Image to Patch Embedding
|
| 288 |
+
"""
|
| 289 |
+
def __init__(self,
|
| 290 |
+
img_size: int | tuple[int, ...] = 224,
|
| 291 |
+
patch_size: int | tuple[int, ...] = 16,
|
| 292 |
+
in_chans=3,
|
| 293 |
+
embed_dim=768):
|
| 294 |
+
super().__init__()
|
| 295 |
+
img_size: tuple[int,int] = _ntuple(2)(img_size) # audio mae used: (target_length x 128) --> not sure why tbh
|
| 296 |
+
patch_size: tuple[int,int] = _ntuple(2)(patch_size)
|
| 297 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 298 |
+
self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) # number of patches height/width = 8/32
|
| 299 |
+
self.img_size = img_size
|
| 300 |
+
self.patch_size = patch_size
|
| 301 |
+
self.num_patches = num_patches
|
| 302 |
+
|
| 303 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 304 |
+
|
| 305 |
+
def forward(self, x):
|
| 306 |
+
B, C, H, W = x.shape #batch size, channels, height, width --> apparently sth else is expected???
|
| 307 |
+
x = self.proj(x) # 1, 1, 512, 128 -> 1, 768, 32, 8 (batch, 768 channel, 32 height, 8 width)
|
| 308 |
+
x = x.flatten(2) # 1, 768, 32, 8 -> 1, 768, 256
|
| 309 |
+
x = x.transpose(1, 2) # 1, 768, 256 -> 1, 256, 768
|
| 310 |
+
return x
|
| 311 |
+
|
| 312 |
+
# --- END OF NECESSARY TIMM/Custom internal module definitions ---
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class BirdMAEPreTrainedModel(PreTrainedModel):
|
| 316 |
+
config_class = BirdMAEConfig
|
| 317 |
+
base_model_prefix = "model"
|
| 318 |
+
|
| 319 |
+
def _init_weights(self, module):
|
| 320 |
+
if isinstance(module, nn.Linear):
|
| 321 |
+
nn.init.normal_(module.weight, std=.02)
|
| 322 |
+
if module.bias is not None:
|
| 323 |
+
nn.init.constant_(module.bias, 0)
|
| 324 |
+
elif isinstance(module, nn.LayerNorm):
|
| 325 |
+
nn.init.constant_(module.weight, 1.0)
|
| 326 |
+
nn.init.constant_(module.bias, 0)
|
| 327 |
+
elif isinstance(module, nn.Conv2d):
|
| 328 |
+
w = module.weight.data
|
| 329 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class BirdMAEModel(BirdMAEPreTrainedModel):
|
| 333 |
+
_auto_class = "AutoModel"
|
| 334 |
+
#_keys_to_ignore_on_load_missing = ["fc_norm.weight", "fc_norm.bias"]
|
| 335 |
+
|
| 336 |
+
def __init__(self, config: BirdMAEConfig):
|
| 337 |
+
super().__init__(config)
|
| 338 |
+
|
| 339 |
+
self.patch_embed = PatchEmbed_org(
|
| 340 |
+
img_size=(config.img_size_x, config.img_size_y), # (512, 128)
|
| 341 |
+
patch_size=config.patch_size,
|
| 342 |
+
in_chans=config.in_chans,
|
| 343 |
+
embed_dim=config.embed_dim
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
|
| 347 |
+
|
| 348 |
+
self.pos_embed = nn.Parameter(
|
| 349 |
+
torch.zeros(1, config.num_patches + 1, config.embed_dim),
|
| 350 |
+
requires_grad=config.pos_trainable
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if self.pos_embed.data.shape[1] == config.num_patches + 1:
|
| 354 |
+
pos_embed_np = get_2d_sincos_pos_embed_flexible(
|
| 355 |
+
self.pos_embed.shape[-1], # embedding dim
|
| 356 |
+
self.patch_embed.patch_hw, # (8, 32) for a 128x512 image with 16x16 patches
|
| 357 |
+
cls_token=True
|
| 358 |
+
)
|
| 359 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed_np).float().unsqueeze(0))
|
| 360 |
+
else:
|
| 361 |
+
logger.warning("Positional embedding shape mismatch. Will not initialize sin-cos pos embed.")
|
| 362 |
+
|
| 363 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)]
|
| 364 |
+
self.blocks = nn.ModuleList([
|
| 365 |
+
Block(
|
| 366 |
+
dim=config.embed_dim,
|
| 367 |
+
num_heads=config.num_heads,
|
| 368 |
+
mlp_ratio=config.mlp_ratio,
|
| 369 |
+
qkv_bias=config.qkv_bias,
|
| 370 |
+
qk_norm=config.qk_norm,
|
| 371 |
+
init_values=config.init_values,
|
| 372 |
+
proj_drop=config.proj_drop_rate,
|
| 373 |
+
attn_drop=config.attn_drop_rate,
|
| 374 |
+
drop_path=dpr[i],
|
| 375 |
+
#norm_layer=nn.LayerNorm(config.embed_dim, eps=config.norm_layer_eps)
|
| 376 |
+
norm_layer=partial(nn.LayerNorm, eps=config.norm_layer_eps)
|
| 377 |
+
)
|
| 378 |
+
for i in range(config.depth)
|
| 379 |
+
])
|
| 380 |
+
|
| 381 |
+
self.pos_drop = nn.Dropout(p=config.pos_drop_rate)
|
| 382 |
+
self.norm = nn.LayerNorm(config.embed_dim, eps=config.norm_layer_eps) #norm_layer(config.embed_dim)
|
| 383 |
+
self.fc_norm = nn.LayerNorm(config.embed_dim, eps=config.norm_layer_eps) #norm_layer(config.embed_dim)
|
| 384 |
+
self.global_pool = config.global_pool
|
| 385 |
+
|
| 386 |
+
nn.init.trunc_normal_(self.cls_token, std=.02)
|
| 387 |
+
|
| 388 |
+
def forward(
|
| 389 |
+
self,
|
| 390 |
+
input_values : torch.Tensor,
|
| 391 |
+
attention_mask: torch.Tensor = None,
|
| 392 |
+
output_attentions: bool = None,
|
| 393 |
+
output_hidden_states: bool = None,
|
| 394 |
+
return_dict: bool = None,
|
| 395 |
+
) -> tuple | BaseModelOutput:
|
| 396 |
+
if len(input_values.shape) == 3:
|
| 397 |
+
input_values = input_values.unsqueeze(0)
|
| 398 |
+
|
| 399 |
+
output_attentions = output_attentions or self.config.output_attentions
|
| 400 |
+
|
| 401 |
+
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
| 402 |
+
return_dict = return_dict or self.config.use_return_dict
|
| 403 |
+
|
| 404 |
+
B, C, X, Y = input_values.shape
|
| 405 |
+
assert X == self.config.img_size_x, f"Expected image_size_x={self.config.img_size_x} but was {X}."
|
| 406 |
+
assert Y == self.config.img_size_y, f"Expected image_size_y={self.config.img_size_y} but was {Y}."
|
| 407 |
+
|
| 408 |
+
x = self.patch_embed(input_values)
|
| 409 |
+
|
| 410 |
+
x = x + self.pos_embed[:, 1:, :]
|
| 411 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
| 412 |
+
cls_tokens = cls_token.expand(B, -1, -1)
|
| 413 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 414 |
+
x = self.pos_drop(x)
|
| 415 |
+
|
| 416 |
+
all_hidden_states = (x,) if output_hidden_states else None
|
| 417 |
+
all_self_attns = () if output_attentions else None
|
| 418 |
+
|
| 419 |
+
for blk in self.blocks:
|
| 420 |
+
x, self_attn_weights = blk(x, output_attentions=output_attentions, attn_mask=attention_mask)
|
| 421 |
+
if output_hidden_states:
|
| 422 |
+
all_hidden_states += (x,)
|
| 423 |
+
if output_attentions:
|
| 424 |
+
all_self_attns += (self_attn_weights,)
|
| 425 |
+
|
| 426 |
+
if self.global_pool is None:
|
| 427 |
+
pooled_output = x
|
| 428 |
+
elif self.global_pool == "mean":
|
| 429 |
+
x = x[:, 1:, :].mean(dim=1)
|
| 430 |
+
pooled_output = self.fc_norm(x)
|
| 431 |
+
elif self.global_pool == "cls":
|
| 432 |
+
x = self.norm(x)
|
| 433 |
+
pooled_output = x[:, 0]
|
| 434 |
+
else:
|
| 435 |
+
raise ValueError(f"Invalid global pool type: {self.global_pool}")
|
| 436 |
+
|
| 437 |
+
if not return_dict:
|
| 438 |
+
return (pooled_output,) + (all_hidden_states if output_hidden_states else ()) + (None,)
|
| 439 |
+
|
| 440 |
+
return BaseModelOutput(
|
| 441 |
+
last_hidden_state=pooled_output,
|
| 442 |
+
hidden_states=all_hidden_states,
|
| 443 |
+
attentions=all_self_attns
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
class BirdMAEForSequenceClassification(BirdMAEPreTrainedModel):
|
| 448 |
+
_auto_class = "AutoModelForSequenceClassification"
|
| 449 |
+
def __init__(self, config: BirdMAEConfig, head_type: Literal["linear", "ppnet"]):
|
| 450 |
+
super().__init__(config)
|
| 451 |
+
self.num_labels = self.config.num_labels
|
| 452 |
+
self.head_type = head_type
|
| 453 |
+
self.model = BirdMAEModel(config)
|
| 454 |
+
if head_type == "linear":
|
| 455 |
+
self.head = nn.Linear(config.embed_dim, self.num_labels, bias=False)
|
| 456 |
+
elif head_type == "ppnet":
|
| 457 |
+
pass
|
| 458 |
+
else:
|
| 459 |
+
raise NotImplementedError(f"{head_type=} is not supported.")
|
| 460 |
+
|
| 461 |
+
def forward(self,
|
| 462 |
+
input_values: torch.Tensor,
|
| 463 |
+
attention_mask: torch.Tensor = None,
|
| 464 |
+
labels: torch.Tensor = None,
|
| 465 |
+
output_attentions: bool = None,
|
| 466 |
+
output_hidden_states: bool = None,
|
| 467 |
+
return_dict: bool = None):
|
| 468 |
+
return_dict = return_dict or self.config.return_dict
|
| 469 |
+
output_attentions = output_attentions or self.config.output_attentions
|
| 470 |
+
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
| 471 |
+
|
| 472 |
+
output = self.model(input_values,
|
| 473 |
+
attention_mask=attention_mask,
|
| 474 |
+
output_attentions=output_attentions,
|
| 475 |
+
output_hidden_states=output_hidden_states,
|
| 476 |
+
return_dict=return_dict)
|
| 477 |
+
|
| 478 |
+
hidden_state = output[0]
|
| 479 |
+
logits = self.head(hidden_state)
|
| 480 |
+
|
| 481 |
+
loss = None
|
| 482 |
+
if labels is not None:
|
| 483 |
+
labels = labels.to(logits.device)
|
| 484 |
+
if self.config.problem_type is None:
|
| 485 |
+
if self.num_labels == 1:
|
| 486 |
+
raise NotImplementedError(f"Setting num_labels={self.num_labels} indicates a regression task, which is not supported.")
|
| 487 |
+
elif self.num_labels > 1 and labels.shape != logits.shape:
|
| 488 |
+
self.config.problem_type = "single_label_classification"
|
| 489 |
+
else:
|
| 490 |
+
self.config.problem_type = "multi_label_classification"
|
| 491 |
+
|
| 492 |
+
if self.config.problem_type == "single_label_classification":
|
| 493 |
+
loss_fct = CrossEntropyLoss()
|
| 494 |
+
loss = loss_fct(
|
| 495 |
+
logits.view(-1, self.num_labels), labels.view(-1)
|
| 496 |
+
)
|
| 497 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 498 |
+
loss_fct = BCEWithLogitsLoss()
|
| 499 |
+
loss = loss_fct(logits, labels.float())
|
| 500 |
+
|
| 501 |
+
if not return_dict:
|
| 502 |
+
output = (logits,) + output[1:]
|
| 503 |
+
return ((loss,) + output) if loss is not None else output
|
| 504 |
+
|
| 505 |
+
return SequenceClassifierOutput(
|
| 506 |
+
loss=loss,
|
| 507 |
+
logits=logits,
|
| 508 |
+
hidden_states=output.hidden_states,
|
| 509 |
+
attentions=output.attentions,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
|