|
|
import argparse |
|
|
import json |
|
|
from pathlib import Path |
|
|
from safetensors import safe_open |
|
|
|
|
|
|
|
|
def check_model_shape(model_path: str): |
|
|
"""Inspects a model's config and weights to determine its MLP structure.""" |
|
|
model_path = Path(model_path) |
|
|
config_path = model_path / "config.json" |
|
|
weights_path = model_path / "model.safetensors" |
|
|
|
|
|
if not config_path.exists(): |
|
|
print(f"Error: config.json not found in {model_path}") |
|
|
return |
|
|
|
|
|
if not weights_path.exists(): |
|
|
print(f"Error: model.safetensors not found in {model_path}") |
|
|
return |
|
|
|
|
|
print(f"--- Checking model shape in {model_path} ---") |
|
|
|
|
|
|
|
|
with open(config_path, "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
has_dual_mlp_config = config.get("intermediate_size_mlp", 0) > 0 |
|
|
print(f"Config has 'intermediate_size_mlp': {has_dual_mlp_config}") |
|
|
|
|
|
|
|
|
has_dual_mlp_weights = False |
|
|
try: |
|
|
with safe_open(weights_path, framework="mlx") as f: |
|
|
weight_keys = f.keys() |
|
|
|
|
|
|
|
|
for key in weight_keys: |
|
|
if ( |
|
|
"mlp" in key |
|
|
and "gate_proj" not in key |
|
|
and "up_proj" not in key |
|
|
and "down_proj" not in key |
|
|
): |
|
|
print(f"Found potential dual-branch weight: {key}") |
|
|
has_dual_mlp_weights = True |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"Could not read weights from model.safetensors: {e}") |
|
|
return |
|
|
|
|
|
print(f"Found potential dual-branch MLP weights: {has_dual_mlp_weights}") |
|
|
|
|
|
|
|
|
print("\n--- Conclusion ---") |
|
|
if has_dual_mlp_config and has_dual_mlp_weights: |
|
|
print("✅ The model appears to be a DUAL-BRANCH MLP variant.") |
|
|
elif has_dual_mlp_config and not has_dual_mlp_weights: |
|
|
print( |
|
|
"⚠️ The model configuration suggests a dual-branch MLP, but no corresponding weights were found." |
|
|
) |
|
|
print(" It will likely run as a SINGLE-BRANCH model.") |
|
|
else: |
|
|
print("✅ The model appears to be a SINGLE-BRANCH MLP variant.") |
|
|
print("--------------------\n") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Check the MLP shape of a model variant." |
|
|
) |
|
|
parser.add_argument( |
|
|
"model_path", |
|
|
type=str, |
|
|
nargs="?", |
|
|
default=".", |
|
|
help="Path to the model directory to check.", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
check_model_shape(args.model_path) |
|
|
|