# model.py import os import sys import torch import torch.nn as nn import torchvision from torchvision import transforms from huggingface_hub import hf_hub_download from PIL import Image import numpy as np # --- Cấu hình chung --- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" HF_REPO = "VanNguyen1214/detect_faceshape" # repo của bạn trên HF Hub HF_FILENAME = "best_model.pth" # file ở root của repo LOCAL_CKPT = "models/best_model.pth" # sẽ lưu tại đây CLASS_NAMES = ['Heart', 'Oblong', 'Oval', 'Round', 'Square'] NUM_CLASSES = len(CLASS_NAMES) # --- Transform cho ảnh trước inference --- _TRANSFORM = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std =[0.229, 0.224, 0.225]), ]) def _ensure_checkpoint() -> str: """ Kiểm tra xem LOCAL_CKPT đã tồn tại chưa. Nếu chưa, tải best_model.pth từ HF_REPO và lưu vào ./models/ """ if os.path.exists(LOCAL_CKPT): return LOCAL_CKPT try: ckpt_path = hf_hub_download( repo_id=HF_REPO, filename=HF_FILENAME, local_dir="models", ) return ckpt_path except Exception as e: print(f"❌ Không tải được model từ HF Hub: {e}") sys.exit(1) def _load_model(ckpt_path: str) -> torch.nn.Module: """ Tái tạo kiến trúc EfficientNet-B4, load state_dict, đưa về eval mode. """ # 1) Khởi tạo EfficientNet-B4 model = torchvision.models.efficientnet_b4(pretrained=False) in_features = model.classifier[1].in_features model.classifier = nn.Sequential( nn.Dropout(p=0.3, inplace=True), nn.Linear(in_features, NUM_CLASSES) ) # 2) Load trọng số state = torch.load(ckpt_path, map_location=DEVICE) model.load_state_dict(state) # 3) Đưa model về chế độ evaluation return model.to(DEVICE).eval() # === Build model ngay khi import === _CKPT_PATH = _ensure_checkpoint() _MODEL = _load_model(_CKPT_PATH) def predict(image: Image.Image) -> dict: """ Chức năng inference: - image: numpy array H×W×3 RGB - Trả về dict: { "predicted_class": str, "confidence": float, "probabilities": { class_name: prob, ... } } """ # Convert về PIL + transform img = image.convert("RGB") x = _TRANSFORM(img).unsqueeze(0).to(DEVICE) # Inference with torch.no_grad(): logits = _MODEL(x) probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy() idx = int(probs.argmax()) return {"predicted_class": CLASS_NAMES[idx]}