Attention Branch Network (ABN) — ResNet152 on Imagenette
このモデルは、Attention Branch Network(ABN)を ResNet 系バックボーン上に実装し、torchvision.datasets.Imagenette(10クラス)で学習した画像分類モデルです。ABN により、予測に寄与した領域をアテンションマップとして可視化できます。
- 学習対象: Imagenette 10クラス分類(公式
train/val分割) - バックボーン: ResNet18/34/50/101/152(本モデルは ResNet152)
- 可視化: 原画像とアテンション重畳のペアをグリッド出力(
visualize.py)
結果(検証)
- Loss: 0.6205
- Top-1: 0.9047
- Top-5: 0.9921
- Epochs: 90
モデルの説明
ABN は、分類のための Perception Branch に加えて、説明可能性のための Attention Branch を設け、注意マップを生成します。本実装では ResNet の layer3 出力に対して Attention Branch(att_layer4 など)を構築し、得られた注意マップで特徴マップを強調した上で最終分類を行います。
- 出力:
logits(分類用)、内部にmodel.attention_map(形状: B×1×H×W)を保持 - 損失:
CrossEntropyLoss(att_logits) + CrossEntropyLoss(per_logits)の和 - AutoClass 対応:
trust_remote_code=TrueでAutoModelForImageClassificationからロード可能
想定される用途
- 研究用途・可視化用途(画像分類+解釈)
- 対象データ: Imagenette(10クラス)。他クラス数やデータセットへ適用する場合は再学習が必要
学習・評価データ
- データセット:
torchvision.datasets.Imagenette - 分割: 公式
train/val - 前処理(train):
RandomResizedCrop(224)、RandomHorizontalFlip()、ToTensor()、Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) - 前処理(val):
Resize(256)、CenterCrop(224)、ToTensor()、Normalize(...)
主要ハイパーパラメータ
- エポック数: 90
- バッチサイズ: train=64, eval=100
- 学習率: 0.1
- 最適化: SGD(momentum=0.9, weight_decay=1e-4)
- LR スケジュール: エポック 31/61 で
gamma=0.1を乗じるステップ方式(LambdaLR実装) - 乱数シード: 42
- ワーカー数: 4
使い方(推論)
from transformers import AutoModelForImageClassification
import torchvision.transforms as T
import torch
model = AutoModelForImageClassification.from_pretrained(
"yukiharada1228/abn-resnet152-imagenette",
trust_remote_code=True,
)
model.eval()
transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# pixel_values: (B,3,224,224)
with torch.no_grad():
outputs = model(pixel_values=pixel_values)
logits = outputs.logits
att_map = model.model.attention_map # (B,1,H,W)
可視化は visualize.py を使用:
uv run visualize.py --ckpt yukiharada1228/abn-resnet152-imagenette --out-dir outputs --prefix abn
フレームワーク・依存関係
- Transformers 4.57.0
- Pytorch 2.8.0+cu128
- Tokenizers 0.22.1
- その他: accelerate, numpy, matplotlib, opencv-python, tensorboardx
リポジトリ・コード
- 実装コード: https://github.com/yukiharada1228/attention-branch-network
- 学習・可視化スクリプト、モデル定義、使用方法の詳細は上記リポジトリを参照
ライセンス / 謝辞
- ライセンス: リポジトリの
LICENSEに従います - 参考: "Attention Branch Network: Learning of Attention Mechanism for Visual Explanation"(MIT License)
- Downloads last month
- 71
Evaluation results
- Top-1 Accuracy on Imagenetteself-reported0.905
- Top-5 Accuracy on Imagenetteself-reported0.992
- Validation Loss on Imagenetteself-reported0.621
