wing-selector-mlp / inference.py
ecopus's picture
Upload folder using huggingface_hub
bcf89e2 verified
import os, json, numpy as np, torch, torch.nn as nn
OBJECTIVES = ["min_cd", "max_cl", "max_ld"]
class MLPSelector(nn.Module):
def __init__(self, in_dim:int, n_airfoils:int, obj_dim:int=3, af_embed_dim:int=8, hidden:int=128):
super().__init__()
self.af_emb = nn.Embedding(n_airfoils, af_embed_dim)
self.net = nn.Sequential(
nn.Linear(in_dim + obj_dim + af_embed_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, 1),
)
def forward(self, x, obj_id, af_id):
B = x.size(0)
obj_oh = torch.zeros(B, 3, device=x.device)
obj_oh[torch.arange(B), obj_id] = 1.0
af_e = self.af_emb(af_id)
z = torch.cat([x, obj_oh, af_e], dim=1)
return self.net(z).squeeze(1)
def load_selector(local_dir=".", device="cpu"):
ckpt_path = os.path.join(local_dir, "best.pt")
if not os.path.exists(ckpt_path):
ckpt_path = os.path.join(local_dir, "last.pt")
if not os.path.exists(ckpt_path):
raise FileNotFoundError("best.pt/last.pt not found in "+local_dir)
ckpt = torch.load(ckpt_path, map_location=device)
cfg = {
"in_dim": int(ckpt["in_dim"]),
"n_airfoils": int(ckpt["n_airfoils"]),
"feat_stats": {
"means": np.array(ckpt["feat_stats"]["means"], dtype=np.float32),
"stds": np.array(ckpt["feat_stats"]["stds"], dtype=np.float32),
}
}
model = MLPSelector(cfg["in_dim"], cfg["n_airfoils"])
model.load_state_dict(ckpt["model"])
model.to(device).eval()
return model, cfg
def standardize(X_raw: np.ndarray, means: np.ndarray, stds: np.ndarray) -> np.ndarray:
X_imp = np.where(np.isfinite(X_raw), X_raw, means)
return (X_imp - means) / np.where(stds==0, 1.0, stds)
def score_wings(model, X_std: np.ndarray, airfoil_id: int, objective: str, device="cpu"):
obj_id = OBJECTIVES.index(objective)
X = torch.tensor(X_std, dtype=torch.float32, device=device)
obj_ids = torch.full((X.size(0),), obj_id, dtype=torch.long, device=device)
af_ids = torch.full((X.size(0),), airfoil_id, dtype=torch.long, device=device)
with torch.no_grad():
probs = torch.sigmoid(model(X, obj_ids, af_ids)).cpu().numpy()
return probs # higher = better