--- license: apache-2.0 library_name: mlx-image tags: - mlx - mlx-image - vision - image-classification datasets: - imagenet-1k --- # efficientnet_b0 An EfficientNet B0 model architecture, pretrained on ImageNet-1K. Disclaimer: this is a port of the Torchvision model weights to Apple MLX Framework. See [mlx-convert-scripts](https://github.com/lextoumbourou/mlx-convert-scripts) repo for the conversion script used. ## How to use ```bash pip install mlx-image ``` Here is how to use this model for image classification: ```python import mlx.core as mx from mlxim.model import create_model from mlxim.io import read_rgb from mlxim.transform import ImageNetTransform from mlxim.utils.imagenet import IMAGENET2012_CLASSES transform = ImageNetTransform(train=False, img_size=224) x = transform(read_rgb("cat.jpg")) x = mx.array(x) x = mx.expand_dims(x, 0) model = create_model("efficientnet_b0") model.eval() logits = model(x) predicted_idx = mx.argmax(logits, axis=-1).item() predicted_class = list(IMAGENET2012_CLASSES.values())[predicted_idx] print(f"Predicted class: {predicted_class}") ``` You can also use the embeds from layer before head: ```python import mlx.core as mx from mlxim.model import create_model from mlxim.io import read_rgb from mlxim.transform import ImageNetTransform transform = ImageNetTransform(train=False, img_size=224) x = transform(read_rgb("cat.jpg")) x = mx.array(x) x = mx.expand_dims(x, 0) # first option model = create_model("efficientnet_b0", num_classes=0) model.eval() embeds = model(x) # second option model = create_model("efficientnet_b0") model.eval() embeds = model.get_features(x) ```