File size: 3,909 Bytes
dfde5b2
 
 
 
 
 
 
 
0ea2d66
dfde5b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610cfdf
dfde5b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
---
license: apache-2.0
inference: false
datasets:
- mnist
pipeline_tag: image-classification
---

# Perceiver IO image classifier (MNIST)

This model is a small Perceiver IO image classifier (907K parameters) trained from scratch on the [MNIST](https://huggingface.co/datasets/mnist)
dataset. It is a [training example](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#image-classification)
of the [perceiver-io](https://github.com/krasserm/perceiver-io) library.

## Model description

Like [krasserm/perceiver-io-img-clf](https://huggingface.co/krasserm/perceiver-io-img-clf) this model also uses 2D 
Fourier features for position encoding and cross-attends to individual pixels of an input image but uses repeated 
cross-attention, a configuration that was described in the original [Perceiver paper](https://arxiv.org/abs/2103.03206) 
which has been dropped in the follow-up [Perceiver IO paper](https://arxiv.org/abs/2107.14795) (see 
[building blocks](https://github.com/krasserm/perceiver-io/blob/main/docs/building-blocks.md) for more details). 

## Model training

The model was [trained](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#image-classification)
with randomly initialized weights on the MNIST handwritten digits dataset. Images were normalized, data augmentations 
were turned off. Training was done with [PyTorch Lightning](https://www.pytorchlightning.ai/index.html) and the resulting
checkpoint was converted to this 🤗 model with a library-specific [conversion utility](#checkpoint-conversion).

## Intended use and limitations

The model can be used for MNIST handwritten digit classification.

## Usage examples

To use this model you first need to [install](https://github.com/krasserm/perceiver-io/blob/main/README.md#installation) 
the `perceiver-io` library with extension `vision`.

```shell
pip install perceiver-io[vision]
```

Then the model can be used with PyTorch. Either use the model and image processor directly

```python
from datasets import load_dataset
from transformers import AutoModelForImageClassification, AutoImageProcessor
from perceiver.model.vision import image_classifier  # auto-class registration

repo_id = "krasserm/perceiver-io-img-clf-mnist"

mnist_dataset = load_dataset("mnist", split="test")[:9]

images = mnist_dataset["image"]
labels = mnist_dataset["label"]

model = AutoModelForImageClassification.from_pretrained(repo_id)
processor = AutoImageProcessor.from_pretrained(repo_id)

inputs = processor(images, return_tensors="pt")
logits = model(**inputs).logits

print(f"Labels:      {labels}")
print(f"Predictions: {logits.argmax(dim=-1).numpy().tolist()}")
```
```
Labels:      [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
```

or use an `image-classification` pipeline:

```python
from datasets import load_dataset
from transformers import pipeline
from perceiver.model.vision import image_classifier  # auto-class registration

repo_id = "krasserm/perceiver-io-img-clf-mnist"

mnist_dataset = load_dataset("mnist", split="test")[:9]

images = mnist_dataset["image"]
labels = mnist_dataset["label"]

classifier = pipeline("image-classification", model=repo_id)
predictions = [pred[0]["label"] for pred in classifier(images)]

print(f"Labels:      {labels}")
print(f"Predictions: {predictions}")
```
```
Labels:      [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
```

## Checkpoint conversion

The `krasserm/perceiver-io-img-clf-mnist` model has been created from a training checkpoint with: 

```python
from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint

convert_mnist_classifier_checkpoint(
    save_dir="krasserm/perceiver-io-img-clf-mnist",
    ckpt_url="https://martin-krasser.com/perceiver/logs-0.8.0/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
    push_to_hub=True,
)
```