from pathlib import Path import numpy as np import os, shutil import matplotlib.pyplot as plt from PIL import Image from tqdm.auto import tqdm import torch import torchvision from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder from torchvision.transforms import transforms import torch.optim as optim from torchvision.models import resnet50, ResNet50_Weights transform = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor() ]) import urllib.request urllib.request.urlretrieve("https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420937484-1629951672/carpet.tar.xz", "carpet.tar.xz") import tarfile with tarfile.open('carpet.tar.xz') as f: f.extractall('.') class resnet_feature_extractor(torch.nn.Module): def __init__(self): """This class extracts the feature maps from a pretrained Resnet model.""" super(resnet_feature_extractor, self).__init__() self.model = resnet50(weights=ResNet50_Weights.DEFAULT) self.model.eval() for param in self.model.parameters(): param.requires_grad = False # Hook to extract feature maps def hook(module, input, output) -> None: """This hook saves the extracted feature map on self.featured.""" self.features.append(output) self.model.layer2[-1].register_forward_hook(hook) self.model.layer3[-1].register_forward_hook(hook) def forward(self, input): self.features = [] with torch.no_grad(): _ = self.model(input) self.avg = torch.nn.AvgPool2d(3, stride=1) fmap_size = self.features[0].shape[-2] # Feature map sizes h, w self.resize = torch.nn.AdaptiveAvgPool2d(fmap_size) resized_maps = [self.resize(self.avg(fmap)) for fmap in self.features] patch = torch.cat(resized_maps, 1) # Merge the resized feature maps patch = patch.reshape(patch.shape[1], -1).T # Craete a column tensor return patch image = Image.open(r'carpet/test/color/000.png') image = transform(image).unsqueeze(0) backbone = resnet_feature_extractor() feature = backbone(image) # print(backbone.features[0].shape) # print(backbone.features[1].shape) print(feature.shape) # plt.imshow(image[0].permute(1,2,0)) memory_bank =[] folder_path = Path(r'carpet/train/good') for pth in tqdm(folder_path.iterdir(),leave=False): with torch.no_grad(): data = transform(Image.open(pth)).unsqueeze(0) features = backbone(data) memory_bank.append(features.cpu().detach()) memory_bank = torch.cat(memory_bank,dim=0) y_score=[] folder_path = Path(r'carpet/train/good') for pth in tqdm(folder_path.iterdir(),leave=False): data = transform(Image.open(pth)).unsqueeze(0) with torch.no_grad(): features = backbone(data) distances = torch.cdist(features, memory_bank, p=2.0) dist_score, dist_score_idxs = torch.min(distances, dim=1) s_star = torch.max(dist_score) segm_map = dist_score.view(1, 1, 28, 28) y_score.append(s_star.cpu().numpy()) best_threshold = np.mean(y_score) + 2 * np.std(y_score) plt.hist(y_score,bins=50) plt.vlines(x=best_threshold,ymin=0,ymax=30,color='r') plt.show() y_score = [] y_true=[] for classes in ['color','good','cut','hole','metal_contamination','thread']: folder_path = Path(r'carpet/test/{}'.format(classes)) for pth in tqdm(folder_path.iterdir(),leave=False): class_label = pth.parts[-2] with torch.no_grad(): test_image = transform(Image.open(pth)).unsqueeze(0) features = backbone(test_image) distances = torch.cdist(features, memory_bank, p=2.0) dist_score, dist_score_idxs = torch.min(distances, dim=1) s_star = torch.max(dist_score) segm_map = dist_score.view(1, 1, 28, 28) y_score.append(s_star.cpu().numpy()) y_true.append(0 if class_label == 'good' else 1) # plotting the y_score values which do not belong to 'good' class y_score_nok = [score for score,true in zip(y_score,y_true) if true==1] plt.hist(y_score_nok,bins=50) plt.vlines(x=best_threshold,ymin=0,ymax=30,color='r') plt.show() test_image = transform(Image.open(r'carpet/test/color/000.png')).unsqueeze(0) features = backbone(test_image) distances = torch.cdist(features, memory_bank, p=2.0) dist_score, dist_score_idxs = torch.min(distances, dim=1) s_star = torch.max(dist_score) segm_map = dist_score.view(1, 1, 28, 28) segm_map = torch.nn.functional.interpolate( # Upscale by bi-linaer interpolation to match the original input resolution segm_map, size=(224, 224), mode='bilinear' ) plt.imshow(segm_map.cpu().squeeze(), cmap='jet') from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, ConfusionMatrixDisplay, f1_score # Calculate AUC-ROC score auc_roc_score = roc_auc_score(y_true, y_score) print("AUC-ROC Score:", auc_roc_score) # Plot ROC curve fpr, tpr, thresholds = roc_curve(y_true, y_score) plt.figure() plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % auc_roc_score) plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver Operating Characteristic (ROC) Curve') plt.legend(loc="lower right") plt.show() f1_scores = [f1_score(y_true, y_score >= threshold) for threshold in thresholds] # Select the best threshold based on F1 score best_threshold = thresholds[np.argmax(f1_scores)] print(f'best_threshold = {best_threshold}') # Generate confusion matrix cm = confusion_matrix(y_true, (y_score >= best_threshold).astype(int)) disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['OK','NOK']) disp.plot() plt.show() backbone.eval() import gradio as gr import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt import io # ----------------- def detect_fault(uploaded_image): # Convert uploaded image test_image = transform(uploaded_image).unsqueeze(0) with torch.no_grad(): features = backbone(test_image) distances = torch.cdist(features, memory_bank, p=2.0) dist_score, dist_score_idxs = torch.min(distances, dim=1) s_star = torch.max(dist_score) segm_map = dist_score.view(1, 1, 28, 28) segm_map = torch.nn.functional.interpolate( segm_map, size=(224, 224), mode='bilinear' ).cpu().squeeze().numpy() y_score_image = s_star.cpu().numpy() y_pred_image = 1*(y_score_image >= best_threshold) class_label = ['Image Is OK','Image is Not OK'] # --- Plot results --- fig, axs = plt.subplots(1, 3, figsize=(15, 5)) # Original image axs[0].imshow(test_image.squeeze().permute(1,2,0).cpu().numpy()) axs[0].set_title("Original Image") axs[0].axis("off") # Heatmap axs[1].imshow(segm_map, cmap='jet') axs[1].set_title(f"Anomaly Score: {y_score_image / best_threshold:0.4f}\nPrediction: {class_label[y_pred_image]}") axs[1].axis("off") # Segmentation map axs[2].imshow((segm_map > best_threshold*1.25), cmap='gray') axs[2].set_title("Fault Segmentation Map") axs[2].axis("off") # Save plot to image buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) result_image = Image.open(buf) plt.close(fig) return result_image # Gradio UI demo = gr.Interface( fn=detect_fault, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Image(type="pil", label="Detection Result"), title="Fault Detection in Images", description="Upload an image and the model will detect if there are any faults and show the segmentation map." ) if __name__ == "__main__": demo.launch()