worstchan commited on
Commit
25d88d3
·
verified ·
1 Parent(s): ac7cbcd

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. config.json +38 -0
  2. configuration_eat.py +66 -0
  3. eat_model.py +99 -0
  4. model.safetensors +3 -0
  5. model_core.py +224 -0
  6. modeling_eat.py +18 -0
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "architectures": [
4
+ "EATModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModel": "modeling_eat.EATModel",
8
+ "AutoConfig": "configuration_eat.EATConfig"
9
+ },
10
+ "attn_drop_rate": 0.0,
11
+ "depth": 12,
12
+ "drop_rate": 0.0,
13
+ "embed_dim": 768,
14
+ "end_drop_path_rate": 0.0,
15
+ "fixed_positions": true,
16
+ "img_size": [
17
+ 1024,
18
+ 128
19
+ ],
20
+ "in_chans": 1,
21
+ "layer_norm_first": false,
22
+ "max_length": 768,
23
+ "mel_bins": 128,
24
+ "mlp_ratio": 4.0,
25
+ "model_type": "eat",
26
+ "model_variant": "pretrain",
27
+ "norm_affine": true,
28
+ "norm_eps": 1e-06,
29
+ "num_classes": 527,
30
+ "num_heads": 12,
31
+ "patch_size": 16,
32
+ "post_mlp_drop": 0.0,
33
+ "qkv_bias": true,
34
+ "start_drop_path_rate": 0.0,
35
+ "stride": 16,
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.51.3"
38
+ }
configuration_eat.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configuration_eat.py
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+ class EATConfig(PretrainedConfig):
6
+ model_type = "eat"
7
+
8
+ def __init__(
9
+ self,
10
+ embed_dim=768,
11
+ depth=12,
12
+ num_heads=12,
13
+ patch_size=16,
14
+ stride=16,
15
+ in_chans=1,
16
+ mel_bins=128,
17
+ max_length=768,
18
+ num_classes=527,
19
+ model_variant="pretrain", # or "finetune"
20
+
21
+ mlp_ratio=4.0,
22
+ qkv_bias=True,
23
+ drop_rate=0.0,
24
+ attn_drop_rate=0.0,
25
+ activation_dropout=0.0,
26
+ post_mlp_drop=0.0,
27
+ start_drop_path_rate=0.0,
28
+ end_drop_path_rate=0.0,
29
+
30
+ layer_norm_first=False,
31
+ norm_eps=1e-6,
32
+ norm_affine=True,
33
+ fixed_positions=True,
34
+
35
+ img_size=(1024, 128), # (target_length, mel_bins)
36
+
37
+ **kwargs,
38
+ ):
39
+ super().__init__(**kwargs)
40
+
41
+ self.embed_dim = embed_dim
42
+ self.depth = depth
43
+ self.num_heads = num_heads
44
+ self.patch_size = patch_size
45
+ self.stride = stride
46
+ self.in_chans = in_chans
47
+ self.mel_bins = mel_bins
48
+ self.max_length = max_length
49
+ self.num_classes = num_classes
50
+ self.model_variant = model_variant
51
+
52
+ self.mlp_ratio = mlp_ratio
53
+ self.qkv_bias = qkv_bias
54
+ self.drop_rate = drop_rate
55
+ self.attn_drop_rate = attn_drop_rate
56
+ self.activation_dropout = activation_dropout
57
+ self.post_mlp_drop = post_mlp_drop
58
+ self.start_drop_path_rate = start_drop_path_rate
59
+ self.end_drop_path_rate = end_drop_path_rate
60
+
61
+ self.layer_norm_first = layer_norm_first
62
+ self.norm_eps = norm_eps
63
+ self.norm_affine = norm_affine
64
+ self.fixed_positions = fixed_positions
65
+
66
+ self.img_size = img_size
eat_model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from timm.models.layers import trunc_normal_
4
+ from functools import partial
5
+ import numpy as np
6
+ from model_core import (
7
+ PatchEmbed_new,
8
+ get_2d_sincos_pos_embed_flexible,
9
+ FixedPositionalEncoder,
10
+ AltBlock
11
+ )
12
+
13
+ class EAT(nn.Module):
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ self.config = config
17
+ self.mode = config.model_variant # "pretrain" or "finetune"
18
+
19
+ # === Embedding / Encoder ===
20
+ self.local_encoder = PatchEmbed_new(
21
+ img_size=config.img_size,
22
+ patch_size=config.patch_size,
23
+ in_chans=config.in_chans,
24
+ embed_dim=config.embed_dim,
25
+ stride=config.stride
26
+ )
27
+
28
+ self.extra_tokens = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
29
+ self.pos_drop = nn.Dropout(p=config.drop_rate, inplace=True)
30
+ trunc_normal_(self.extra_tokens, std=.02)
31
+
32
+ self.fixed_positional_encoder = (
33
+ FixedPositionalEncoder(self.build_sincos_pos_embed()) if config.fixed_positions else None
34
+ )
35
+
36
+ norm_layer = partial(nn.LayerNorm, eps=config.norm_eps, elementwise_affine=config.norm_affine)
37
+ dpr = np.linspace(config.start_drop_path_rate, config.end_drop_path_rate, config.depth)
38
+ self.blocks = nn.ModuleList([
39
+ AltBlock(config.embed_dim, config.num_heads, config.mlp_ratio,
40
+ qkv_bias=config.qkv_bias, drop=config.drop_rate,
41
+ attn_drop=config.attn_drop_rate, mlp_drop=config.activation_dropout,
42
+ post_mlp_drop=config.post_mlp_drop, drop_path=dpr[i],
43
+ norm_layer=norm_layer, layer_norm_first=config.layer_norm_first,
44
+ ffn_targets=True)
45
+ for i in range(config.depth)
46
+ ])
47
+
48
+ self.pre_norm = norm_layer(config.embed_dim)
49
+
50
+ # === Head (for finetune) ===
51
+ if self.mode == "finetune":
52
+ self.fc_norm = nn.LayerNorm(config.embed_dim)
53
+ self.head = nn.Linear(config.embed_dim, config.num_classes, bias=True)
54
+ else:
55
+ self.head = nn.Identity()
56
+
57
+ self.apply(self._init_weights)
58
+
59
+ def build_sincos_pos_embed(self):
60
+ W = self.config.mel_bins // self.config.patch_size
61
+ max_length = self.config.max_length
62
+ embed_dim = self.config.embed_dim
63
+ pos_embed = nn.Parameter(torch.zeros(1, max_length * W, embed_dim), requires_grad=False)
64
+ emb = get_2d_sincos_pos_embed_flexible(embed_dim, (max_length, W), cls_token=False)
65
+ pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
66
+ return pos_embed
67
+
68
+ def _init_weights(self, m):
69
+ if isinstance(m, nn.Linear):
70
+ trunc_normal_(m.weight, std=.02)
71
+ if m.bias is not None:
72
+ nn.init.constant_(m.bias, 0)
73
+ elif isinstance(m, nn.LayerNorm):
74
+ nn.init.constant_(m.bias, 0)
75
+ nn.init.constant_(m.weight, 1.0)
76
+
77
+ def encode(self, x):
78
+ B = x.shape[0]
79
+ x = self.local_encoder(x)
80
+ if self.fixed_positional_encoder is not None:
81
+ x = x + self.fixed_positional_encoder(x, None)[:, :x.size(1), :]
82
+ x = torch.cat((self.extra_tokens.expand(B, -1, -1), x), dim=1)
83
+ x = self.pre_norm(x)
84
+ x = self.pos_drop(x)
85
+ for blk in self.blocks:
86
+ x, _ = blk(x)
87
+ return x
88
+
89
+ def forward(self, x):
90
+ x = self.encode(x)
91
+ if self.mode == "finetune":
92
+ x = x[:, 0] # use cls token
93
+ x = self.fc_norm(x)
94
+ x = self.head(x)
95
+ return x
96
+
97
+ def extract_features(self, x):
98
+ x = self.encode(x)
99
+ return x
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8623072d09aac4f3ad1168b4fed3a24e4f68fe1da25b9fe733375efb237e5f48
3
+ size 359905840
model_core.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from timm.models.layers import to_2tuple
6
+
7
+ class PatchEmbed_new(nn.Module):
8
+ """ Flexible Image to Patch Embedding
9
+ """
10
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=16):
11
+ super().__init__()
12
+ img_size = to_2tuple(img_size)
13
+ patch_size = to_2tuple(patch_size)
14
+ stride = to_2tuple(stride)
15
+
16
+ self.img_size = img_size
17
+ self.patch_size = patch_size
18
+
19
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
20
+
21
+ def forward(self, x):
22
+ x = self.proj(x)
23
+ x = x.flatten(2).transpose(1, 2)
24
+ return x
25
+
26
+
27
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
28
+ """
29
+ grid_size: int of the grid height and width
30
+ return:
31
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
32
+ """
33
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
34
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
35
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
36
+ grid = np.stack(grid, axis=0)
37
+
38
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
39
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
40
+ if cls_token:
41
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
42
+ return pos_embed
43
+
44
+
45
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
46
+ assert embed_dim % 2 == 0
47
+
48
+ # use half of dimensions to encode grid_h
49
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
50
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
51
+
52
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
53
+ return emb
54
+
55
+
56
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
57
+ """
58
+ embed_dim: output dimension for each position
59
+ pos: a list of positions to be encoded: size (M,)
60
+ out: (M, D)
61
+ """
62
+ assert embed_dim % 2 == 0
63
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
64
+ omega /= embed_dim / 2.0
65
+ omega = 1.0 / 10000 ** omega # (D/2,)
66
+
67
+ pos = pos.reshape(-1) # (M,)
68
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
69
+
70
+ emb_sin = np.sin(out) # (M, D/2)
71
+ emb_cos = np.cos(out) # (M, D/2)
72
+
73
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
74
+ return emb
75
+
76
+
77
+ class FixedPositionalEncoder(nn.Module):
78
+ def __init__(self, pos_embed):
79
+ super().__init__()
80
+ self.positions = pos_embed
81
+
82
+ def forward(self, x, padding_mask):
83
+ return self.positions
84
+
85
+
86
+ class AltBlock(nn.Module):
87
+ def __init__(
88
+ self,
89
+ dim,
90
+ num_heads,
91
+ mlp_ratio=4.0,
92
+ qkv_bias=False,
93
+ qk_scale=None,
94
+ drop=0.0,
95
+ attn_drop=0.0,
96
+ mlp_drop=0.0,
97
+ post_mlp_drop=0.0,
98
+ drop_path=0.0,
99
+ act_layer=nn.GELU,
100
+ norm_layer=nn.LayerNorm,
101
+ layer_norm_first=True,
102
+ ffn_targets=False,
103
+ cosine_attention=False,
104
+ ):
105
+ super().__init__()
106
+
107
+ self.layer_norm_first = layer_norm_first
108
+ self.ffn_targets = ffn_targets
109
+
110
+ from timm.models.vision_transformer import DropPath, Mlp
111
+
112
+ self.norm1 = norm_layer(dim)
113
+ self.attn = AltAttention(
114
+ dim,
115
+ num_heads=num_heads,
116
+ qkv_bias=qkv_bias,
117
+ qk_scale=qk_scale,
118
+ attn_drop=attn_drop,
119
+ proj_drop=drop,
120
+ cosine_attention=cosine_attention,
121
+ )
122
+
123
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
124
+ self.norm2 = norm_layer(dim)
125
+ mlp_hidden_dim = int(dim * mlp_ratio)
126
+ self.mlp = Mlp(
127
+ in_features=dim,
128
+ hidden_features=mlp_hidden_dim,
129
+ act_layer=act_layer,
130
+ drop=mlp_drop,
131
+ )
132
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
133
+
134
+ def forward(self, x, padding_mask=None, alibi_bias=None):
135
+ if self.layer_norm_first:
136
+ x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
137
+ r = x = self.mlp(self.norm2(x))
138
+ t = x
139
+ x = r + self.drop_path(self.post_mlp_dropout(x))
140
+ if not self.ffn_targets:
141
+ t = x
142
+ else:
143
+ x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
144
+ r = x = self.norm1(x)
145
+ x = self.mlp(x)
146
+ t = x
147
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
148
+ if not self.ffn_targets:
149
+ t = x
150
+
151
+ return x, t
152
+
153
+
154
+ class AltAttention(nn.Module):
155
+ def __init__(
156
+ self,
157
+ dim,
158
+ num_heads=8,
159
+ qkv_bias=False,
160
+ qk_scale=None,
161
+ attn_drop=0.0,
162
+ proj_drop=0.0,
163
+ cosine_attention=False,
164
+ ):
165
+ super().__init__()
166
+ self.num_heads = num_heads
167
+ head_dim = dim // num_heads
168
+ self.scale = qk_scale or head_dim ** -0.5
169
+
170
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
171
+ self.attn_drop = nn.Dropout(attn_drop)
172
+ self.proj = nn.Linear(dim, dim)
173
+ self.proj_drop = nn.Dropout(proj_drop)
174
+
175
+ self.cosine_attention = cosine_attention
176
+
177
+ if cosine_attention:
178
+ self.logit_scale = nn.Parameter(
179
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
180
+ )
181
+
182
+ def forward(self, x, padding_mask=None, alibi_bias=None):
183
+ B, N, C = x.shape
184
+ qkv = (
185
+ self.qkv(x)
186
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
187
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
188
+ )
189
+ q, k, v = (
190
+ qkv[0],
191
+ qkv[1],
192
+ qkv[2],
193
+ ) # make torchscript happy (cannot use tensor as tuple)
194
+
195
+ dtype = q.dtype
196
+
197
+ if self.cosine_attention:
198
+ # cosine attention
199
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
200
+ logit_scale = torch.clamp(
201
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
202
+ ).exp()
203
+ attn = attn * logit_scale
204
+ else:
205
+ q = q * self.scale
206
+ attn = q @ k.transpose(-2, -1)
207
+
208
+ if alibi_bias is not None:
209
+ attn = attn.type_as(alibi_bias)
210
+ attn[:, : alibi_bias.size(1)] += alibi_bias
211
+
212
+ if padding_mask is not None and padding_mask.any():
213
+ attn = attn.masked_fill(
214
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
215
+ float("-inf"),
216
+ )
217
+
218
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
219
+ attn = self.attn_drop(attn)
220
+ x = (attn @ v).transpose(1, 2) #
221
+ x = x.reshape(B, N, C)
222
+ x = self.proj(x)
223
+ x = self.proj_drop(x)
224
+ return x
modeling_eat.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_eat.py
2
+
3
+ from transformers import PreTrainedModel
4
+ from configuration_eat import EATConfig
5
+ from eat_model import EAT
6
+
7
+ class EATModel(PreTrainedModel):
8
+ config_class = EATConfig
9
+
10
+ def __init__(self, config: EATConfig):
11
+ super().__init__(config)
12
+ self.model = EAT(config)
13
+
14
+ def forward(self, *args, **kwargs):
15
+ return self.model(*args, **kwargs)
16
+
17
+ def extract_features(self, x):
18
+ return self.model.extract_features(x)