File size: 1,591 Bytes
dd98f48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# app.py
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import io, torch
from collections import Counter

from models import ModelA, ModelB, ModelC, transform_small, transform_large

# 1. spin up FastAPI
app = FastAPI()

# 2. load your saved weights
device = torch.device('cpu')
modelA = ModelA(); 
modelA.load_state_dict(torch.load('modelA.pth', map_location=device,weights_only=True))
modelA.eval()

modelB = ModelB()
modelB.load_state_dict(torch.load('modelB.pth', map_location=device,weights_only=True))
modelB.eval()
modelC = ModelC()
modelC.load_state_dict(torch.load('modelC.pth', map_location=device,weights_only=True))
modelC.eval()

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    # read image bytes → PIL
    data = await file.read()
    img  = Image.open(io.BytesIO(data)).convert('RGB')

    # preprocess
    t_small = transform_small(img).unsqueeze(0)  # for A & B
    t_large = transform_large(img).unsqueeze(0)  # for C

    # run inference
    votes = []
    with torch.no_grad():
        for model, inp in [(modelA, t_small), (modelB, t_small), (modelC, t_large)]:
            out = model(inp)
            _, pred = out.max(1)
            votes.append(int(pred.item()))

    # majority vote + confidence
    vote_count  = Counter(votes)
    final_label = vote_count.most_common(1)[0][0]
    confidence  = vote_count[final_label] / len(votes)

    return {
        "prediction": "Real" if final_label == 1 else "Fake",
        "confidence": f"{confidence*100:.1f}%"
    }