import sys from pathlib import Path # Add the current directory to the python path to import model.py sys.path.append(str(Path.cwd())) from model import load_model from mlx.utils import tree_flatten def run_diagnostic_checks(): """ Performs the verification checks outlined in the review. """ print("--- Running Diagnostic Checks ---") # 1. Load model and check for errors try: model = load_model(".") print("Successfully loaded model definition.") except Exception as e: print(f"Error loading model: {e}") return # 2. Print total parameter count try: params = model.parameters() num_params = sum(p.size for _, p in tree_flatten(params)) print(f"Total number of parameters: {num_params / 1e6:.2f}M") except Exception as e: print(f"Error calculating parameters: {e}") # 3. Verify MLP weight shapes print("--- Verifying MLP Weight Shapes ---") try: first_block = model.layers[0] args = model.args print(f"use_dual_mlp detected: {args.use_dual_mlp}") if args.use_dual_mlp: g_up_shape = first_block.feed_forward.g_up.weight.shape p_up_shape = first_block.feed_forward.p_up.weight.shape print(f"Gated MLP branch (g_up) weight shape: {g_up_shape}") print(f"Plain MLP branch (p_up) weight shape: {p_up_shape}") assert g_up_shape == (args.intermediate_size, args.hidden_size) assert p_up_shape == (args.intermediate_size_mlp, args.hidden_size) print("DualMLP weight shapes are correct.") else: gate_proj_shape = first_block.feed_forward.gate_proj.weight.shape up_proj_shape = first_block.feed_forward.up_proj.weight.shape print(f"SwiGLUMLP gate_proj weight shape: {gate_proj_shape}") print(f"SwiGLUMLP up_proj weight shape: {up_proj_shape}") assert gate_proj_shape == (args.intermediate_size_mlp, args.hidden_size) assert up_proj_shape == (args.intermediate_size_mlp, args.hidden_size) print("SwiGLUMLP weight shapes are correct.") except AttributeError as e: print( f"Error accessing MLP weights. It seems the structure is not as expected: {e}" ) except AssertionError: print("Error: MLP weight shapes do not match the configuration.") except Exception as e: print(f"An unexpected error occurred while verifying shapes: {e}") # 4. Verify Embedding shape print("--- Verifying Embedding Shape ---") try: embedding_shape = model.tok_embeddings.weight.shape print(f"Embedding weight shape: {embedding_shape}") args = model.args print(f"Expected embedding shape: ({args.vocab_size}, {args.hidden_size})") assert embedding_shape == (args.vocab_size, args.hidden_size) print("Embedding shape is correct.") except Exception as e: print(f"An unexpected error occurred while verifying embedding shape: {e}") print("--- Sanity Checking Loaded Weights ---") try: # Check expected attribute exists based on architecture if model.args.use_dual_mlp: _ = model.layers[0].feed_forward.g_gate.weight _ = model.layers[0].feed_forward.g_up.weight _ = model.layers[0].feed_forward.g_down.weight _ = model.layers[0].feed_forward.p_up.weight _ = model.layers[0].feed_forward.p_down.weight print("Found dual-branch MLP weights in the model.") else: _ = model.layers[0].feed_forward.gate_proj.weight _ = model.layers[0].feed_forward.up_proj.weight _ = model.layers[0].feed_forward.down_proj.weight print("Found SwiGLU MLP weights in the model.") print("Weight presence sanity check passed.") except Exception as e: print(f"An error occurred during sanity check: {e}") print("--- Diagnostic Checks Complete ---") if __name__ == "__main__": run_diagnostic_checks()