Commit 
							
							·
						
						eb977fa
	
1
								Parent(s):
							
							3bda970
								
update lsnet_artist.py in preparation of lsnet_xl_artist_448 arch
Browse files- lsnet/lsnet_artist.py +73 -23
    	
        lsnet/lsnet_artist.py
    CHANGED
    
    | @@ -1,11 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import torch
         | 
| 2 | 
             
            import torch.nn as nn
         | 
| 3 | 
            -
            from . | 
| 4 | 
            -
             | 
| 5 | 
            -
            from  | 
| 6 |  | 
| 7 |  | 
| 8 | 
             
            class LSNetArtist(LSNet):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 9 | 
             
                def __init__(self, 
         | 
| 10 | 
             
                             img_size=224,
         | 
| 11 | 
             
                             patch_size=8,
         | 
| @@ -71,14 +85,20 @@ class LSNetArtist(LSNet): | |
| 71 | 
             
                    x = self.projection(x)
         | 
| 72 | 
             
                    return x
         | 
| 73 |  | 
| 74 | 
            -
                def forward(self, x, return_features=False):
         | 
| 75 | 
             
                    """
         | 
| 76 | 
            -
                     | 
| 77 | 
            -
                    return_features: 是否只返回特征向量(用于聚类)
         | 
| 78 | 
            -
                                    False时返回分类logits(用于分类)
         | 
| 79 |  | 
| 80 | 
            -
                     | 
| 81 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 82 | 
             
                    """
         | 
| 83 | 
             
                    features = self.forward_features(x)
         | 
| 84 |  | 
| @@ -88,23 +108,26 @@ class LSNetArtist(LSNet): | |
| 88 |  | 
| 89 | 
             
                    # 返回分类结果
         | 
| 90 | 
             
                    if self.distillation:
         | 
| 91 | 
            -
                         | 
| 92 | 
             
                        if not self.training:
         | 
| 93 | 
            -
                             | 
| 94 | 
             
                    else:
         | 
| 95 | 
            -
                         | 
| 96 |  | 
| 97 | 
            -
                     | 
|  | |
|  | |
|  | |
| 98 |  | 
| 99 | 
             
                def get_features(self, x):
         | 
| 100 | 
             
                    """
         | 
| 101 | 
            -
                     | 
| 102 | 
             
                    """
         | 
| 103 | 
             
                    return self.forward(x, return_features=True)
         | 
| 104 |  | 
| 105 | 
             
                def classify(self, x):
         | 
| 106 | 
             
                    """
         | 
| 107 | 
            -
                     | 
| 108 | 
             
                    """
         | 
| 109 | 
             
                    return self.forward(x, return_features=False)
         | 
| 110 |  | 
| @@ -129,8 +152,9 @@ default_cfgs_artist = dict( | |
| 129 | 
             
                lsnet_t_artist = _cfg_artist(),
         | 
| 130 | 
             
                lsnet_s_artist = _cfg_artist(),
         | 
| 131 | 
             
                lsnet_b_artist = _cfg_artist(),
         | 
| 132 | 
            -
                lsnet_l_artist = _cfg_artist(),
         | 
| 133 | 
            -
                lsnet_xl_artist = _cfg_artist(),
         | 
|  | |
| 134 | 
             
            )
         | 
| 135 |  | 
| 136 |  | 
| @@ -151,6 +175,7 @@ def _create_lsnet_artist(variant, pretrained=False, **kwargs): | |
| 151 | 
             
            @register_model
         | 
| 152 | 
             
            def lsnet_t_artist(num_classes=1000, distillation=False, pretrained=False, 
         | 
| 153 | 
             
                               feature_dim=None, use_projection=True, **kwargs):
         | 
|  | |
| 154 | 
             
                model = _create_lsnet_artist(
         | 
| 155 | 
             
                    "lsnet_t_artist",
         | 
| 156 | 
             
                    pretrained=pretrained,
         | 
| @@ -171,6 +196,7 @@ def lsnet_t_artist(num_classes=1000, distillation=False, pretrained=False, | |
| 171 | 
             
            @register_model
         | 
| 172 | 
             
            def lsnet_s_artist(num_classes=1000, distillation=False, pretrained=False,
         | 
| 173 | 
             
                               feature_dim=None, use_projection=True, **kwargs):
         | 
|  | |
| 174 | 
             
                model = _create_lsnet_artist(
         | 
| 175 | 
             
                    "lsnet_s_artist",
         | 
| 176 | 
             
                    pretrained=pretrained,
         | 
| @@ -191,6 +217,7 @@ def lsnet_s_artist(num_classes=1000, distillation=False, pretrained=False, | |
| 191 | 
             
            @register_model
         | 
| 192 | 
             
            def lsnet_b_artist(num_classes=1000, distillation=False, pretrained=False,
         | 
| 193 | 
             
                               feature_dim=None, use_projection=True, **kwargs):
         | 
|  | |
| 194 | 
             
                model = _create_lsnet_artist(
         | 
| 195 | 
             
                    "lsnet_b_artist",
         | 
| 196 | 
             
                    pretrained=pretrained,
         | 
| @@ -211,6 +238,7 @@ def lsnet_b_artist(num_classes=1000, distillation=False, pretrained=False, | |
| 211 | 
             
            @register_model
         | 
| 212 | 
             
            def lsnet_l_artist(num_classes=1000, distillation=False, pretrained=False,
         | 
| 213 | 
             
                               feature_dim=None, use_projection=True, **kwargs):
         | 
|  | |
| 214 | 
             
                model = _create_lsnet_artist(
         | 
| 215 | 
             
                    "lsnet_l_artist",
         | 
| 216 | 
             
                    pretrained=pretrained,
         | 
| @@ -218,9 +246,9 @@ def lsnet_l_artist(num_classes=1000, distillation=False, pretrained=False, | |
| 218 | 
             
                    distillation=distillation,
         | 
| 219 | 
             
                    img_size=224,
         | 
| 220 | 
             
                    patch_size=8,
         | 
| 221 | 
            -
                    embed_dim=[160, 320, 480, 640],
         | 
| 222 | 
            -
                    depth=[6, 8, 12, 14],
         | 
| 223 | 
            -
                    num_heads=[4, 4, 4, 4],
         | 
| 224 | 
             
                    feature_dim=feature_dim,
         | 
| 225 | 
             
                    use_projection=use_projection,
         | 
| 226 | 
             
                    **kwargs
         | 
| @@ -231,6 +259,7 @@ def lsnet_l_artist(num_classes=1000, distillation=False, pretrained=False, | |
| 231 | 
             
            @register_model
         | 
| 232 | 
             
            def lsnet_xl_artist(num_classes=1000, distillation=False, pretrained=False,
         | 
| 233 | 
             
                                feature_dim=None, use_projection=True, **kwargs):
         | 
|  | |
| 234 | 
             
                model = _create_lsnet_artist(
         | 
| 235 | 
             
                    "lsnet_xl_artist",
         | 
| 236 | 
             
                    pretrained=pretrained,
         | 
| @@ -238,11 +267,32 @@ def lsnet_xl_artist(num_classes=1000, distillation=False, pretrained=False, | |
| 238 | 
             
                    distillation=distillation,
         | 
| 239 | 
             
                    img_size=224,
         | 
| 240 | 
             
                    patch_size=8,
         | 
| 241 | 
            -
                    embed_dim=[192, 384, 576, 768],
         | 
| 242 | 
            -
                    depth=[8, 12, 16, 20],
         | 
| 243 | 
            -
                    num_heads=[6, 6, 6, 6],
         | 
| 244 | 
             
                    feature_dim=feature_dim,
         | 
| 245 | 
             
                    use_projection=use_projection,
         | 
| 246 | 
             
                    **kwargs
         | 
| 247 | 
             
                )
         | 
| 248 | 
             
                return model
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            LSNet for Artist Style Classification and Clustering
         | 
| 3 | 
            +
            支持画师风格的分类和聚类任务
         | 
| 4 | 
            +
            """
         | 
| 5 | 
             
            import torch
         | 
| 6 | 
             
            import torch.nn as nn
         | 
| 7 | 
            +
            from timm.models import build_model_with_cfg, register_model
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .lsnet import BN_Linear, Conv2d_BN, LSNet
         | 
| 10 |  | 
| 11 |  | 
| 12 | 
             
            class LSNetArtist(LSNet):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                LSNet模型用于画师风格分类和聚类
         | 
| 15 | 
            +
                
         | 
| 16 | 
            +
                特点:
         | 
| 17 | 
            +
                - 训练时使用分类头进行监督学习
         | 
| 18 | 
            +
                - 推理时可选择是否使用分类头
         | 
| 19 | 
            +
                - 去掉分类头输出特征向量用于聚类
         | 
| 20 | 
            +
                - 保留分类头可以做风格分类
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                
         | 
| 23 | 
             
                def __init__(self, 
         | 
| 24 | 
             
                             img_size=224,
         | 
| 25 | 
             
                             patch_size=8,
         | 
|  | |
| 85 | 
             
                    x = self.projection(x)
         | 
| 86 | 
             
                    return x
         | 
| 87 |  | 
| 88 | 
            +
                def forward(self, x, return_features=False, return_both=False):
         | 
| 89 | 
             
                    """
         | 
| 90 | 
            +
                    前向传播
         | 
|  | |
|  | |
| 91 |  | 
| 92 | 
            +
                    Args:
         | 
| 93 | 
            +
                        x: 输入图像
         | 
| 94 | 
            +
                        return_features: 是否只返回特征向量(用于聚类)
         | 
| 95 | 
            +
                                       False时返回分类logits(用于分类)
         | 
| 96 | 
            +
                        return_both: 是否同时返回特征向量和分类logits(用于对比损失)
         | 
| 97 | 
            +
                    
         | 
| 98 | 
            +
                    Returns:
         | 
| 99 | 
            +
                        如果return_features=True: 返回特征向量 (batch_size, feature_dim)
         | 
| 100 | 
            +
                        如果return_both=True: 返回 (features, logits)
         | 
| 101 | 
            +
                        如果return_features=False and return_both=False: 返回分类logits (batch_size, num_classes)
         | 
| 102 | 
             
                    """
         | 
| 103 | 
             
                    features = self.forward_features(x)
         | 
| 104 |  | 
|  | |
| 108 |  | 
| 109 | 
             
                    # 返回分类结果
         | 
| 110 | 
             
                    if self.distillation:
         | 
| 111 | 
            +
                        logits = self.head(features), self.head_dist(features)
         | 
| 112 | 
             
                        if not self.training:
         | 
| 113 | 
            +
                            logits = (logits[0] + logits[1]) / 2
         | 
| 114 | 
             
                    else:
         | 
| 115 | 
            +
                        logits = self.head(features)
         | 
| 116 |  | 
| 117 | 
            +
                    if return_both:
         | 
| 118 | 
            +
                        return features, logits
         | 
| 119 | 
            +
                    
         | 
| 120 | 
            +
                    return logits
         | 
| 121 |  | 
| 122 | 
             
                def get_features(self, x):
         | 
| 123 | 
             
                    """
         | 
| 124 | 
            +
                    便捷方法:提取特征向量
         | 
| 125 | 
             
                    """
         | 
| 126 | 
             
                    return self.forward(x, return_features=True)
         | 
| 127 |  | 
| 128 | 
             
                def classify(self, x):
         | 
| 129 | 
             
                    """
         | 
| 130 | 
            +
                    便捷方法:进行分类
         | 
| 131 | 
             
                    """
         | 
| 132 | 
             
                    return self.forward(x, return_features=False)
         | 
| 133 |  | 
|  | |
| 152 | 
             
                lsnet_t_artist = _cfg_artist(),
         | 
| 153 | 
             
                lsnet_s_artist = _cfg_artist(),
         | 
| 154 | 
             
                lsnet_b_artist = _cfg_artist(),
         | 
| 155 | 
            +
                lsnet_l_artist = _cfg_artist(),  # Large model for massive training
         | 
| 156 | 
            +
                lsnet_xl_artist = _cfg_artist(), # Extra Large model for 100k+ classes
         | 
| 157 | 
            +
                lsnet_xl_artist_448 = _cfg_artist(), # Extra Large model with 448x448 input for massive datasets with 50k+ classes
         | 
| 158 | 
             
            )
         | 
| 159 |  | 
| 160 |  | 
|  | |
| 175 | 
             
            @register_model
         | 
| 176 | 
             
            def lsnet_t_artist(num_classes=1000, distillation=False, pretrained=False, 
         | 
| 177 | 
             
                               feature_dim=None, use_projection=True, **kwargs):
         | 
| 178 | 
            +
                """LSNet-T for Artist Style Classification"""
         | 
| 179 | 
             
                model = _create_lsnet_artist(
         | 
| 180 | 
             
                    "lsnet_t_artist",
         | 
| 181 | 
             
                    pretrained=pretrained,
         | 
|  | |
| 196 | 
             
            @register_model
         | 
| 197 | 
             
            def lsnet_s_artist(num_classes=1000, distillation=False, pretrained=False,
         | 
| 198 | 
             
                               feature_dim=None, use_projection=True, **kwargs):
         | 
| 199 | 
            +
                """LSNet-S for Artist Style Classification"""
         | 
| 200 | 
             
                model = _create_lsnet_artist(
         | 
| 201 | 
             
                    "lsnet_s_artist",
         | 
| 202 | 
             
                    pretrained=pretrained,
         | 
|  | |
| 217 | 
             
            @register_model
         | 
| 218 | 
             
            def lsnet_b_artist(num_classes=1000, distillation=False, pretrained=False,
         | 
| 219 | 
             
                               feature_dim=None, use_projection=True, **kwargs):
         | 
| 220 | 
            +
                """LSNet-B for Artist Style Classification"""
         | 
| 221 | 
             
                model = _create_lsnet_artist(
         | 
| 222 | 
             
                    "lsnet_b_artist",
         | 
| 223 | 
             
                    pretrained=pretrained,
         | 
|  | |
| 238 | 
             
            @register_model
         | 
| 239 | 
             
            def lsnet_l_artist(num_classes=1000, distillation=False, pretrained=False,
         | 
| 240 | 
             
                               feature_dim=None, use_projection=True, **kwargs):
         | 
| 241 | 
            +
                """LSNet-L for Artist Style Classification (Large model for massive training)"""
         | 
| 242 | 
             
                model = _create_lsnet_artist(
         | 
| 243 | 
             
                    "lsnet_l_artist",
         | 
| 244 | 
             
                    pretrained=pretrained,
         | 
|  | |
| 246 | 
             
                    distillation=distillation,
         | 
| 247 | 
             
                    img_size=224,
         | 
| 248 | 
             
                    patch_size=8,
         | 
| 249 | 
            +
                    embed_dim=[160, 320, 480, 640],  # 更大的embed_dim
         | 
| 250 | 
            +
                    depth=[6, 8, 12, 14],           # 更深的网络
         | 
| 251 | 
            +
                    num_heads=[4, 4, 4, 4],          # 更多的注意力头
         | 
| 252 | 
             
                    feature_dim=feature_dim,
         | 
| 253 | 
             
                    use_projection=use_projection,
         | 
| 254 | 
             
                    **kwargs
         | 
|  | |
| 259 | 
             
            @register_model
         | 
| 260 | 
             
            def lsnet_xl_artist(num_classes=1000, distillation=False, pretrained=False,
         | 
| 261 | 
             
                                feature_dim=None, use_projection=True, **kwargs):
         | 
| 262 | 
            +
                """LSNet-XL for Artist Style Classification (Extra Large model for massive datasets with 100k+ classes)"""
         | 
| 263 | 
             
                model = _create_lsnet_artist(
         | 
| 264 | 
             
                    "lsnet_xl_artist",
         | 
| 265 | 
             
                    pretrained=pretrained,
         | 
|  | |
| 267 | 
             
                    distillation=distillation,
         | 
| 268 | 
             
                    img_size=224,
         | 
| 269 | 
             
                    patch_size=8,
         | 
| 270 | 
            +
                    embed_dim=[192, 384, 576, 768],  # 超大embed_dim,支持10万+类别
         | 
| 271 | 
            +
                    depth=[8, 12, 16, 20],           # 超深网络,学习复杂特征
         | 
| 272 | 
            +
                    num_heads=[6, 6, 6, 6],           # 更多注意力头
         | 
| 273 | 
             
                    feature_dim=feature_dim,
         | 
| 274 | 
             
                    use_projection=use_projection,
         | 
| 275 | 
             
                    **kwargs
         | 
| 276 | 
             
                )
         | 
| 277 | 
             
                return model
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
            @register_model
         | 
| 281 | 
            +
            def lsnet_xl_artist_448(num_classes=50000, distillation=False, pretrained=False,
         | 
| 282 | 
            +
                                    feature_dim=None, use_projection=True, **kwargs):
         | 
| 283 | 
            +
                """LSNet-XL-448 for Artist Style Classification (Extra Large model with 448x448 input for massive datasets with 50k+ classes)"""
         | 
| 284 | 
            +
                model = _create_lsnet_artist(
         | 
| 285 | 
            +
                    "lsnet_xl_artist_448",
         | 
| 286 | 
            +
                    pretrained=pretrained,
         | 
| 287 | 
            +
                    num_classes=num_classes, 
         | 
| 288 | 
            +
                    distillation=distillation,
         | 
| 289 | 
            +
                    img_size=448,
         | 
| 290 | 
            +
                    patch_size=8,
         | 
| 291 | 
            +
                    embed_dim=[192, 384, 576, 768],  # 超大embed_dim,支持10万+类别
         | 
| 292 | 
            +
                    depth=[8, 12, 16, 20],           # 超深网络,学习复杂特征
         | 
| 293 | 
            +
                    num_heads=[6, 6, 6, 6],           # 更多注意力头
         | 
| 294 | 
            +
                    feature_dim=feature_dim,
         | 
| 295 | 
            +
                    use_projection=use_projection,
         | 
| 296 | 
            +
                    **kwargs
         | 
| 297 | 
            +
                )
         | 
| 298 | 
            +
                return model
         | 
