Revisiting AlexNet: Achieving High-Accuracy on CIFAR-10 with Modern Optimization Techniques
Model Description
This repository contains a TensorFlow/Keras implementation of the AlexNet architecture, optimized and trained from scratch on the CIFAR-10 dataset. The original 2012 architecture has been modernized by replacing Local Response Normalization (LRN) layers with Batch Normalization and incorporating robust regularization techniques like L2 weight decay and aggressive data augmentation.
The model was developed in a Kaggle environment, demonstrating a reproducible workflow for achieving high accuracy on a benchmark computer vision task.
Model Details
| Detail | Value |
|---|---|
| Architecture | Modified AlexNet with Batch Normalization |
| Parameters | ~46 million |
| Framework | TensorFlow / Keras |
| Task | Image Classification |
| Original Paper | Revisiting AlexNet: Achieving High-Accuracy on CIFAR-10 with Modern Optimization Techniques |
| Original Paper | ImageNet Classification with Deep Convolutional Neural Networks |
Training Procedure
Data
The model was trained on the CIFAR-10 dataset.
- Preprocessing: All images were resized from 32x32 to 224x224 and pixel values were normalized to a [0, 1] range.
- Augmentation: The training data was augmented on-the-fly with Random Horizontal Flips, Random Rotations (10%), and Random Zooms (10%).
Hyperparameters
| Hyperparameter | Value |
|---|---|
| Optimizer | Adam |
| Learning Rate | 1e-4 (with ReduceLROnPlateau callback) |
| Batch Size | 128 (GPU) / 1024 (TPU) |
| Epochs | Trained for a max of 100 with EarlyStopping (patience=10) |
| Regularization | L2 Weight Decay (λ=0.0005), Dropout (rate=0.5) |
| Hardware | Kaggle GPU (2x T4) or TPU (v5e-8) |
Evaluation
The model achieved the following performance on the CIFAR-10 test set:
- Test Accuracy: 95.7%
- Test Loss: 0.6143
How to Use
This model can be easily loaded from the Hub for inference. Below is a complete example of how to load the model and predict the class of a sample image.
import tensorflow as tf
import numpy as np
from PIL import Image
import requests
from huggingface_hub import from_pretrained_keras
# 1. Load the model from the Hub
# Replace with your actual repo_id
repo_id = "metanthropiclabs/alexnet-cifar10-optimized"
model = from_pretrained_keras(repo_id)
# 2. Define class labels
cifar10_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 3. Load and preprocess a sample image
# Example: an image of a cat from the web
url = '[https://storage.googleapis.com/petbacker/images/blog/2017/cat-in-a-box.jpg](https://storage.googleapis.com/petbacker/images/blog/2017/cat-in-a-box.jpg)'
image = Image.open(requests.get(url, stream=True).raw)
image = image.resize((224, 224)) # Must match model's input size
image_array = np.array(image)
# Normalize and add a batch dimension
image_array = image_array.astype('float32') / 255.0
image_tensor = tf.expand_dims(image_array, 0) # Create a batch of 1
# 4. Make a prediction
predictions = model.predict(image_tensor)
predicted_class_index = np.argmax(predictions[0])
predicted_class_name = cifar10_labels[predicted_class_index]
print(f"Predicted Class: {predicted_class_name}")
# Expected output for this image: Predicted Class: cat
- Downloads last month
- -