import mlx.core as mx import mlx.nn as nn import json from dataclasses import dataclass from pathlib import Path @dataclass class ModelArgs: hidden_size: int num_attention_heads: int num_hidden_layers: int vocab_size: int intermediate_size: int intermediate_size_mlp: int = None num_key_value_heads: int = 0 rms_norm_eps: float = 1e-5 rope_theta: float = 10000.0 head_dim: int = None use_dual_mlp: bool = False tie_word_embeddings: bool = True use_qk_norm: bool = False attn_scale: float = 1.0 no_rope_layers: list | None = None attention_chunk_size: int | None = None attn_temperature_tuning: bool = False @classmethod def from_dict(cls, params): return cls( hidden_size=params["hidden_size"], num_attention_heads=params["num_attention_heads"], num_hidden_layers=params["num_hidden_layers"], vocab_size=params["vocab_size"], intermediate_size=params["intermediate_size"], intermediate_size_mlp=params.get("intermediate_size_mlp"), num_key_value_heads=params.get("num_key_value_heads", 0), rms_norm_eps=params.get("rms_norm_eps", 1e-5), rope_theta=params.get("rope_theta", 10000.0), head_dim=params.get("head_dim"), # Default: off. We'll detect from weights in load_model. use_dual_mlp=False, tie_word_embeddings=params.get("tie_word_embeddings", True), use_qk_norm=params.get("use_qk_norm", False), attn_scale=params.get("attn_scale", 1.0), no_rope_layers=params.get("no_rope_layers"), attention_chunk_size=params.get("attention_chunk_size"), attn_temperature_tuning=params.get("attn_temperature_tuning", False), ) class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() self.weight = mx.ones((dims,)) self.eps = eps def _norm(self, x): return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) def __call__(self, x): output = self._norm(x.astype(mx.float32)).astype(x.dtype) return self.weight * output class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.n_heads = args.num_attention_heads self.n_kv_heads = ( args.num_key_value_heads if args.num_key_value_heads > 0 else args.num_attention_heads ) self.head_dim = ( args.head_dim if getattr(args, "head_dim", None) is not None else (args.hidden_size // self.n_heads) ) # Use standard LLaMA scaling. The attn_scale field in some configs # does not correspond to SDPA scaling and degrades outputs if applied here. self.scale = self.head_dim**-0.5 self.q_proj = nn.Linear( args.hidden_size, self.n_heads * self.head_dim, bias=False ) self.k_proj = nn.Linear( args.hidden_size, self.n_kv_heads * self.head_dim, bias=False ) self.v_proj = nn.Linear( args.hidden_size, self.n_kv_heads * self.head_dim, bias=False ) self.o_proj = nn.Linear( self.n_heads * self.head_dim, args.hidden_size, bias=False ) self.q_norm = ( RMSNorm(self.head_dim, eps=args.rms_norm_eps) if getattr(args, "use_qk_norm", False) else None ) self.k_norm = ( RMSNorm(self.head_dim, eps=args.rms_norm_eps) if getattr(args, "use_qk_norm", False) else None ) # Llama 4 text models commonly use traditional RoPE application self.rope = nn.RoPE(self.head_dim, traditional=True, base=args.rope_theta) def __call__( self, x, mask=None, cache=None, apply_rope: bool = True, attn_temp: float | None = None, ): B, L, D = x.shape queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if self.q_norm is not None: queries = self.q_norm(queries) keys = self.k_norm(keys) # Optionally apply RoPE depending on per-layer setting if apply_rope: if cache is not None: queries = self.rope(queries, offset=cache.offset) keys = self.rope(keys, offset=cache.offset) keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) else: if cache is not None: keys, values = cache.update_and_fetch(keys, values) if self.n_kv_heads != self.n_heads: repeat = self.n_heads // self.n_kv_heads keys = mx.repeat(keys, repeat, axis=1) values = mx.repeat(values, repeat, axis=1) # Optional attention temperature tuning (scale the softmax input) scale = self.scale if attn_temp is None else (self.scale * attn_temp) output = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) class SwiGLUMLP(nn.Module): """Standard LLaMA-style gated MLP (SwiGLU).""" def __init__(self, dim, intermediate_size, activation=nn.silu): super().__init__() self.gate_proj = nn.Linear(dim, intermediate_size, bias=False) self.up_proj = nn.Linear(dim, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, dim, bias=False) # self.activation = activation def __call__(self, x): # return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class DualMLP(nn.Module): """Dense dual-branch MLP: gated + plain.""" def __init__(self, dim, intermediate_gated, intermediate_plain, activation=nn.silu): super().__init__() self.g_up = nn.Linear(dim, intermediate_gated, bias=False) self.g_gate = nn.Linear(dim, intermediate_gated, bias=False) self.g_down = nn.Linear(intermediate_gated, dim, bias=False) self.p_up = nn.Linear(dim, intermediate_plain, bias=False) self.p_down = nn.Linear(intermediate_plain, dim, bias=False) # self.activation = activation def __call__(self, x): # gated_out = self.g_down(self.activation(self.g_gate(x)) * self.g_up(x)) # plain_out = self.p_down(self.activation(self.p_up(x))) gated_out = self.g_down(nn.silu(self.g_gate(x)) * self.g_up(x)) plain_out = self.p_down(nn.silu(self.p_up(x))) return gated_out + plain_out class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs, layer_idx: int): super().__init__() self.attention = Attention(args) self.layer_idx = layer_idx # RoPE gating per layer. # If the config provides a per-layer no_rope mask: # - If it disables ALL layers, ignore it (apply RoPE everywhere) # - Otherwise, honor the per-layer flag. if ( isinstance(args.no_rope_layers, list) and len(args.no_rope_layers) > layer_idx ): all_marked = all(bool(v) for v in args.no_rope_layers) if all_marked: disable_rope = False else: disable_rope = bool(args.no_rope_layers[layer_idx]) else: disable_rope = False self.apply_rope = not disable_rope self.layer_idx = layer_idx if args.use_dual_mlp and args.intermediate_size_mlp: self.feed_forward = DualMLP( args.hidden_size, args.intermediate_size, args.intermediate_size_mlp, ) else: self.feed_forward = SwiGLUMLP( args.hidden_size, args.intermediate_size_mlp, ) self.attention_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.ffn_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__(self, x, mask=None, cache=None): L = x.shape[1] # Use standard causal mask; iRoPE chunking is not applied for now attn_mask = ( None if L <= 1 else nn.MultiHeadAttention.create_additive_causal_mask(L).astype(x.dtype) ) args = self.attention.args apply_rope = self.apply_rope attn_temp = 1.0 if getattr(args, "attn_temperature_tuning", False) else None r = self.attention( self.attention_norm(x), attn_mask, cache, apply_rope=apply_rope, attn_temp=attn_temp, ) h = x + r r = self.feed_forward(self.ffn_norm(h)) return h + r class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.vocab_size = args.vocab_size self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size) # Plain Python list is fine in MLX self.layers = [ TransformerBlock(args=args, layer_idx=i) for i in range(args.num_hidden_layers) ] self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) if not self.args.tie_word_embeddings: self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__(self, inputs, cache=None): h = self.tok_embeddings(inputs) if cache is None: cache = [None] * len(self.layers) for layer, c in zip(self.layers, cache): h = layer(h, None, c) h = self.norm(h) if self.args.tie_word_embeddings: return h @ self.tok_embeddings.weight.T else: return self.output(h) def load_model(model_path: str): model_path = Path(model_path) with open(model_path / "config.json", "r") as f: config = json.load(f) from safetensors import safe_open from mlx.utils import tree_unflatten # Peek at weights to decide MLP variant with safe_open(model_path / "model.safetensors", framework="mlx") as f: keys = list(f.keys()) has_dual = any( (".feed_forward.g_up.weight" in k) or (".mlp.g_up.weight" in k) or (".feed_forward.p_up.weight" in k) or (".mlp.p_up.weight" in k) for k in keys ) args = ModelArgs.from_dict(config) args.use_dual_mlp = bool(has_dual) model = Model(args) weights = {} with safe_open(model_path / "model.safetensors", framework="mlx") as f: for k in f.keys(): v = f.get_tensor(k) # The keys in the safetensors file are from the Hugging Face model. # We need to map them to the names in our MLX model. k = k.replace("model.embed_tokens", "tok_embeddings") k = k.replace("model.layers", "layers") k = k.replace("self_attn", "attention") k = k.replace("input_layernorm", "attention_norm") k = k.replace("post_attention_layernorm", "ffn_norm") k = k.replace("mlp.", "feed_forward.") k = k.replace("model.norm", "norm") # For the MLP, the names are conveniently the same if using SwiGLUMLP # k = k.replace("feed_forward.gate_proj", "feed_forward.gate_proj") # k = k.replace("feed_forward.up_proj", "feed_forward.up_proj") # k = k.replace("feed_forward.down_proj", "feed_forward.down_proj") weights[k] = v # The output layer is tied to the token embeddings, so we don't load weights for it separately. if config.get("tie_word_embeddings", True): weights.pop("output.weight", None) model.update(tree_unflatten(list(weights.items()))) return model