File size: 2,885 Bytes
8dff9a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# model.py

import os
import sys
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np

# --- Cấu hình chung ---
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
HF_REPO      = "VanNguyen1214/detect_faceshape"  # repo của bạn trên HF Hub
HF_FILENAME  = "best_model.pth"                  # file ở root của repo
LOCAL_CKPT   = "models/best_model.pth"           # sẽ lưu tại đây
CLASS_NAMES  = ['Heart', 'Oblong', 'Oval', 'Round', 'Square']
NUM_CLASSES  = len(CLASS_NAMES)

# --- Transform cho ảnh trước inference ---
_TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225]),
])

def _ensure_checkpoint() -> str:
    """

    Kiểm tra xem LOCAL_CKPT đã tồn tại chưa.

    Nếu chưa, tải best_model.pth từ HF_REPO và lưu vào ./models/

    """
    if os.path.exists(LOCAL_CKPT):
        return LOCAL_CKPT

    try:
        ckpt_path = hf_hub_download(
            repo_id=HF_REPO,
            filename=HF_FILENAME,
            local_dir="models",
        )
        return ckpt_path
    except Exception as e:
        print(f"❌ Không tải được model từ HF Hub: {e}")
        sys.exit(1)

def _load_model(ckpt_path: str) -> torch.nn.Module:
    """

    Tái tạo kiến trúc EfficientNet-B4, load state_dict, đưa về eval mode.

    """
    # 1) Khởi tạo EfficientNet-B4
    model = torchvision.models.efficientnet_b4(pretrained=False)
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.3, inplace=True),
        nn.Linear(in_features, NUM_CLASSES)
    )

    # 2) Load trọng số
    state = torch.load(ckpt_path, map_location=DEVICE)
    model.load_state_dict(state)

    # 3) Đưa model về chế độ evaluation
    return model.to(DEVICE).eval()

# === Build model ngay khi import ===
_CKPT_PATH = _ensure_checkpoint()
_MODEL     = _load_model(_CKPT_PATH)

def predict(image: Image.Image) -> dict:
    """

    Chức năng inference:

      - image: numpy array H×W×3 RGB

      - Trả về dict:

          {

            "predicted_class": str,

            "confidence": float,

            "probabilities": { class_name: prob, ... }

          }

    """
    # Convert về PIL + transform
    img = image.convert("RGB")
    x   = _TRANSFORM(img).unsqueeze(0).to(DEVICE)

    # Inference
    with torch.no_grad():
        logits = _MODEL(x)
        probs  = torch.softmax(logits, dim=1).squeeze().cpu().numpy()

    idx = int(probs.argmax())
    return {"predicted_class": CLASS_NAMES[idx]}