krasserm commited on
Commit
dfde5b2
·
1 Parent(s): 9d45ef8

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +108 -0
README.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ inference: false
4
+ datasets:
5
+ - mnist
6
+ pipeline_tag: image-classification
7
+ ---
8
+
9
+ # Perceiver IO image classifier (IMDb)
10
+
11
+ This model is a small Perceiver IO image classifier (907K parameters) trained from scratch on the [MNIST](https://huggingface.co/datasets/mnist)
12
+ dataset. It is a [training example](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#image-classification)
13
+ of the [perceiver-io](https://github.com/krasserm/perceiver-io) library.
14
+
15
+ ## Model description
16
+
17
+ Like [krasserm/perceiver-io-img-clf](https://huggingface.co/krasserm/perceiver-io-img-clf) this model also uses 2D
18
+ Fourier features for position encoding and cross-attends to individual pixels of an input image but uses repeated
19
+ cross-attention, a configuration that was described in the original [Perceiver paper](https://arxiv.org/abs/2103.03206)
20
+ which has been dropped in the follow-up [Perceiver IO paper](https://arxiv.org/abs/2107.14795) (see
21
+ [building blocks](https://github.com/krasserm/perceiver-io/blob/main/docs/building-blocks.md) for more details).
22
+
23
+ ## Model training
24
+
25
+ The model was [trained](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#image-classification)
26
+ with randomly initialized weights on the MNIST handwritten digits dataset. Images were normalized, data augmentations
27
+ were turned off. Training was done with [PyTorch Lightning](https://www.pytorchlightning.ai/index.html) and the resulting
28
+ checkpoint was converted to this 🤗 model with a library-specific [conversion utility](#checkpoint-conversion).
29
+
30
+ ## Intended use and limitations
31
+
32
+ The model can be used for MNIST handwritten digit classification.
33
+
34
+ ## Usage examples
35
+
36
+ To use this model you first need to [install](https://github.com/krasserm/perceiver-io/blob/main/README.md#installation)
37
+ the `perceiver-io` library with extension `vision`.
38
+
39
+ ```shell
40
+ pip install perceiver-io[vision]
41
+ ```
42
+
43
+ Then the model can be used with PyTorch. Either use the model and image processor directly
44
+
45
+ ```python
46
+ from datasets import load_dataset
47
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
48
+ from perceiver.model.vision import image_classifier # auto-class registration
49
+
50
+ repo_id = "krasserm/perceiver-io-img-clf-mnist"
51
+
52
+ mnist_dataset = load_dataset("mnist", split="test")[:9]
53
+
54
+ images = mnist_dataset["image"]
55
+ labels = mnist_dataset["label"]
56
+
57
+ model = AutoModelForImageClassification.from_pretrained(repo_id)
58
+ processor = AutoImageProcessor.from_pretrained(repo_id)
59
+
60
+ inputs = processor(images, return_tensors="pt")
61
+ logits = model(**inputs).logits
62
+
63
+ print(f"Labels: {labels}")
64
+ print(f"Predictions: {logits.argmax(dim=-1).numpy().tolist()}")
65
+ ```
66
+ ```
67
+ Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5]
68
+ Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
69
+ ```
70
+
71
+ or use an `image-classification` pipeline:
72
+
73
+ ```python
74
+ from datasets import load_dataset
75
+ from transformers import pipeline
76
+ from perceiver.model.vision import image_classifier # auto-class registration
77
+
78
+ repo_id = "krasserm/perceiver-io-img-clf-mnist"
79
+
80
+ mnist_dataset = load_dataset("mnist", split="test")[:9]
81
+
82
+ images = mnist_dataset["image"]
83
+ labels = mnist_dataset["label"]
84
+
85
+ classifier = pipeline("image-classification", model=repo_id)
86
+ predictions = [int(pred[0]["label"]) for pred in classifier(images)]
87
+
88
+ print(f"Labels: {labels}")
89
+ print(f"Predictions: {predictions}")
90
+ ```
91
+ ```
92
+ Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5]
93
+ Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
94
+ ```
95
+
96
+ ## Checkpoint conversion
97
+
98
+ The `krasserm/perceiver-io-img-clf-mnist` model has been created from a training checkpoint with:
99
+
100
+ ```python
101
+ from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint
102
+
103
+ convert_mnist_classifier_checkpoint(
104
+ save_dir="krasserm/perceiver-io-img-clf-mnist",
105
+ ckpt_url="https://martin-krasser.com/perceiver/logs-0.8.0/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
106
+ push_to_hub=True,
107
+ )
108
+ ```