Create README.md
Browse files
README.md
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
inference: false
|
| 4 |
+
datasets:
|
| 5 |
+
- mnist
|
| 6 |
+
pipeline_tag: image-classification
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Perceiver IO image classifier (IMDb)
|
| 10 |
+
|
| 11 |
+
This model is a small Perceiver IO image classifier (907K parameters) trained from scratch on the [MNIST](https://huggingface.co/datasets/mnist)
|
| 12 |
+
dataset. It is a [training example](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#image-classification)
|
| 13 |
+
of the [perceiver-io](https://github.com/krasserm/perceiver-io) library.
|
| 14 |
+
|
| 15 |
+
## Model description
|
| 16 |
+
|
| 17 |
+
Like [krasserm/perceiver-io-img-clf](https://huggingface.co/krasserm/perceiver-io-img-clf) this model also uses 2D
|
| 18 |
+
Fourier features for position encoding and cross-attends to individual pixels of an input image but uses repeated
|
| 19 |
+
cross-attention, a configuration that was described in the original [Perceiver paper](https://arxiv.org/abs/2103.03206)
|
| 20 |
+
which has been dropped in the follow-up [Perceiver IO paper](https://arxiv.org/abs/2107.14795) (see
|
| 21 |
+
[building blocks](https://github.com/krasserm/perceiver-io/blob/main/docs/building-blocks.md) for more details).
|
| 22 |
+
|
| 23 |
+
## Model training
|
| 24 |
+
|
| 25 |
+
The model was [trained](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#image-classification)
|
| 26 |
+
with randomly initialized weights on the MNIST handwritten digits dataset. Images were normalized, data augmentations
|
| 27 |
+
were turned off. Training was done with [PyTorch Lightning](https://www.pytorchlightning.ai/index.html) and the resulting
|
| 28 |
+
checkpoint was converted to this 🤗 model with a library-specific [conversion utility](#checkpoint-conversion).
|
| 29 |
+
|
| 30 |
+
## Intended use and limitations
|
| 31 |
+
|
| 32 |
+
The model can be used for MNIST handwritten digit classification.
|
| 33 |
+
|
| 34 |
+
## Usage examples
|
| 35 |
+
|
| 36 |
+
To use this model you first need to [install](https://github.com/krasserm/perceiver-io/blob/main/README.md#installation)
|
| 37 |
+
the `perceiver-io` library with extension `vision`.
|
| 38 |
+
|
| 39 |
+
```shell
|
| 40 |
+
pip install perceiver-io[vision]
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
Then the model can be used with PyTorch. Either use the model and image processor directly
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
from datasets import load_dataset
|
| 47 |
+
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
| 48 |
+
from perceiver.model.vision import image_classifier # auto-class registration
|
| 49 |
+
|
| 50 |
+
repo_id = "krasserm/perceiver-io-img-clf-mnist"
|
| 51 |
+
|
| 52 |
+
mnist_dataset = load_dataset("mnist", split="test")[:9]
|
| 53 |
+
|
| 54 |
+
images = mnist_dataset["image"]
|
| 55 |
+
labels = mnist_dataset["label"]
|
| 56 |
+
|
| 57 |
+
model = AutoModelForImageClassification.from_pretrained(repo_id)
|
| 58 |
+
processor = AutoImageProcessor.from_pretrained(repo_id)
|
| 59 |
+
|
| 60 |
+
inputs = processor(images, return_tensors="pt")
|
| 61 |
+
logits = model(**inputs).logits
|
| 62 |
+
|
| 63 |
+
print(f"Labels: {labels}")
|
| 64 |
+
print(f"Predictions: {logits.argmax(dim=-1).numpy().tolist()}")
|
| 65 |
+
```
|
| 66 |
+
```
|
| 67 |
+
Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5]
|
| 68 |
+
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
or use an `image-classification` pipeline:
|
| 72 |
+
|
| 73 |
+
```python
|
| 74 |
+
from datasets import load_dataset
|
| 75 |
+
from transformers import pipeline
|
| 76 |
+
from perceiver.model.vision import image_classifier # auto-class registration
|
| 77 |
+
|
| 78 |
+
repo_id = "krasserm/perceiver-io-img-clf-mnist"
|
| 79 |
+
|
| 80 |
+
mnist_dataset = load_dataset("mnist", split="test")[:9]
|
| 81 |
+
|
| 82 |
+
images = mnist_dataset["image"]
|
| 83 |
+
labels = mnist_dataset["label"]
|
| 84 |
+
|
| 85 |
+
classifier = pipeline("image-classification", model=repo_id)
|
| 86 |
+
predictions = [int(pred[0]["label"]) for pred in classifier(images)]
|
| 87 |
+
|
| 88 |
+
print(f"Labels: {labels}")
|
| 89 |
+
print(f"Predictions: {predictions}")
|
| 90 |
+
```
|
| 91 |
+
```
|
| 92 |
+
Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5]
|
| 93 |
+
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## Checkpoint conversion
|
| 97 |
+
|
| 98 |
+
The `krasserm/perceiver-io-img-clf-mnist` model has been created from a training checkpoint with:
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint
|
| 102 |
+
|
| 103 |
+
convert_mnist_classifier_checkpoint(
|
| 104 |
+
save_dir="krasserm/perceiver-io-img-clf-mnist",
|
| 105 |
+
ckpt_url="https://martin-krasser.com/perceiver/logs-0.8.0/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
|
| 106 |
+
push_to_hub=True,
|
| 107 |
+
)
|
| 108 |
+
```
|