Commit
·
f6c7770
1
Parent(s):
38d695c
added example usage
Browse files
README.md
CHANGED
|
@@ -32,6 +32,41 @@ The following hyperparameters were used during training:
|
|
| 32 |
- optimizer: SGD with momentum = 0.9
|
| 33 |
- num_epochs: 5
|
| 34 |
|
| 35 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
- optimizer: SGD with momentum = 0.9
|
| 33 |
- num_epochs: 5
|
| 34 |
|
| 35 |
+
### Example usage
|
| 36 |
+
```py
|
| 37 |
+
from efficientnet_pytorch import EfficientNet
|
| 38 |
+
import torch
|
| 39 |
+
import torchvision.transforms as transforms
|
| 40 |
|
| 41 |
+
model = EfficientNet.from_name('efficientnet-b7')
|
| 42 |
+
model._fc= torch.nn.Linear(in_features=model._fc.in_features, out_features=len(annotations_map), bias=True)
|
| 43 |
+
model.load_state_dict(torch.load('/content/efficientnetb7_tyrequality_classifier.pth'))
|
| 44 |
+
|
| 45 |
+
model.eval()
|
| 46 |
+
img = Image.open('/content/defective-tires-cause-accidents-min.jpg')
|
| 47 |
+
test_transform = transforms.Compose([
|
| 48 |
+
transforms.Resize(224),
|
| 49 |
+
transforms.CenterCrop(224),
|
| 50 |
+
transforms.ToTensor(),
|
| 51 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
| 52 |
+
[0.229, 0.224, 0.225])
|
| 53 |
+
])
|
| 54 |
+
input_data = test_transform(img).unsqueeze(0)
|
| 55 |
+
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
output = model(input_data)
|
| 58 |
+
|
| 59 |
+
_, predicted_class = torch.max(output, 1)
|
| 60 |
+
|
| 61 |
+
probs = torch.nn.functional.softmax(output, dim=1)
|
| 62 |
+
conf, _ = torch.max(probs, 1)
|
| 63 |
+
|
| 64 |
+
print('Predicted Class:', predicted_class.item())
|
| 65 |
+
print('Predicted Label:', id2label[predicted_class.item()])
|
| 66 |
+
print(f'Confidence: {conf.item()*100}%')
|
| 67 |
+
|
| 68 |
+
plt.title(id2label[predicted_class.item()])
|
| 69 |
+
plt.axis("off")
|
| 70 |
+
plt.imshow(img)
|
| 71 |
+
plt.show()
|
| 72 |
+
```
|