import argparse import json from pathlib import Path from safetensors import safe_open def check_model_shape(model_path: str): """Inspects a model's config and weights to determine its MLP structure.""" model_path = Path(model_path) config_path = model_path / "config.json" weights_path = model_path / "model.safetensors" if not config_path.exists(): print(f"Error: config.json not found in {model_path}") return if not weights_path.exists(): print(f"Error: model.safetensors not found in {model_path}") return print(f"--- Checking model shape in {model_path} ---") # 1. Inspect config.json with open(config_path, "r") as f: config = json.load(f) has_dual_mlp_config = config.get("intermediate_size_mlp", 0) > 0 print(f"Config has 'intermediate_size_mlp': {has_dual_mlp_config}") # 2. Inspect weight keys from model.safetensors has_dual_mlp_weights = False try: with safe_open(weights_path, framework="mlx") as f: weight_keys = f.keys() # A simple heuristic: check for weight keys that are not part of the standard SwiGLU MLP. # This is not foolproof as names can vary, but it's a good indicator. for key in weight_keys: if ( "mlp" in key and "gate_proj" not in key and "up_proj" not in key and "down_proj" not in key ): print(f"Found potential dual-branch weight: {key}") has_dual_mlp_weights = True break except Exception as e: print(f"Could not read weights from model.safetensors: {e}") return print(f"Found potential dual-branch MLP weights: {has_dual_mlp_weights}") # 3. Report conclusion print("\n--- Conclusion ---") if has_dual_mlp_config and has_dual_mlp_weights: print("✅ The model appears to be a DUAL-BRANCH MLP variant.") elif has_dual_mlp_config and not has_dual_mlp_weights: print( "⚠️ The model configuration suggests a dual-branch MLP, but no corresponding weights were found." ) print(" It will likely run as a SINGLE-BRANCH model.") else: print("✅ The model appears to be a SINGLE-BRANCH MLP variant.") print("--------------------\n") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Check the MLP shape of a model variant." ) parser.add_argument( "model_path", type=str, nargs="?", default=".", help="Path to the model directory to check.", ) args = parser.parse_args() check_model_shape(args.model_path)