import gradio as gr import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt import io from pathlib import Path import os, shutil from tqdm.auto import tqdm import torchvision from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder from torchvision.transforms import transforms import torch.optim as optim from torchvision.models import resnet50, ResNet50_Weights import urllib.request import tarfile # Transform transform = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor() ]) # Dataset download urllib.request.urlretrieve( "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420937484-1629951672/carpet.tar.xz", "carpet.tar.xz" ) with tarfile.open('carpet.tar.xz') as f: f.extractall('.') # Feature extractor class class resnet_feature_extractor(torch.nn.Module): def __init__(self): super(resnet_feature_extractor, self).__init__() self.model = resnet50(weights=ResNet50_Weights.DEFAULT) self.model.eval() for param in self.model.parameters(): param.requires_grad = False def hook(module, input, output): self.features.append(output) self.model.layer2[-1].register_forward_hook(hook) self.model.layer3[-1].register_forward_hook(hook) def forward(self, input): self.features = [] with torch.no_grad(): _ = self.model(input) self.avg = torch.nn.AvgPool2d(3, stride=1) fmap_size = self.features[0].shape[-2] self.resize = torch.nn.AdaptiveAvgPool2d(fmap_size) resized_maps = [self.resize(self.avg(fmap)) for fmap in self.features] patch = torch.cat(resized_maps, 1) patch = patch.reshape(patch.shape[1], -1).T return patch # Initialize backbone backbone = resnet_feature_extractor() # Memory bank memory_bank = [] folder_path = Path("carpet/train/good") for pth in tqdm(folder_path.iterdir(), leave=False): with torch.no_grad(): data = transform(Image.open(pth)).unsqueeze(0) features = backbone(data) memory_bank.append(features.cpu().detach()) memory_bank = torch.cat(memory_bank, dim=0) # Threshold y_score = [] for pth in tqdm(folder_path.iterdir(), leave=False): data = transform(Image.open(pth)).unsqueeze(0) with torch.no_grad(): features = backbone(data) distances = torch.cdist(features, memory_bank, p=2.0) dist_score, _ = torch.min(distances, dim=1) s_star = torch.max(dist_score) y_score.append(s_star.cpu().numpy()) best_threshold = np.mean(y_score) + 2 * np.std(y_score) # Gradio Function def detect_fault(uploaded_image): test_image = transform(uploaded_image).unsqueeze(0) with torch.no_grad(): features = backbone(test_image) distances = torch.cdist(features, memory_bank, p=2.0) dist_score, _ = torch.min(distances, dim=1) s_star = torch.max(dist_score) segm_map = dist_score.view(1, 1, 28, 28) segm_map = torch.nn.functional.interpolate( segm_map, size=(224, 224), mode='bilinear' ).cpu().squeeze().numpy() y_score_image = s_star.cpu().numpy() y_pred_image = 1*(y_score_image >= best_threshold) class_label = ['Image Is OK','Image Is Not OK'] # Plot results fig, axs = plt.subplots(1, 3, figsize=(15, 5)) axs[0].imshow(test_image.squeeze().permute(1,2,0).cpu().numpy()) axs[0].set_title("Original Image") axs[0].axis("off") axs[1].imshow(segm_map, cmap='jet') axs[1].set_title(f"Anomaly Score: {y_score_image / best_threshold:0.4f}\nPrediction: {class_label[y_pred_image]}") axs[1].axis("off") axs[2].imshow((segm_map > best_threshold*1.25), cmap='gray') axs[2].set_title("Fault Segmentation Map") axs[2].axis("off") buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) result_image = Image.open(buf) plt.close(fig) return result_image # Launch Gradio App demo = gr.Interface( fn=detect_fault, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Image(type="pil", label="Detection Result"), title="Fault Detection in Images", description="Upload an image and the model will detect if there are any faults and show the segmentation map." ) if __name__ == "__main__": demo.launch()