Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import numpy as np | |
| import os, shutil | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| import torch | |
| import torchvision | |
| from torch.utils.data import DataLoader | |
| from torchvision.datasets import ImageFolder | |
| from torchvision.transforms import transforms | |
| import torch.optim as optim | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| transform = transforms.Compose([ | |
| transforms.Resize((224,224)), | |
| transforms.ToTensor() | |
| ]) | |
| import urllib.request | |
| urllib.request.urlretrieve("https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420937484-1629951672/carpet.tar.xz", | |
| "carpet.tar.xz") | |
| import tarfile | |
| with tarfile.open('carpet.tar.xz') as f: | |
| f.extractall('.') | |
| class resnet_feature_extractor(torch.nn.Module): | |
| def __init__(self): | |
| """This class extracts the feature maps from a pretrained Resnet model.""" | |
| super(resnet_feature_extractor, self).__init__() | |
| self.model = resnet50(weights=ResNet50_Weights.DEFAULT) | |
| self.model.eval() | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| # Hook to extract feature maps | |
| def hook(module, input, output) -> None: | |
| """This hook saves the extracted feature map on self.featured.""" | |
| self.features.append(output) | |
| self.model.layer2[-1].register_forward_hook(hook) | |
| self.model.layer3[-1].register_forward_hook(hook) | |
| def forward(self, input): | |
| self.features = [] | |
| with torch.no_grad(): | |
| _ = self.model(input) | |
| self.avg = torch.nn.AvgPool2d(3, stride=1) | |
| fmap_size = self.features[0].shape[-2] # Feature map sizes h, w | |
| self.resize = torch.nn.AdaptiveAvgPool2d(fmap_size) | |
| resized_maps = [self.resize(self.avg(fmap)) for fmap in self.features] | |
| patch = torch.cat(resized_maps, 1) # Merge the resized feature maps | |
| patch = patch.reshape(patch.shape[1], -1).T # Craete a column tensor | |
| return patch | |
| image = Image.open(r'carpet/test/color/000.png') | |
| image = transform(image).unsqueeze(0) | |
| backbone = resnet_feature_extractor() | |
| feature = backbone(image) | |
| # print(backbone.features[0].shape) | |
| # print(backbone.features[1].shape) | |
| print(feature.shape) | |
| # plt.imshow(image[0].permute(1,2,0)) | |
| memory_bank =[] | |
| folder_path = Path(r'carpet/train/good') | |
| for pth in tqdm(folder_path.iterdir(),leave=False): | |
| with torch.no_grad(): | |
| data = transform(Image.open(pth)).unsqueeze(0) | |
| features = backbone(data) | |
| memory_bank.append(features.cpu().detach()) | |
| memory_bank = torch.cat(memory_bank,dim=0) | |
| y_score=[] | |
| folder_path = Path(r'carpet/train/good') | |
| for pth in tqdm(folder_path.iterdir(),leave=False): | |
| data = transform(Image.open(pth)).unsqueeze(0) | |
| with torch.no_grad(): | |
| features = backbone(data) | |
| distances = torch.cdist(features, memory_bank, p=2.0) | |
| dist_score, dist_score_idxs = torch.min(distances, dim=1) | |
| s_star = torch.max(dist_score) | |
| segm_map = dist_score.view(1, 1, 28, 28) | |
| y_score.append(s_star.cpu().numpy()) | |
| best_threshold = np.mean(y_score) + 2 * np.std(y_score) | |
| plt.hist(y_score,bins=50) | |
| plt.vlines(x=best_threshold,ymin=0,ymax=30,color='r') | |
| plt.show() | |
| y_score = [] | |
| y_true=[] | |
| for classes in ['color','good','cut','hole','metal_contamination','thread']: | |
| folder_path = Path(r'carpet/test/{}'.format(classes)) | |
| for pth in tqdm(folder_path.iterdir(),leave=False): | |
| class_label = pth.parts[-2] | |
| with torch.no_grad(): | |
| test_image = transform(Image.open(pth)).unsqueeze(0) | |
| features = backbone(test_image) | |
| distances = torch.cdist(features, memory_bank, p=2.0) | |
| dist_score, dist_score_idxs = torch.min(distances, dim=1) | |
| s_star = torch.max(dist_score) | |
| segm_map = dist_score.view(1, 1, 28, 28) | |
| y_score.append(s_star.cpu().numpy()) | |
| y_true.append(0 if class_label == 'good' else 1) | |
| # plotting the y_score values which do not belong to 'good' class | |
| y_score_nok = [score for score,true in zip(y_score,y_true) if true==1] | |
| plt.hist(y_score_nok,bins=50) | |
| plt.vlines(x=best_threshold,ymin=0,ymax=30,color='r') | |
| plt.show() | |
| test_image = transform(Image.open(r'carpet/test/color/000.png')).unsqueeze(0) | |
| features = backbone(test_image) | |
| distances = torch.cdist(features, memory_bank, p=2.0) | |
| dist_score, dist_score_idxs = torch.min(distances, dim=1) | |
| s_star = torch.max(dist_score) | |
| segm_map = dist_score.view(1, 1, 28, 28) | |
| segm_map = torch.nn.functional.interpolate( # Upscale by bi-linaer interpolation to match the original input resolution | |
| segm_map, | |
| size=(224, 224), | |
| mode='bilinear' | |
| ) | |
| plt.imshow(segm_map.cpu().squeeze(), cmap='jet') | |
| from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, ConfusionMatrixDisplay, f1_score | |
| # Calculate AUC-ROC score | |
| auc_roc_score = roc_auc_score(y_true, y_score) | |
| print("AUC-ROC Score:", auc_roc_score) | |
| # Plot ROC curve | |
| fpr, tpr, thresholds = roc_curve(y_true, y_score) | |
| plt.figure() | |
| plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % auc_roc_score) | |
| plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') | |
| plt.xlabel('False Positive Rate') | |
| plt.ylabel('True Positive Rate') | |
| plt.title('Receiver Operating Characteristic (ROC) Curve') | |
| plt.legend(loc="lower right") | |
| plt.show() | |
| f1_scores = [f1_score(y_true, y_score >= threshold) for threshold in thresholds] | |
| # Select the best threshold based on F1 score | |
| best_threshold = thresholds[np.argmax(f1_scores)] | |
| print(f'best_threshold = {best_threshold}') | |
| # Generate confusion matrix | |
| cm = confusion_matrix(y_true, (y_score >= best_threshold).astype(int)) | |
| disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['OK','NOK']) | |
| disp.plot() | |
| plt.show() | |
| backbone.eval() | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| # ----------------- | |
| def detect_fault(uploaded_image): | |
| # Convert uploaded image | |
| test_image = transform(uploaded_image).unsqueeze(0) | |
| with torch.no_grad(): | |
| features = backbone(test_image) | |
| distances = torch.cdist(features, memory_bank, p=2.0) | |
| dist_score, dist_score_idxs = torch.min(distances, dim=1) | |
| s_star = torch.max(dist_score) | |
| segm_map = dist_score.view(1, 1, 28, 28) | |
| segm_map = torch.nn.functional.interpolate( | |
| segm_map, | |
| size=(224, 224), | |
| mode='bilinear' | |
| ).cpu().squeeze().numpy() | |
| y_score_image = s_star.cpu().numpy() | |
| y_pred_image = 1*(y_score_image >= best_threshold) | |
| class_label = ['Image Is OK','Image is Not OK'] | |
| # --- Plot results --- | |
| fig, axs = plt.subplots(1, 3, figsize=(15, 5)) | |
| # Original image | |
| axs[0].imshow(test_image.squeeze().permute(1,2,0).cpu().numpy()) | |
| axs[0].set_title("Original Image") | |
| axs[0].axis("off") | |
| # Heatmap | |
| axs[1].imshow(segm_map, cmap='jet') | |
| axs[1].set_title(f"Anomaly Score: {y_score_image / best_threshold:0.4f}\nPrediction: {class_label[y_pred_image]}") | |
| axs[1].axis("off") | |
| # Segmentation map | |
| axs[2].imshow((segm_map > best_threshold*1.25), cmap='gray') | |
| axs[2].set_title("Fault Segmentation Map") | |
| axs[2].axis("off") | |
| # Save plot to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| result_image = Image.open(buf) | |
| plt.close(fig) | |
| return result_image | |
| # Gradio UI | |
| demo = gr.Interface( | |
| fn=detect_fault, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Image(type="pil", label="Detection Result"), | |
| title="Fault Detection in Images", | |
| description="Upload an image and the model will detect if there are any faults and show the segmentation map." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |