faldeus0092 commited on
Commit
f6c7770
·
1 Parent(s): 38d695c

added example usage

Browse files
Files changed (1) hide show
  1. README.md +37 -2
README.md CHANGED
@@ -32,6 +32,41 @@ The following hyperparameters were used during training:
32
  - optimizer: SGD with momentum = 0.9
33
  - num_epochs: 5
34
 
35
- ### Training results
 
 
 
 
36
 
37
- See at [Weights and Biases](https://wandb.ai/faldeus0092/efficientnetb7_tyrequality_classifier/runs/1z5mnxps/overview?workspace=user-faldeus0092)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  - optimizer: SGD with momentum = 0.9
33
  - num_epochs: 5
34
 
35
+ ### Example usage
36
+ ```py
37
+ from efficientnet_pytorch import EfficientNet
38
+ import torch
39
+ import torchvision.transforms as transforms
40
 
41
+ model = EfficientNet.from_name('efficientnet-b7')
42
+ model._fc= torch.nn.Linear(in_features=model._fc.in_features, out_features=len(annotations_map), bias=True)
43
+ model.load_state_dict(torch.load('/content/efficientnetb7_tyrequality_classifier.pth'))
44
+
45
+ model.eval()
46
+ img = Image.open('/content/defective-tires-cause-accidents-min.jpg')
47
+ test_transform = transforms.Compose([
48
+ transforms.Resize(224),
49
+ transforms.CenterCrop(224),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.485, 0.456, 0.406],
52
+ [0.229, 0.224, 0.225])
53
+ ])
54
+ input_data = test_transform(img).unsqueeze(0)
55
+
56
+ with torch.no_grad():
57
+ output = model(input_data)
58
+
59
+ _, predicted_class = torch.max(output, 1)
60
+
61
+ probs = torch.nn.functional.softmax(output, dim=1)
62
+ conf, _ = torch.max(probs, 1)
63
+
64
+ print('Predicted Class:', predicted_class.item())
65
+ print('Predicted Label:', id2label[predicted_class.item()])
66
+ print(f'Confidence: {conf.item()*100}%')
67
+
68
+ plt.title(id2label[predicted_class.item()])
69
+ plt.axis("off")
70
+ plt.imshow(img)
71
+ plt.show()
72
+ ```