mwirth7 commited on
Commit
055d923
·
verified ·
1 Parent(s): b31ecc1

Bird-MAE-Huge

Browse files
Files changed (3) hide show
  1. config.json +4 -3
  2. configuration_bird_mae.py +55 -0
  3. modeling_bird_mae.py +516 -0
config.json CHANGED
@@ -4,18 +4,19 @@
4
  ],
5
  "attn_drop_rate": 0.0,
6
  "auto_map": {
7
- "AutoConfig": "config.BirdMAEConfig",
8
- "AutoModel": "model.BirdMAEModel"
9
  },
10
  "depth": 32,
11
  "drop_path_rate": 0.0,
12
  "drop_rate": 0.0,
13
  "embed_dim": 1280,
 
14
  "img_size_x": 512,
15
  "img_size_y": 128,
16
  "in_chans": 1,
17
  "init_values": null,
18
- "mlp_ratio": 4.0,
19
  "norm_layer_eps": 1e-06,
20
  "num_heads": 16,
21
  "num_patches": 256,
 
4
  ],
5
  "attn_drop_rate": 0.0,
6
  "auto_map": {
7
+ "AutoConfig": "configuration_bird_mae.BirdMAEConfig",
8
+ "AutoModel": "modeling_bird_mae.BirdMAEModel"
9
  },
10
  "depth": 32,
11
  "drop_path_rate": 0.0,
12
  "drop_rate": 0.0,
13
  "embed_dim": 1280,
14
+ "global_pool": "mean",
15
  "img_size_x": 512,
16
  "img_size_y": 128,
17
  "in_chans": 1,
18
  "init_values": null,
19
+ "mlp_ratio": 4,
20
  "norm_layer_eps": 1e-06,
21
  "num_heads": 16,
22
  "num_patches": 256,
configuration_bird_mae.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import Literal
3
+
4
+
5
+ class BirdMAEConfig(PretrainedConfig):
6
+ """This represents the Bird-MAE-Base config from the original paper"""
7
+ _auto_class = "AutoConfig"
8
+
9
+ def __init__(
10
+ self,
11
+ img_size_x: int = 512,
12
+ img_size_y: int = 128,
13
+ patch_size: int = 16,
14
+ in_chans: int = 1,
15
+ embed_dim: int = 768,
16
+ depth: int = 12,
17
+ num_heads: int = 12,
18
+ mlp_ratio: int = 4,
19
+ pos_trainable: bool = False,
20
+ qkv_bias: bool = True,
21
+ qk_norm: bool = False,
22
+ init_values: float = None,
23
+ drop_rate: float = 0.0,
24
+ norm_layer_eps: float = 1e-6,
25
+ global_pool: Literal["cls", "mean"] | None = "mean",
26
+ **kwargs
27
+ ):
28
+ super().__init__(**kwargs)
29
+
30
+ self.img_size_x = img_size_x
31
+ self.img_size_y = img_size_y
32
+ self.patch_size = patch_size
33
+ self.in_chans = in_chans
34
+ self.embed_dim = embed_dim
35
+ self.depth = depth
36
+ self.num_heads = num_heads
37
+ self.mlp_ratio = mlp_ratio
38
+ self.pos_trainable = pos_trainable
39
+
40
+ self.qkv_bias = qkv_bias
41
+ self.qk_norm = qk_norm
42
+ self.init_values = init_values
43
+ self.drop_rate = drop_rate
44
+ self.pos_drop_rate = drop_rate
45
+ self.attn_drop_rate = drop_rate
46
+ self.drop_path_rate = drop_rate
47
+ self.proj_drop_rate = drop_rate
48
+ self.norm_layer_eps = norm_layer_eps
49
+ self.global_pool = global_pool
50
+
51
+ # Calculated properties (useful for initializing the model)
52
+ self.num_patches_x = img_size_x // patch_size
53
+ self.num_patches_y = img_size_y // patch_size
54
+ self.num_patches = self.num_patches_x * self.num_patches_y
55
+ self.num_tokens = self.num_patches + 1
modeling_bird_mae.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import collections
3
+ from itertools import repeat
4
+ from functools import partial
5
+ from typing import Optional, Literal
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from transformers import PreTrainedModel
11
+ from transformers.utils import logging
12
+ from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput
13
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
14
+
15
+ from .configuration_bird_mae import BirdMAEConfig
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
21
+ """
22
+ embed_dim: output dimension for each position
23
+ pos: a list of positions to be encoded: size (M,)
24
+ out: (M, D)
25
+ """
26
+ assert embed_dim % 2 == 0
27
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
28
+ omega /= embed_dim / 2.
29
+ omega = 1. / 10000**omega # (D/2,)
30
+
31
+ pos = pos.reshape(-1) # (M,)
32
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
33
+
34
+ emb_sin = np.sin(out) # (M, D/2)
35
+ emb_cos = np.cos(out) # (M, D/2)
36
+
37
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
38
+ return emb
39
+
40
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
41
+ assert embed_dim % 2 == 0
42
+
43
+ # use half of dimensions to encode grid_h
44
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
45
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
46
+
47
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
48
+ return emb
49
+
50
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
51
+ """
52
+ grid_size: int of the grid height and width
53
+ return:
54
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
55
+ """
56
+ grid_h = np.arange(grid_size[0], dtype=np.float32) # grid size[0] = 8
57
+ grid_w = np.arange(grid_size[1], dtype=np.float32) # grid size[1] = 32
58
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
59
+ grid = np.stack(grid, axis=0) # 2,8,32
60
+
61
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) # 2,1,8.32
62
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
63
+ if cls_token:
64
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
65
+ return pos_embed # 267 (+cls) x 1024 (feature dim)
66
+
67
+ # From timm.models.layers
68
+ class DropPath(nn.Module):
69
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
70
+ """
71
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
72
+ super(DropPath, self).__init__()
73
+ self.drop_prob = drop_prob
74
+ self.scale_by_keep = scale_by_keep
75
+
76
+ def forward(self, x):
77
+ if self.drop_prob == 0. or not self.training:
78
+ return x
79
+ keep_prob = 1 - self.drop_prob
80
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
81
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
82
+ if keep_prob > 0.0 and self.scale_by_keep:
83
+ random_tensor.div_(keep_prob)
84
+ return x * random_tensor
85
+
86
+ def _ntuple(n):
87
+ def parse(x):
88
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
89
+ return tuple(x)
90
+ return tuple(repeat(x, n))
91
+ return parse
92
+
93
+ class Mlp(nn.Module):
94
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
95
+ """
96
+ def __init__(
97
+ self,
98
+ in_features,
99
+ hidden_features=None,
100
+ out_features=None,
101
+ act_layer=nn.GELU,
102
+ norm_layer=None,
103
+ bias=True,
104
+ drop=0.,
105
+ use_conv=False,
106
+ ):
107
+ super().__init__()
108
+ out_features = out_features or in_features
109
+ hidden_features = hidden_features or in_features
110
+ bias = _ntuple(2)(bias)
111
+ drop_probs = _ntuple(2)(drop)
112
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
113
+
114
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
115
+ self.act = act_layer()
116
+ self.drop1 = nn.Dropout(drop_probs[0])
117
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
118
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
119
+ self.drop2 = nn.Dropout(drop_probs[1])
120
+
121
+ def forward(self, x):
122
+ x = self.fc1(x)
123
+ x = self.act(x)
124
+ x = self.drop1(x)
125
+ x = self.norm(x)
126
+ x = self.fc2(x)
127
+ x = self.drop2(x)
128
+ return x
129
+
130
+ # Modified from timm.models.vision_transformer
131
+ class Attention(nn.Module):
132
+ """Standard Multi-head Self Attention module with QKV projection.
133
+
134
+ This module implements the standard multi-head attention mechanism used in transformers.
135
+ It supports both the fused attention implementation (scaled_dot_product_attention) for
136
+ efficiency when available, and a manual implementation otherwise. The module includes
137
+ options for QK normalization, attention dropout, and projection dropout.
138
+ """
139
+ fused_attn: bool = True
140
+
141
+ def __init__(
142
+ self,
143
+ dim: int,
144
+ num_heads: int = 8,
145
+ qkv_bias: bool = False,
146
+ qk_norm: bool = False,
147
+ scale_norm: bool = False,
148
+ proj_bias: bool = True,
149
+ attn_drop: float = 0.,
150
+ proj_drop: float = 0.,
151
+ norm_layer: nn.Module = None,
152
+ ) -> None:
153
+ """Initialize the Attention module.
154
+
155
+ Args:
156
+ dim: Input dimension of the token embeddings
157
+ num_heads: Number of attention heads
158
+ qkv_bias: Whether to use bias in the query, key, value projections
159
+ qk_norm: Whether to apply normalization to query and key vectors
160
+ proj_bias: Whether to use bias in the output projection
161
+ attn_drop: Dropout rate applied to the attention weights
162
+ proj_drop: Dropout rate applied after the output projection
163
+ norm_layer: Normalization layer constructor for QK normalization if enabled
164
+ """
165
+ super().__init__()
166
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
167
+ if qk_norm or scale_norm:
168
+ assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
169
+ self.num_heads = num_heads
170
+ self.head_dim = dim // num_heads
171
+ self.scale = self.head_dim ** -0.5
172
+
173
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
174
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
175
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
176
+ self.attn_drop = nn.Dropout(attn_drop)
177
+ self.norm = norm_layer(dim) if scale_norm else nn.Identity()
178
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
179
+ self.proj_drop = nn.Dropout(proj_drop)
180
+
181
+ def forward(
182
+ self,
183
+ x: torch.Tensor,
184
+ attn_mask: torch.Tensor = None,
185
+ output_attentions: bool = False,
186
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
187
+ B, N, C = x.shape
188
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
189
+ q, k, v = qkv.unbind(0)
190
+ q, k = self.q_norm(q), self.k_norm(k)
191
+
192
+ attn_weights = None
193
+
194
+ if self.fused_attn and not output_attentions:
195
+ x = F.scaled_dot_product_attention(
196
+ q, k, v,
197
+ attn_mask=attn_mask,
198
+ dropout_p=self.attn_drop.p if self.training else 0.,
199
+ )
200
+ else:
201
+ q = q * self.scale
202
+ attn = q @ k.transpose(-2, -1)
203
+ attn = attn if attn_mask is None else attn + attn_mask
204
+ attn_weights = attn.softmax(dim=-1)
205
+ x = self.attn_drop(attn_weights) @ v
206
+
207
+ x = x.transpose(1, 2).reshape(B, N, C)
208
+ x = self.norm(x)
209
+ x = self.proj(x)
210
+ x = self.proj_drop(x)
211
+
212
+ return x, attn_weights
213
+
214
+
215
+ # From timm.models.vision_transformer
216
+ class Block(nn.Module):
217
+ def __init__(
218
+ self,
219
+ dim: int,
220
+ num_heads: int,
221
+ mlp_ratio: float = 4.,
222
+ qkv_bias: bool = False,
223
+ qk_norm: bool = False,
224
+ proj_drop: float = 0.,
225
+ attn_drop: float = 0.,
226
+ init_values: float = None,
227
+ drop_path: float = 0.,
228
+ act_layer: nn.Module = nn.GELU,
229
+ norm_layer: nn.Module = nn.LayerNorm,
230
+ mlp_layer: nn.Module = Mlp,
231
+ ) -> None:
232
+ super().__init__()
233
+ self.norm1 = norm_layer(dim)
234
+ self.attn = Attention(
235
+ dim,
236
+ num_heads=num_heads,
237
+ qkv_bias=qkv_bias,
238
+ qk_norm=qk_norm,
239
+ attn_drop=attn_drop,
240
+ proj_drop=proj_drop,
241
+ norm_layer=norm_layer,
242
+ )
243
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
244
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
245
+
246
+ self.norm2 = norm_layer(dim)
247
+ self.mlp = mlp_layer(
248
+ in_features=dim,
249
+ hidden_features=int(dim * mlp_ratio),
250
+ act_layer=act_layer,
251
+ drop=proj_drop,
252
+ )
253
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
254
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
255
+
256
+ def forward(self, x: torch.Tensor,
257
+ output_attentions: bool = False,
258
+ attn_mask: torch.Tensor = None
259
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
260
+ #x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
261
+ x_skip = x
262
+ x = self.norm1(x)
263
+ x, att = self.attn(x, output_attentions=output_attentions, attn_mask=attn_mask)
264
+ x = self.ls1(x)
265
+ x = self.drop_path1(x)
266
+ x += x_skip
267
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
268
+ return x, att
269
+
270
+ # From timm.models.vision_transformer
271
+ class LayerScale(nn.Module):
272
+ def __init__(
273
+ self,
274
+ dim: int,
275
+ init_values: float = 1e-5,
276
+ inplace: bool = False,
277
+ ) -> None:
278
+ super().__init__()
279
+ self.inplace = inplace
280
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
281
+
282
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
283
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
284
+
285
+
286
+ class PatchEmbed_org(nn.Module):
287
+ """ Image to Patch Embedding
288
+ """
289
+ def __init__(self,
290
+ img_size: int | tuple[int, ...] = 224,
291
+ patch_size: int | tuple[int, ...] = 16,
292
+ in_chans=3,
293
+ embed_dim=768):
294
+ super().__init__()
295
+ img_size: tuple[int,int] = _ntuple(2)(img_size) # audio mae used: (target_length x 128) --> not sure why tbh
296
+ patch_size: tuple[int,int] = _ntuple(2)(patch_size)
297
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
298
+ self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) # number of patches height/width = 8/32
299
+ self.img_size = img_size
300
+ self.patch_size = patch_size
301
+ self.num_patches = num_patches
302
+
303
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
304
+
305
+ def forward(self, x):
306
+ B, C, H, W = x.shape #batch size, channels, height, width --> apparently sth else is expected???
307
+ x = self.proj(x) # 1, 1, 512, 128 -> 1, 768, 32, 8 (batch, 768 channel, 32 height, 8 width)
308
+ x = x.flatten(2) # 1, 768, 32, 8 -> 1, 768, 256
309
+ x = x.transpose(1, 2) # 1, 768, 256 -> 1, 256, 768
310
+ return x
311
+
312
+ # --- END OF NECESSARY TIMM/Custom internal module definitions ---
313
+
314
+
315
+ class BirdMAEPreTrainedModel(PreTrainedModel):
316
+ config_class = BirdMAEConfig
317
+ base_model_prefix = "model"
318
+
319
+ def _init_weights(self, module):
320
+ if isinstance(module, nn.Linear):
321
+ nn.init.normal_(module.weight, std=.02)
322
+ if module.bias is not None:
323
+ nn.init.constant_(module.bias, 0)
324
+ elif isinstance(module, nn.LayerNorm):
325
+ nn.init.constant_(module.weight, 1.0)
326
+ nn.init.constant_(module.bias, 0)
327
+ elif isinstance(module, nn.Conv2d):
328
+ w = module.weight.data
329
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
330
+
331
+
332
+ class BirdMAEModel(BirdMAEPreTrainedModel):
333
+ _auto_class = "AutoModel"
334
+ #_keys_to_ignore_on_load_missing = ["fc_norm.weight", "fc_norm.bias"]
335
+
336
+ def __init__(self, config: BirdMAEConfig):
337
+ super().__init__(config)
338
+
339
+ self.patch_embed = PatchEmbed_org(
340
+ img_size=(config.img_size_x, config.img_size_y), # (512, 128)
341
+ patch_size=config.patch_size,
342
+ in_chans=config.in_chans,
343
+ embed_dim=config.embed_dim
344
+ )
345
+
346
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
347
+
348
+ self.pos_embed = nn.Parameter(
349
+ torch.zeros(1, config.num_patches + 1, config.embed_dim),
350
+ requires_grad=config.pos_trainable
351
+ )
352
+
353
+ if self.pos_embed.data.shape[1] == config.num_patches + 1:
354
+ pos_embed_np = get_2d_sincos_pos_embed_flexible(
355
+ self.pos_embed.shape[-1], # embedding dim
356
+ self.patch_embed.patch_hw, # (8, 32) for a 128x512 image with 16x16 patches
357
+ cls_token=True
358
+ )
359
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed_np).float().unsqueeze(0))
360
+ else:
361
+ logger.warning("Positional embedding shape mismatch. Will not initialize sin-cos pos embed.")
362
+
363
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)]
364
+ self.blocks = nn.ModuleList([
365
+ Block(
366
+ dim=config.embed_dim,
367
+ num_heads=config.num_heads,
368
+ mlp_ratio=config.mlp_ratio,
369
+ qkv_bias=config.qkv_bias,
370
+ qk_norm=config.qk_norm,
371
+ init_values=config.init_values,
372
+ proj_drop=config.proj_drop_rate,
373
+ attn_drop=config.attn_drop_rate,
374
+ drop_path=dpr[i],
375
+ #norm_layer=nn.LayerNorm(config.embed_dim, eps=config.norm_layer_eps)
376
+ norm_layer=partial(nn.LayerNorm, eps=config.norm_layer_eps)
377
+ )
378
+ for i in range(config.depth)
379
+ ])
380
+
381
+ self.pos_drop = nn.Dropout(p=config.pos_drop_rate)
382
+ self.norm = nn.LayerNorm(config.embed_dim, eps=config.norm_layer_eps) #norm_layer(config.embed_dim)
383
+ self.fc_norm = nn.LayerNorm(config.embed_dim, eps=config.norm_layer_eps) #norm_layer(config.embed_dim)
384
+ self.global_pool = config.global_pool
385
+
386
+ nn.init.trunc_normal_(self.cls_token, std=.02)
387
+
388
+ def forward(
389
+ self,
390
+ input_values : torch.Tensor,
391
+ attention_mask: torch.Tensor = None,
392
+ output_attentions: bool = None,
393
+ output_hidden_states: bool = None,
394
+ return_dict: bool = None,
395
+ ) -> tuple | BaseModelOutput:
396
+ if len(input_values.shape) == 3:
397
+ input_values = input_values.unsqueeze(0)
398
+
399
+ output_attentions = output_attentions or self.config.output_attentions
400
+
401
+ output_hidden_states = output_hidden_states or self.config.output_hidden_states
402
+ return_dict = return_dict or self.config.use_return_dict
403
+
404
+ B, C, X, Y = input_values.shape
405
+ assert X == self.config.img_size_x, f"Expected image_size_x={self.config.img_size_x} but was {X}."
406
+ assert Y == self.config.img_size_y, f"Expected image_size_y={self.config.img_size_y} but was {Y}."
407
+
408
+ x = self.patch_embed(input_values)
409
+
410
+ x = x + self.pos_embed[:, 1:, :]
411
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
412
+ cls_tokens = cls_token.expand(B, -1, -1)
413
+ x = torch.cat((cls_tokens, x), dim=1)
414
+ x = self.pos_drop(x)
415
+
416
+ all_hidden_states = (x,) if output_hidden_states else None
417
+ all_self_attns = () if output_attentions else None
418
+
419
+ for blk in self.blocks:
420
+ x, self_attn_weights = blk(x, output_attentions=output_attentions, attn_mask=attention_mask)
421
+ if output_hidden_states:
422
+ all_hidden_states += (x,)
423
+ if output_attentions:
424
+ all_self_attns += (self_attn_weights,)
425
+
426
+ if self.global_pool is None:
427
+ pooled_output = x
428
+ elif self.global_pool == "mean":
429
+ x = x[:, 1:, :].mean(dim=1)
430
+ pooled_output = self.fc_norm(x)
431
+ elif self.global_pool == "cls":
432
+ x = self.norm(x)
433
+ pooled_output = x[:, 0]
434
+ else:
435
+ raise ValueError(f"Invalid global pool type: {self.global_pool}")
436
+
437
+ if not return_dict:
438
+ return (pooled_output,) + (all_hidden_states if output_hidden_states else ()) + (None,)
439
+
440
+ return BaseModelOutput(
441
+ last_hidden_state=pooled_output,
442
+ hidden_states=all_hidden_states,
443
+ attentions=all_self_attns
444
+ )
445
+
446
+
447
+ class BirdMAEForSequenceClassification(BirdMAEPreTrainedModel):
448
+ _auto_class = "AutoModelForSequenceClassification"
449
+ def __init__(self, config: BirdMAEConfig, head_type: Literal["linear", "ppnet"]):
450
+ super().__init__(config)
451
+ self.num_labels = self.config.num_labels
452
+ self.head_type = head_type
453
+ self.model = BirdMAEModel(config)
454
+ if head_type == "linear":
455
+ self.head = nn.Linear(config.embed_dim, self.num_labels, bias=False)
456
+ elif head_type == "ppnet":
457
+ pass
458
+ else:
459
+ raise NotImplementedError(f"{head_type=} is not supported.")
460
+
461
+ def forward(self,
462
+ input_values: torch.Tensor,
463
+ attention_mask: torch.Tensor = None,
464
+ labels: torch.Tensor = None,
465
+ output_attentions: bool = None,
466
+ output_hidden_states: bool = None,
467
+ return_dict: bool = None):
468
+ return_dict = return_dict or self.config.return_dict
469
+ output_attentions = output_attentions or self.config.output_attentions
470
+ output_hidden_states = output_hidden_states or self.config.output_hidden_states
471
+
472
+ output = self.model(input_values,
473
+ attention_mask=attention_mask,
474
+ output_attentions=output_attentions,
475
+ output_hidden_states=output_hidden_states,
476
+ return_dict=return_dict)
477
+
478
+ hidden_state = output[0]
479
+ logits = self.head(hidden_state)
480
+
481
+ loss = None
482
+ if labels is not None:
483
+ labels = labels.to(logits.device)
484
+ if self.config.problem_type is None:
485
+ if self.num_labels == 1:
486
+ raise NotImplementedError(f"Setting num_labels={self.num_labels} indicates a regression task, which is not supported.")
487
+ elif self.num_labels > 1 and labels.shape != logits.shape:
488
+ self.config.problem_type = "single_label_classification"
489
+ else:
490
+ self.config.problem_type = "multi_label_classification"
491
+
492
+ if self.config.problem_type == "single_label_classification":
493
+ loss_fct = CrossEntropyLoss()
494
+ loss = loss_fct(
495
+ logits.view(-1, self.num_labels), labels.view(-1)
496
+ )
497
+ elif self.config.problem_type == "multi_label_classification":
498
+ loss_fct = BCEWithLogitsLoss()
499
+ loss = loss_fct(logits, labels.float())
500
+
501
+ if not return_dict:
502
+ output = (logits,) + output[1:]
503
+ return ((loss,) + output) if loss is not None else output
504
+
505
+ return SequenceClassifierOutput(
506
+ loss=loss,
507
+ logits=logits,
508
+ hidden_states=output.hidden_states,
509
+ attentions=output.attentions,
510
+ )
511
+
512
+
513
+
514
+
515
+
516
+