real_finals / detect_face.py
VanNguyen1214's picture
Upload 58 files
8dff9a2 verified
# model.py
import os
import sys
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
# --- Cấu hình chung ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HF_REPO = "VanNguyen1214/detect_faceshape" # repo của bạn trên HF Hub
HF_FILENAME = "best_model.pth" # file ở root của repo
LOCAL_CKPT = "models/best_model.pth" # sẽ lưu tại đây
CLASS_NAMES = ['Heart', 'Oblong', 'Oval', 'Round', 'Square']
NUM_CLASSES = len(CLASS_NAMES)
# --- Transform cho ảnh trước inference ---
_TRANSFORM = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std =[0.229, 0.224, 0.225]),
])
def _ensure_checkpoint() -> str:
"""
Kiểm tra xem LOCAL_CKPT đã tồn tại chưa.
Nếu chưa, tải best_model.pth từ HF_REPO và lưu vào ./models/
"""
if os.path.exists(LOCAL_CKPT):
return LOCAL_CKPT
try:
ckpt_path = hf_hub_download(
repo_id=HF_REPO,
filename=HF_FILENAME,
local_dir="models",
)
return ckpt_path
except Exception as e:
print(f"❌ Không tải được model từ HF Hub: {e}")
sys.exit(1)
def _load_model(ckpt_path: str) -> torch.nn.Module:
"""
Tái tạo kiến trúc EfficientNet-B4, load state_dict, đưa về eval mode.
"""
# 1) Khởi tạo EfficientNet-B4
model = torchvision.models.efficientnet_b4(pretrained=False)
in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features, NUM_CLASSES)
)
# 2) Load trọng số
state = torch.load(ckpt_path, map_location=DEVICE)
model.load_state_dict(state)
# 3) Đưa model về chế độ evaluation
return model.to(DEVICE).eval()
# === Build model ngay khi import ===
_CKPT_PATH = _ensure_checkpoint()
_MODEL = _load_model(_CKPT_PATH)
def predict(image: Image.Image) -> dict:
"""
Chức năng inference:
- image: numpy array H×W×3 RGB
- Trả về dict:
{
"predicted_class": str,
"confidence": float,
"probabilities": { class_name: prob, ... }
}
"""
# Convert về PIL + transform
img = image.convert("RGB")
x = _TRANSFORM(img).unsqueeze(0).to(DEVICE)
# Inference
with torch.no_grad():
logits = _MODEL(x)
probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
idx = int(probs.argmax())
return {"predicted_class": CLASS_NAMES[idx]}