MobileLLM-R1-950M-MLX / check_shape.py
robbiemu's picture
add mlx and mlx-lm support
e39ff3a
raw
history blame
2.77 kB
import argparse
import json
from pathlib import Path
from safetensors import safe_open
def check_model_shape(model_path: str):
"""Inspects a model's config and weights to determine its MLP structure."""
model_path = Path(model_path)
config_path = model_path / "config.json"
weights_path = model_path / "model.safetensors"
if not config_path.exists():
print(f"Error: config.json not found in {model_path}")
return
if not weights_path.exists():
print(f"Error: model.safetensors not found in {model_path}")
return
print(f"--- Checking model shape in {model_path} ---")
# 1. Inspect config.json
with open(config_path, "r") as f:
config = json.load(f)
has_dual_mlp_config = config.get("intermediate_size_mlp", 0) > 0
print(f"Config has 'intermediate_size_mlp': {has_dual_mlp_config}")
# 2. Inspect weight keys from model.safetensors
has_dual_mlp_weights = False
try:
with safe_open(weights_path, framework="mlx") as f:
weight_keys = f.keys()
# A simple heuristic: check for weight keys that are not part of the standard SwiGLU MLP.
# This is not foolproof as names can vary, but it's a good indicator.
for key in weight_keys:
if (
"mlp" in key
and "gate_proj" not in key
and "up_proj" not in key
and "down_proj" not in key
):
print(f"Found potential dual-branch weight: {key}")
has_dual_mlp_weights = True
break
except Exception as e:
print(f"Could not read weights from model.safetensors: {e}")
return
print(f"Found potential dual-branch MLP weights: {has_dual_mlp_weights}")
# 3. Report conclusion
print("\n--- Conclusion ---")
if has_dual_mlp_config and has_dual_mlp_weights:
print("✅ The model appears to be a DUAL-BRANCH MLP variant.")
elif has_dual_mlp_config and not has_dual_mlp_weights:
print(
"⚠️ The model configuration suggests a dual-branch MLP, but no corresponding weights were found."
)
print(" It will likely run as a SINGLE-BRANCH model.")
else:
print("✅ The model appears to be a SINGLE-BRANCH MLP variant.")
print("--------------------\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Check the MLP shape of a model variant."
)
parser.add_argument(
"model_path",
type=str,
nargs="?",
default=".",
help="Path to the model directory to check.",
)
args = parser.parse_args()
check_model_shape(args.model_path)