|
|
|
|
|
import os, json, numpy as np, torch, torch.nn as nn
|
|
|
|
|
|
OBJECTIVES = ["min_cd", "max_cl", "max_ld"]
|
|
|
|
|
|
class MLPSelector(nn.Module):
|
|
|
def __init__(self, in_dim:int, n_airfoils:int, obj_dim:int=3, af_embed_dim:int=8, hidden:int=128):
|
|
|
super().__init__()
|
|
|
self.af_emb = nn.Embedding(n_airfoils, af_embed_dim)
|
|
|
self.net = nn.Sequential(
|
|
|
nn.Linear(in_dim + obj_dim + af_embed_dim, hidden),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(hidden, hidden),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(hidden, 1),
|
|
|
)
|
|
|
def forward(self, x, obj_id, af_id):
|
|
|
B = x.size(0)
|
|
|
obj_oh = torch.zeros(B, 3, device=x.device)
|
|
|
obj_oh[torch.arange(B), obj_id] = 1.0
|
|
|
af_e = self.af_emb(af_id)
|
|
|
z = torch.cat([x, obj_oh, af_e], dim=1)
|
|
|
return self.net(z).squeeze(1)
|
|
|
|
|
|
def load_selector(local_dir=".", device="cpu"):
|
|
|
ckpt_path = os.path.join(local_dir, "best.pt")
|
|
|
if not os.path.exists(ckpt_path):
|
|
|
ckpt_path = os.path.join(local_dir, "last.pt")
|
|
|
if not os.path.exists(ckpt_path):
|
|
|
raise FileNotFoundError("best.pt/last.pt not found in "+local_dir)
|
|
|
|
|
|
ckpt = torch.load(ckpt_path, map_location=device)
|
|
|
cfg = {
|
|
|
"in_dim": int(ckpt["in_dim"]),
|
|
|
"n_airfoils": int(ckpt["n_airfoils"]),
|
|
|
"feat_stats": {
|
|
|
"means": np.array(ckpt["feat_stats"]["means"], dtype=np.float32),
|
|
|
"stds": np.array(ckpt["feat_stats"]["stds"], dtype=np.float32),
|
|
|
}
|
|
|
}
|
|
|
model = MLPSelector(cfg["in_dim"], cfg["n_airfoils"])
|
|
|
model.load_state_dict(ckpt["model"])
|
|
|
model.to(device).eval()
|
|
|
return model, cfg
|
|
|
|
|
|
def standardize(X_raw: np.ndarray, means: np.ndarray, stds: np.ndarray) -> np.ndarray:
|
|
|
X_imp = np.where(np.isfinite(X_raw), X_raw, means)
|
|
|
return (X_imp - means) / np.where(stds==0, 1.0, stds)
|
|
|
|
|
|
def score_wings(model, X_std: np.ndarray, airfoil_id: int, objective: str, device="cpu"):
|
|
|
obj_id = OBJECTIVES.index(objective)
|
|
|
X = torch.tensor(X_std, dtype=torch.float32, device=device)
|
|
|
obj_ids = torch.full((X.size(0),), obj_id, dtype=torch.long, device=device)
|
|
|
af_ids = torch.full((X.size(0),), airfoil_id, dtype=torch.long, device=device)
|
|
|
with torch.no_grad():
|
|
|
probs = torch.sigmoid(model(X, obj_ids, af_ids)).cpu().numpy()
|
|
|
return probs
|
|
|
|