|
|
--- |
|
|
license: apache-2.0 |
|
|
library_name: mlx-image |
|
|
tags: |
|
|
- mlx |
|
|
- mlx-image |
|
|
- vision |
|
|
- image-classification |
|
|
datasets: |
|
|
- imagenet-1k |
|
|
--- |
|
|
|
|
|
# efficientnet_b0 |
|
|
|
|
|
An EfficientNet B0 model architecture, pretrained on ImageNet-1K. |
|
|
|
|
|
Disclaimer: this is a port of the Torchvision model weights to Apple MLX Framework. |
|
|
|
|
|
See [mlx-convert-scripts](https://github.com/lextoumbourou/mlx-convert-scripts) repo for the conversion script used. |
|
|
|
|
|
## How to use |
|
|
|
|
|
```bash |
|
|
pip install mlx-image |
|
|
``` |
|
|
|
|
|
Here is how to use this model for image classification: |
|
|
|
|
|
```python |
|
|
import mlx.core as mx |
|
|
from mlxim.model import create_model |
|
|
from mlxim.io import read_rgb |
|
|
from mlxim.transform import ImageNetTransform |
|
|
from mlxim.utils.imagenet import IMAGENET2012_CLASSES |
|
|
|
|
|
transform = ImageNetTransform(train=False, img_size=224) |
|
|
x = transform(read_rgb("cat.jpg")) |
|
|
x = mx.array(x) |
|
|
x = mx.expand_dims(x, 0) |
|
|
|
|
|
model = create_model("efficientnet_b0") |
|
|
model.eval() |
|
|
|
|
|
logits = model(x) |
|
|
predicted_idx = mx.argmax(logits, axis=-1).item() |
|
|
predicted_class = list(IMAGENET2012_CLASSES.values())[predicted_idx] |
|
|
|
|
|
print(f"Predicted class: {predicted_class}") |
|
|
``` |
|
|
|
|
|
You can also use the embeds from layer before head: |
|
|
|
|
|
```python |
|
|
import mlx.core as mx |
|
|
from mlxim.model import create_model |
|
|
from mlxim.io import read_rgb |
|
|
from mlxim.transform import ImageNetTransform |
|
|
|
|
|
transform = ImageNetTransform(train=False, img_size=224) |
|
|
x = transform(read_rgb("cat.jpg")) |
|
|
x = mx.array(x) |
|
|
x = mx.expand_dims(x, 0) |
|
|
|
|
|
# first option |
|
|
model = create_model("efficientnet_b0", num_classes=0) |
|
|
model.eval() |
|
|
|
|
|
embeds = model(x) |
|
|
|
|
|
# second option |
|
|
model = create_model("efficientnet_b0") |
|
|
model.eval() |
|
|
|
|
|
embeds = model.get_features(x) |
|
|
``` |
|
|
|