#!/usr/bin/env python3 # convert_to_neox.py import argparse, re, torch from safetensors.torch import load_file, save_file def cat_if_exists(keys, state, dim=0): """Helper: concatenate a list of keys if they all exist, else return None.""" if all(k in state for k in keys): return torch.cat([state[k] for k in keys], dim=dim) return None def convert(path_in: str, path_out: str, dtype: str): src = load_file(path_in) tgt = {} # --- top‑level tensors --------------------------------------------------- tgt["model.embed_tokens.weight"] = src["transformer.wte.weight"].to(dtype) tgt["model.final_layernorm.weight"] = src["transformer.ln_f.weight"].to(dtype) tgt["model.final_layernorm.bias"] = src["transformer.ln_f.bias"].to(dtype) tgt["lm_head.weight"] = src["lm_head.weight"].to(dtype) tgt["lm_head.bias"] = src["lm_head.bias"].to(dtype) # --- per‑layer tensors --------------------------------------------------- pat = re.compile(r"transformer\.h\.(\d+)\.") layer_ids = sorted({int(pat.match(k).group(1)) for k in src if pat.match(k)}) for i in layer_ids: p_old = f"transformer.h.{i}" p_new = f"model.layers.{i}" ## attention — fuse QKV #qkv_w = cat_if_exists( # [f"{p_old}.attn.q_proj.weight", # f"{p_old}.attn.k_proj.weight", # f"{p_old}.attn.v_proj.weight"], # src) #qkv_b = cat_if_exists( # [f"{p_old}.attn.q_proj.bias", # f"{p_old}.attn.k_proj.bias", # f"{p_old}.attn.v_proj.bias"], # src) # #tgt[f"{p_new}.attention.query_key_value.weight"] = qkv_w.to(dtype) #if qkv_b is not None: # tgt[f"{p_new}.attention.query_key_value.bias"] = qkv_b.to(dtype) #else: # tgt[f"{p_new}.attention.query_key_value.bias"] = torch.tensor([0] * qkv_w.shape[0]).to(dtype) tgt[f"{p_new}.self_attn.k_proj.weight"] = src[f"{p_old}.attn.k_proj.weight"] tgt[f"{p_new}.self_attn.k_proj.bias"] = torch.tensor([0] * tgt[f"{p_new}.self_attn.k_proj.weight"].shape[0]) tgt[f"{p_new}.self_attn.q_proj.weight"] = src[f"{p_old}.attn.q_proj.weight"] tgt[f"{p_new}.self_attn.q_proj.bias"] = torch.tensor([0] * tgt[f"{p_new}.self_attn.q_proj.weight"].shape[0]) tgt[f"{p_new}.self_attn.v_proj.weight"] = src[f"{p_old}.attn.v_proj.weight"] tgt[f"{p_new}.self_attn.v_proj.bias"] = torch.tensor([0] * tgt[f"{p_new}.self_attn.v_proj.weight"].shape[0]) tgt[f"{p_new}.self_attn.dense.weight"] = src[f"{p_old}.attn.out_proj.weight"].to(dtype) tgt[f"{p_new}.self_attn.dense.bias"] = torch.tensor([0] * tgt[f"{p_new}.self_attn.dense.weight"].shape[0]).to(dtype) # layer norms tgt[f"{p_new}.input_layernorm.weight"] = src[f"{p_old}.ln_1.weight"].to(dtype) tgt[f"{p_new}.input_layernorm.bias"] = src[f"{p_old}.ln_1.bias"].to(dtype) # MLP tgt[f"{p_new}.mlp.fc1.weight"] = src[f"{p_old}.mlp.fc_in.weight"].to(dtype) tgt[f"{p_new}.mlp.fc1.bias"] = src[f"{p_old}.mlp.fc_in.bias"].to(dtype) tgt[f"{p_new}.mlp.fc2.weight"] = src[f"{p_old}.mlp.fc_out.weight"].to(dtype) tgt[f"{p_new}.mlp.fc2.bias"] = src[f"{p_old}.mlp.fc_out.bias"].to(dtype) # ------------------------------------------------------------------------ save_file(tgt, path_out) print(f"✓ wrote {len(tgt):,} tensors to {path_out}") if __name__ == "__main__": ap = argparse.ArgumentParser(description="convert GPT‑2/3 style safetensors to Phi1.5 layout") ap.add_argument("--in", dest="inp", required=True, help="source .safetensors") ap.add_argument("--out", dest="outp", required=True, help="destination .safetensors") ap.add_argument("--dtype", default="float16", choices=["float16","bfloat16","float32"], help="cast parameters to this dtype in the output file") args = ap.parse_args() convert(args.inp, args.outp, getattr(torch, args.dtype))