import timm, torch import torchvision.transforms as T from PIL import Image import json, os device = "cuda" if torch.cuda.is_available() else "cpu" CFG = json.load(open(os.path.join(os.path.dirname(__file__), "config.json"))) MODEL = timm.create_model(CFG["architecture"], pretrained=False, num_classes=CFG["num_classes"]) MODEL.load_state_dict(torch.load(os.path.join(os.path.dirname(__file__), "pytorch_model.bin"), map_location=device)) MODEL.eval().to(device) TFM = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]) def predict(path): x = TFM(Image.open(path).convert("RGB")).unsqueeze(0).to(device) with torch.no_grad(): probs = torch.softmax(MODEL(x), dim=1)[0].cpu().tolist() i = int(torch.tensor(probs).argmax().item()) return {"label_id": i, "probs": probs, "label_name": CFG["class_names"][i]}