aliasthebone commited on
Commit
eb977fa
·
1 Parent(s): 3bda970

update lsnet_artist.py in preparation of lsnet_xl_artist_448 arch

Browse files
Files changed (1) hide show
  1. lsnet/lsnet_artist.py +73 -23
lsnet/lsnet_artist.py CHANGED
@@ -1,11 +1,25 @@
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
- from .lsnet import LSNet, Conv2d_BN, BN_Linear
4
- from timm.models import register_model
5
- from timm.models import build_model_with_cfg
6
 
7
 
8
  class LSNetArtist(LSNet):
 
 
 
 
 
 
 
 
 
 
9
  def __init__(self,
10
  img_size=224,
11
  patch_size=8,
@@ -71,14 +85,20 @@ class LSNetArtist(LSNet):
71
  x = self.projection(x)
72
  return x
73
 
74
- def forward(self, x, return_features=False):
75
  """
76
- x: 输入图像
77
- return_features: 是否只返回特征向量(用于聚类)
78
- False时返回分类logits(用于分类)
79
 
80
- 如果return_features=True: 返回特征向量 (batch_size, feature_dim)
81
- 如果return_features=False: 返回分类logits (batch_size, num_classes)
 
 
 
 
 
 
 
 
82
  """
83
  features = self.forward_features(x)
84
 
@@ -88,23 +108,26 @@ class LSNetArtist(LSNet):
88
 
89
  # 返回分类结果
90
  if self.distillation:
91
- x = self.head(features), self.head_dist(features)
92
  if not self.training:
93
- x = (x[0] + x[1]) / 2
94
  else:
95
- x = self.head(features)
96
 
97
- return x
 
 
 
98
 
99
  def get_features(self, x):
100
  """
101
- 提取特征向量
102
  """
103
  return self.forward(x, return_features=True)
104
 
105
  def classify(self, x):
106
  """
107
- 进行分类
108
  """
109
  return self.forward(x, return_features=False)
110
 
@@ -129,8 +152,9 @@ default_cfgs_artist = dict(
129
  lsnet_t_artist = _cfg_artist(),
130
  lsnet_s_artist = _cfg_artist(),
131
  lsnet_b_artist = _cfg_artist(),
132
- lsnet_l_artist = _cfg_artist(),
133
- lsnet_xl_artist = _cfg_artist(),
 
134
  )
135
 
136
 
@@ -151,6 +175,7 @@ def _create_lsnet_artist(variant, pretrained=False, **kwargs):
151
  @register_model
152
  def lsnet_t_artist(num_classes=1000, distillation=False, pretrained=False,
153
  feature_dim=None, use_projection=True, **kwargs):
 
154
  model = _create_lsnet_artist(
155
  "lsnet_t_artist",
156
  pretrained=pretrained,
@@ -171,6 +196,7 @@ def lsnet_t_artist(num_classes=1000, distillation=False, pretrained=False,
171
  @register_model
172
  def lsnet_s_artist(num_classes=1000, distillation=False, pretrained=False,
173
  feature_dim=None, use_projection=True, **kwargs):
 
174
  model = _create_lsnet_artist(
175
  "lsnet_s_artist",
176
  pretrained=pretrained,
@@ -191,6 +217,7 @@ def lsnet_s_artist(num_classes=1000, distillation=False, pretrained=False,
191
  @register_model
192
  def lsnet_b_artist(num_classes=1000, distillation=False, pretrained=False,
193
  feature_dim=None, use_projection=True, **kwargs):
 
194
  model = _create_lsnet_artist(
195
  "lsnet_b_artist",
196
  pretrained=pretrained,
@@ -211,6 +238,7 @@ def lsnet_b_artist(num_classes=1000, distillation=False, pretrained=False,
211
  @register_model
212
  def lsnet_l_artist(num_classes=1000, distillation=False, pretrained=False,
213
  feature_dim=None, use_projection=True, **kwargs):
 
214
  model = _create_lsnet_artist(
215
  "lsnet_l_artist",
216
  pretrained=pretrained,
@@ -218,9 +246,9 @@ def lsnet_l_artist(num_classes=1000, distillation=False, pretrained=False,
218
  distillation=distillation,
219
  img_size=224,
220
  patch_size=8,
221
- embed_dim=[160, 320, 480, 640],
222
- depth=[6, 8, 12, 14],
223
- num_heads=[4, 4, 4, 4],
224
  feature_dim=feature_dim,
225
  use_projection=use_projection,
226
  **kwargs
@@ -231,6 +259,7 @@ def lsnet_l_artist(num_classes=1000, distillation=False, pretrained=False,
231
  @register_model
232
  def lsnet_xl_artist(num_classes=1000, distillation=False, pretrained=False,
233
  feature_dim=None, use_projection=True, **kwargs):
 
234
  model = _create_lsnet_artist(
235
  "lsnet_xl_artist",
236
  pretrained=pretrained,
@@ -238,11 +267,32 @@ def lsnet_xl_artist(num_classes=1000, distillation=False, pretrained=False,
238
  distillation=distillation,
239
  img_size=224,
240
  patch_size=8,
241
- embed_dim=[192, 384, 576, 768],
242
- depth=[8, 12, 16, 20],
243
- num_heads=[6, 6, 6, 6],
244
  feature_dim=feature_dim,
245
  use_projection=use_projection,
246
  **kwargs
247
  )
248
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LSNet for Artist Style Classification and Clustering
3
+ 支持画师风格的分类和聚类任务
4
+ """
5
  import torch
6
  import torch.nn as nn
7
+ from timm.models import build_model_with_cfg, register_model
8
+
9
+ from .lsnet import BN_Linear, Conv2d_BN, LSNet
10
 
11
 
12
  class LSNetArtist(LSNet):
13
+ """
14
+ LSNet模型用于画师风格分类和聚类
15
+
16
+ 特点:
17
+ - 训练时使用分类头进行监督学习
18
+ - 推理时可选择是否使用分类头
19
+ - 去掉分类头输出特征向量用于聚类
20
+ - 保留分类头可以做风格分类
21
+ """
22
+
23
  def __init__(self,
24
  img_size=224,
25
  patch_size=8,
 
85
  x = self.projection(x)
86
  return x
87
 
88
+ def forward(self, x, return_features=False, return_both=False):
89
  """
90
+ 前向传播
 
 
91
 
92
+ Args:
93
+ x: 输入图像
94
+ return_features: 是否只返回特征向量(用于聚类)
95
+ False时返回分类logits(用于分类)
96
+ return_both: 是否同时返回特征向量和分类logits(用于对比损失)
97
+
98
+ Returns:
99
+ 如果return_features=True: 返回特征向量 (batch_size, feature_dim)
100
+ 如果return_both=True: 返回 (features, logits)
101
+ 如果return_features=False and return_both=False: 返回分类logits (batch_size, num_classes)
102
  """
103
  features = self.forward_features(x)
104
 
 
108
 
109
  # 返回分类结果
110
  if self.distillation:
111
+ logits = self.head(features), self.head_dist(features)
112
  if not self.training:
113
+ logits = (logits[0] + logits[1]) / 2
114
  else:
115
+ logits = self.head(features)
116
 
117
+ if return_both:
118
+ return features, logits
119
+
120
+ return logits
121
 
122
  def get_features(self, x):
123
  """
124
+ 便捷方法:提取特征向量
125
  """
126
  return self.forward(x, return_features=True)
127
 
128
  def classify(self, x):
129
  """
130
+ 便捷方法:进行分类
131
  """
132
  return self.forward(x, return_features=False)
133
 
 
152
  lsnet_t_artist = _cfg_artist(),
153
  lsnet_s_artist = _cfg_artist(),
154
  lsnet_b_artist = _cfg_artist(),
155
+ lsnet_l_artist = _cfg_artist(), # Large model for massive training
156
+ lsnet_xl_artist = _cfg_artist(), # Extra Large model for 100k+ classes
157
+ lsnet_xl_artist_448 = _cfg_artist(), # Extra Large model with 448x448 input for massive datasets with 50k+ classes
158
  )
159
 
160
 
 
175
  @register_model
176
  def lsnet_t_artist(num_classes=1000, distillation=False, pretrained=False,
177
  feature_dim=None, use_projection=True, **kwargs):
178
+ """LSNet-T for Artist Style Classification"""
179
  model = _create_lsnet_artist(
180
  "lsnet_t_artist",
181
  pretrained=pretrained,
 
196
  @register_model
197
  def lsnet_s_artist(num_classes=1000, distillation=False, pretrained=False,
198
  feature_dim=None, use_projection=True, **kwargs):
199
+ """LSNet-S for Artist Style Classification"""
200
  model = _create_lsnet_artist(
201
  "lsnet_s_artist",
202
  pretrained=pretrained,
 
217
  @register_model
218
  def lsnet_b_artist(num_classes=1000, distillation=False, pretrained=False,
219
  feature_dim=None, use_projection=True, **kwargs):
220
+ """LSNet-B for Artist Style Classification"""
221
  model = _create_lsnet_artist(
222
  "lsnet_b_artist",
223
  pretrained=pretrained,
 
238
  @register_model
239
  def lsnet_l_artist(num_classes=1000, distillation=False, pretrained=False,
240
  feature_dim=None, use_projection=True, **kwargs):
241
+ """LSNet-L for Artist Style Classification (Large model for massive training)"""
242
  model = _create_lsnet_artist(
243
  "lsnet_l_artist",
244
  pretrained=pretrained,
 
246
  distillation=distillation,
247
  img_size=224,
248
  patch_size=8,
249
+ embed_dim=[160, 320, 480, 640], # 更大的embed_dim
250
+ depth=[6, 8, 12, 14], # 更深的网络
251
+ num_heads=[4, 4, 4, 4], # 更多的注意力头
252
  feature_dim=feature_dim,
253
  use_projection=use_projection,
254
  **kwargs
 
259
  @register_model
260
  def lsnet_xl_artist(num_classes=1000, distillation=False, pretrained=False,
261
  feature_dim=None, use_projection=True, **kwargs):
262
+ """LSNet-XL for Artist Style Classification (Extra Large model for massive datasets with 100k+ classes)"""
263
  model = _create_lsnet_artist(
264
  "lsnet_xl_artist",
265
  pretrained=pretrained,
 
267
  distillation=distillation,
268
  img_size=224,
269
  patch_size=8,
270
+ embed_dim=[192, 384, 576, 768], # 超大embed_dim,支持10万+类别
271
+ depth=[8, 12, 16, 20], # 超深网络,学习复杂特征
272
+ num_heads=[6, 6, 6, 6], # 更多注意力头
273
  feature_dim=feature_dim,
274
  use_projection=use_projection,
275
  **kwargs
276
  )
277
  return model
278
+
279
+
280
+ @register_model
281
+ def lsnet_xl_artist_448(num_classes=50000, distillation=False, pretrained=False,
282
+ feature_dim=None, use_projection=True, **kwargs):
283
+ """LSNet-XL-448 for Artist Style Classification (Extra Large model with 448x448 input for massive datasets with 50k+ classes)"""
284
+ model = _create_lsnet_artist(
285
+ "lsnet_xl_artist_448",
286
+ pretrained=pretrained,
287
+ num_classes=num_classes,
288
+ distillation=distillation,
289
+ img_size=448,
290
+ patch_size=8,
291
+ embed_dim=[192, 384, 576, 768], # 超大embed_dim,支持10万+类别
292
+ depth=[8, 12, 16, 20], # 超深网络,学习复杂特征
293
+ num_heads=[6, 6, 6, 6], # 更多注意力头
294
+ feature_dim=feature_dim,
295
+ use_projection=use_projection,
296
+ **kwargs
297
+ )
298
+ return model