File size: 3,909 Bytes
dfde5b2 0ea2d66 dfde5b2 610cfdf dfde5b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
---
license: apache-2.0
inference: false
datasets:
- mnist
pipeline_tag: image-classification
---
# Perceiver IO image classifier (MNIST)
This model is a small Perceiver IO image classifier (907K parameters) trained from scratch on the [MNIST](https://huggingface.co/datasets/mnist)
dataset. It is a [training example](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#image-classification)
of the [perceiver-io](https://github.com/krasserm/perceiver-io) library.
## Model description
Like [krasserm/perceiver-io-img-clf](https://huggingface.co/krasserm/perceiver-io-img-clf) this model also uses 2D
Fourier features for position encoding and cross-attends to individual pixels of an input image but uses repeated
cross-attention, a configuration that was described in the original [Perceiver paper](https://arxiv.org/abs/2103.03206)
which has been dropped in the follow-up [Perceiver IO paper](https://arxiv.org/abs/2107.14795) (see
[building blocks](https://github.com/krasserm/perceiver-io/blob/main/docs/building-blocks.md) for more details).
## Model training
The model was [trained](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#image-classification)
with randomly initialized weights on the MNIST handwritten digits dataset. Images were normalized, data augmentations
were turned off. Training was done with [PyTorch Lightning](https://www.pytorchlightning.ai/index.html) and the resulting
checkpoint was converted to this 🤗 model with a library-specific [conversion utility](#checkpoint-conversion).
## Intended use and limitations
The model can be used for MNIST handwritten digit classification.
## Usage examples
To use this model you first need to [install](https://github.com/krasserm/perceiver-io/blob/main/README.md#installation)
the `perceiver-io` library with extension `vision`.
```shell
pip install perceiver-io[vision]
```
Then the model can be used with PyTorch. Either use the model and image processor directly
```python
from datasets import load_dataset
from transformers import AutoModelForImageClassification, AutoImageProcessor
from perceiver.model.vision import image_classifier # auto-class registration
repo_id = "krasserm/perceiver-io-img-clf-mnist"
mnist_dataset = load_dataset("mnist", split="test")[:9]
images = mnist_dataset["image"]
labels = mnist_dataset["label"]
model = AutoModelForImageClassification.from_pretrained(repo_id)
processor = AutoImageProcessor.from_pretrained(repo_id)
inputs = processor(images, return_tensors="pt")
logits = model(**inputs).logits
print(f"Labels: {labels}")
print(f"Predictions: {logits.argmax(dim=-1).numpy().tolist()}")
```
```
Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
```
or use an `image-classification` pipeline:
```python
from datasets import load_dataset
from transformers import pipeline
from perceiver.model.vision import image_classifier # auto-class registration
repo_id = "krasserm/perceiver-io-img-clf-mnist"
mnist_dataset = load_dataset("mnist", split="test")[:9]
images = mnist_dataset["image"]
labels = mnist_dataset["label"]
classifier = pipeline("image-classification", model=repo_id)
predictions = [pred[0]["label"] for pred in classifier(images)]
print(f"Labels: {labels}")
print(f"Predictions: {predictions}")
```
```
Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
```
## Checkpoint conversion
The `krasserm/perceiver-io-img-clf-mnist` model has been created from a training checkpoint with:
```python
from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint
convert_mnist_classifier_checkpoint(
save_dir="krasserm/perceiver-io-img-clf-mnist",
ckpt_url="https://martin-krasser.com/perceiver/logs-0.8.0/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
push_to_hub=True,
)
``` |