Spaces:
Runtime error
Runtime error
| # model.py | |
| import os | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| import numpy as np | |
| # --- Cấu hình chung --- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| HF_REPO = "VanNguyen1214/detect_faceshape" # repo của bạn trên HF Hub | |
| HF_FILENAME = "best_model.pth" # file ở root của repo | |
| LOCAL_CKPT = "models/best_model.pth" # sẽ lưu tại đây | |
| CLASS_NAMES = ['Heart', 'Oblong', 'Oval', 'Round', 'Square'] | |
| NUM_CLASSES = len(CLASS_NAMES) | |
| # --- Transform cho ảnh trước inference --- | |
| _TRANSFORM = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std =[0.229, 0.224, 0.225]), | |
| ]) | |
| def _ensure_checkpoint() -> str: | |
| """ | |
| Kiểm tra xem LOCAL_CKPT đã tồn tại chưa. | |
| Nếu chưa, tải best_model.pth từ HF_REPO và lưu vào ./models/ | |
| """ | |
| if os.path.exists(LOCAL_CKPT): | |
| return LOCAL_CKPT | |
| try: | |
| ckpt_path = hf_hub_download( | |
| repo_id=HF_REPO, | |
| filename=HF_FILENAME, | |
| local_dir="models", | |
| ) | |
| return ckpt_path | |
| except Exception as e: | |
| print(f"❌ Không tải được model từ HF Hub: {e}") | |
| sys.exit(1) | |
| def _load_model(ckpt_path: str) -> torch.nn.Module: | |
| """ | |
| Tái tạo kiến trúc EfficientNet-B4, load state_dict, đưa về eval mode. | |
| """ | |
| # 1) Khởi tạo EfficientNet-B4 | |
| model = torchvision.models.efficientnet_b4(pretrained=False) | |
| in_features = model.classifier[1].in_features | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3, inplace=True), | |
| nn.Linear(in_features, NUM_CLASSES) | |
| ) | |
| # 2) Load trọng số | |
| state = torch.load(ckpt_path, map_location=DEVICE) | |
| model.load_state_dict(state) | |
| # 3) Đưa model về chế độ evaluation | |
| return model.to(DEVICE).eval() | |
| # === Build model ngay khi import === | |
| _CKPT_PATH = _ensure_checkpoint() | |
| _MODEL = _load_model(_CKPT_PATH) | |
| def predict(image: Image.Image) -> dict: | |
| """ | |
| Chức năng inference: | |
| - image: numpy array H×W×3 RGB | |
| - Trả về dict: | |
| { | |
| "predicted_class": str, | |
| "confidence": float, | |
| "probabilities": { class_name: prob, ... } | |
| } | |
| """ | |
| # Convert về PIL + transform | |
| img = image.convert("RGB") | |
| x = _TRANSFORM(img).unsqueeze(0).to(DEVICE) | |
| # Inference | |
| with torch.no_grad(): | |
| logits = _MODEL(x) | |
| probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy() | |
| idx = int(probs.argmax()) | |
| return {"predicted_class": CLASS_NAMES[idx]} | |