Attention Branch Network (ABN) — ResNet152 on Imagenette

このモデルは、Attention Branch Network(ABN)を ResNet 系バックボーン上に実装し、torchvision.datasets.Imagenette(10クラス)で学習した画像分類モデルです。ABN により、予測に寄与した領域をアテンションマップとして可視化できます。

Attention Maps

  • 学習対象: Imagenette 10クラス分類(公式 train/val 分割)
  • バックボーン: ResNet18/34/50/101/152(本モデルは ResNet152)
  • 可視化: 原画像とアテンション重畳のペアをグリッド出力(visualize.py

結果(検証)

  • Loss: 0.6205
  • Top-1: 0.9047
  • Top-5: 0.9921
  • Epochs: 90

モデルの説明

ABN は、分類のための Perception Branch に加えて、説明可能性のための Attention Branch を設け、注意マップを生成します。本実装では ResNet の layer3 出力に対して Attention Branch(att_layer4 など)を構築し、得られた注意マップで特徴マップを強調した上で最終分類を行います。

  • 出力: logits(分類用)、内部に model.attention_map(形状: B×1×H×W)を保持
  • 損失: CrossEntropyLoss(att_logits) + CrossEntropyLoss(per_logits) の和
  • AutoClass 対応: trust_remote_code=TrueAutoModelForImageClassification からロード可能

想定される用途

  • 研究用途・可視化用途(画像分類+解釈)
  • 対象データ: Imagenette(10クラス)。他クラス数やデータセットへ適用する場合は再学習が必要

学習・評価データ

  • データセット: torchvision.datasets.Imagenette
  • 分割: 公式 train / val
  • 前処理(train): RandomResizedCrop(224)RandomHorizontalFlip()ToTensor()Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
  • 前処理(val): Resize(256)CenterCrop(224)ToTensor()Normalize(...)

主要ハイパーパラメータ

  • エポック数: 90
  • バッチサイズ: train=64, eval=100
  • 学習率: 0.1
  • 最適化: SGD(momentum=0.9, weight_decay=1e-4)
  • LR スケジュール: エポック 31/61 で gamma=0.1 を乗じるステップ方式(LambdaLR 実装)
  • 乱数シード: 42
  • ワーカー数: 4

使い方(推論)

from transformers import AutoModelForImageClassification
import torchvision.transforms as T
import torch

model = AutoModelForImageClassification.from_pretrained(
    "yukiharada1228/abn-resnet152-imagenette",
    trust_remote_code=True,
)
model.eval()

transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# pixel_values: (B,3,224,224)
with torch.no_grad():
    outputs = model(pixel_values=pixel_values)
    logits = outputs.logits
    att_map = model.model.attention_map  # (B,1,H,W)

可視化は visualize.py を使用:

uv run visualize.py --ckpt yukiharada1228/abn-resnet152-imagenette --out-dir outputs --prefix abn

フレームワーク・依存関係

  • Transformers 4.57.0
  • Pytorch 2.8.0+cu128
  • Tokenizers 0.22.1
  • その他: accelerate, numpy, matplotlib, opencv-python, tensorboardx

リポジトリ・コード

ライセンス / 謝辞

  • ライセンス: リポジトリの LICENSE に従います
  • 参考: "Attention Branch Network: Learning of Attention Mechanism for Visual Explanation"(MIT License)
Downloads last month
71
Safetensors
Model size
73.3M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Evaluation results