MLX Runtime (Apple silicon) — Added Files & Usage

This fork adds a lightweight MLX runtime so you can run the original MobileLLM‑R1‑950M weights with Apple’s MLX on Apple silicon. It keeps the original weights (model.safetensors) and tokenizer; only the runtime is added. Additional code is provided to reproduce the (also included) mlx-lm conversion and 4-bit quant, though some manual modifications are needed to mlx-lm at this time to get it to run.

Technical Documentation

For detailed technical information about this port, see:

What’s included (added files)

  • model.py — Minimal MLX implementation of the architecture with GQA, optional Q/K norm, RoPE, and output weight tying.
  • inference.py — Simple text generation CLI with temperature, top‑p, greedy mode, optional chat template, EOS handling, plus boxed‑answer controls for math.
  • test_model.py — Diagnostics to verify model structure/parameter shapes and key weight presence.
  • check_shape.py — Heuristic check to inspect the MLP variant from model.safetensors and config.json.
  • main.py — Convenience entry for quick manual tests.

Notes

  • This is an MLX runtime; it does not change or fine‑tune the weights. The README front‑matter marks this repo as a derivative of facebook/MobileLLM-R1-950M via base_model so it appears correctly on Hugging Face.
  • Tested via uv on macOS with Python 3.13; deps are pinned in uv.lock/pyproject.toml.

Quick start (MLX, local safetensors)

  • Install and run with uv: uv run python inference.py --prompt "What is 2+2?" --temperature 0.0 --max-tokens 64
  • Use chat template (default if chat_template.jinja present): uv run python inference.py --prompt "Explain quicksort in 1–2 sentences." --temperature 0.7 --top-p 0.9
  • Disable chat template: uv run python inference.py --prompt "Explain quicksort in 1–2 sentences." --disable-chat-template --temperature 0.7 --top-p 0.9
  • Math mode, final answer only: uv run python inference.py --prompt "Compute 17 * 23. Put your final answer in \\boxed{.}" --temperature 0.0 --final-only --stop-at-boxed --extract-boxed --max-tokens 128

Tips

  • If a sampled response stops mid‑sentence, increase --max-tokens (e.g., 192–256) or use a lower --temperature/--top-p.
  • For concise answers with the chat template, pass a system prompt: --system "Be concise. Answer in 1–2 sentences.".

Diagnostics

  • Structure/weights check: uv run python test_model.py
  • MLP variant heuristic: uv run python check_shape.py .

Details

  • The loader maps HF weight names to MLX module names and detects the MLP variant from weight keys to ensure correct layer wiring.
  • Attention uses standard 1/sqrt(d) scaling for best generation quality.

Installation

This project uses uv for dependency management.

Using uv (recommended)

# 1. Clone the repo
git clone <your-repo>
cd <your-repo>

# 2. Sync all dependencies (includes the default set)
uv sync

# 3. (Optional) Add the torch group if you plan to customize/train models
uv sync --extra torch

Without uv

If you prefer pip/venv, a requirements.txt is provided:

python -m venv .venv
source .venv/bin/activate  # Windows: .venv\Scripts\activate
pip install -r requirements.txt

The torch extra is only required if you intend to fine-tune or swap model back-ends; the default installation already supports inference.

MLX Inference Examples (safetensors)

  • Basic greedy generation:
    • uv run python inference.py --prompt "MobileLLM-R1 runs on MLX." --temperature 0 --max-tokens 64
  • Chat-style with template:
    • uv run python inference.py --prompt "Briefly summarize quicksort." --temperature 0.7 --top-p 0.9
  • Disable the chat template:
    • uv run python inference.py --prompt "Briefly summarize quicksort." --disable-chat-template --temperature 0.7 --top-p 0.9
  • Math/coding “final answer only”:
    • uv run python inference.py --prompt "Solve: 128 / 8. Put final answer in \\boxed{.}" --temperature 0 --final-only --stop-at-boxed --extract-boxed

Design Choices (why not a trivial block)

This runtime mirrors the functional details of the released weights so they load 1:1 and generate well in MLX. A minimal “one size fits all” block hides critical differences and leads to poor output quality. Key choices:

  • Attention layout and features

    • Grouped-Query Attention (GQA): separate num_attention_heads vs num_key_value_heads with head_dim from config. We implement a custom Attention so K/V can be repeated across groups and still match the HF weight layout.
    • Q/K normalization: optional RMSNorm applied to per-head Q and K, controlled by use_qk_norm.
    • RoPE: MLX nn.RoPE with the model’s rope_theta (8e6 here), and a per-layer toggle via no_rope_layers. We gate RoPE per block, with a safe fallback if the list disables all layers.
    • Scaling: we use standard 1/sqrt(d) for SDPA. Some configs expose an attn_scale used for training tricks; applying it at inference severely degraded outputs, so it’s not multiplied into SDPA.
  • MLP variant detection

    • MobileLLM variants use either standard SwiGLU (gate_proj/up_proj/down_proj) or a dual-branch dense MLP. We detect the variant from weight keys in model.safetensors and instantiate the correct module so shapes and semantics match.
  • Weight tying and mapping

    • Tie output logits to the token embedding matrix when tie_word_embeddings is true, matching HF behavior and saving memory.
    • Map HF names to MLX names during load: model.embed_tokenstok_embeddings, layer/attn/norm renames, mlp.feed_forward., model.normnorm.
  • Template and decoding

    • The provided Jinja chat template is supported for parity with HF chat usage, but allow --disable-chat-template for raw prompting. Multiple EOS IDs are supported.
    • Sampling: temperature, top‑p, and greedy; optional repetition/frequency penalties; math helpers --final-only/--stop-at-boxed/--extract-boxed to keep answers concise.

Model Details

We present MobileLLM-R1, a new series of efficient reasoning models in the MobileLLM family. The release includes two categories of models:

Base models:

Final models:

Note: These models are not general-purpose chat models. They are Supervised Fine-Tuned (SFT) models, specifically trained to address mathematical, programming (Python, C++), and scientific problems.

In addition to the models, we release the complete training recipes and data sources to ensure reproducibility and support further research.

Remarkably, the MobileLLM-R1 950M, pre-trained on only ~2T high-quality tokens and with fewer than 5T total training tokens, achieves comparable or superior performance to Qwen3 0.6B, which was trained on 36T tokens, across MATH, GSM8K, MMLU, and LiveCodeBench benchmarks.

Compared to existing fully open-source models, MobileLLM-R1 950M model achieves ~5× higher accuracy on MATH compared to the Olmo 1.24B model and ~2× higher accuracy relative to the SmolLM2 1.7B model, despite being substantially smaller in parameter scale. In addition, MobileLLM-R1 950M outperforms both Olmo 1.24B and SmolLM2 1.7B by a wide margin on coding benchmarks, establishing a new state-of-the-art among fully open-source models.

Highlights

Pretrained Model

image/jpeg

Token efficiency comparison across pretrained models

image/jpeg

Post-trained Model

image/png

Model Architecture:

# Layers # Attnetion Heads # KV Heads Dim Hidden Dim Params
MobileLLM-R1-140M 15 9 3 576 2048 140M
MobileLLM-R1-360M 15 16 4 1024 4096 359M
MobileLLM-R1-950M 22 24 6 1536 6144 949M
Input modalities Output modalities Context Length Vocaburary Size Shared Embeddings
MobileLLM-R1-140M-base Text Text 4k 128k Yes
MobileLLM-R1-360M-base Text Text 4k 128k Yes
MobileLLM-R1-950M-base Text Text 4k 128k Yes
MobileLLM-R1-140M Text Text 32k 128k Yes
MobileLLM-R1-360M Text Text 32k 128k Yes
MobileLLM-R1-950M Text Text 32k 128k Yes

How to use

To load the pretrained model for further finetuning or evaluation:

from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/MobileLLM-R1-950M")
model = AutoModelForCausalLM.from_pretrained("facebook/MobileLLM-R1-950M")

Inference examples

Inference (MLX)

Use the MLX runtime provided in this repo to run the local model.safetensors on Apple silicon.

  • Basic: uv run python inference.py --prompt "Hello MLX" --temperature 0.7 --top-p 0.9
  • Deterministic: uv run python inference.py --prompt "Hello MLX" --temperature 0 --max-tokens 64

Flags in inference.py

  • --model-path: path to model directory (default: .)
  • --prompt: input text
  • --max-tokens: number of tokens to generate
  • --temperature: 0 for greedy, >0 for sampling
  • --top-p: nucleus sampling cutoff
  • --system: optional system message when using chat template
  • --final-only: instructs model to output only a final boxed answer
  • --stop-at-boxed: stop generation after closing } following \boxed{
  • --extract-boxed: print the last \boxed{...} content
  • --disable-chat-template: bypass chat_template.jinja and send raw prompt (with BOS)
  • --repetition-penalty: discourage previously generated tokens (>1.0)
  • --frequency-penalty: subtract alpha * token frequency from logits

See also: the “MLX Runtime (Apple silicon) — Added Files & Usage” section above for more examples and notes.

Inference (MLX-LM)

Two mlx-lm models are also provided, a conversion and a dynamic 4 bit quantization. code to reproduce and a handy inference runtime are provided in custom_mlx_lm/. After installation the following examples should work (I am forgetting, you may need to first copy the model into mlx_lm/ as llama4_text.py)

mobilellm-infer --model-path MobileLLM-R1-950M-mixed-4bit-mlx --prompt "What is the nearest prime to 9^2?

mobilellm-infer --model-path MobileLLM-R1-950M-mlx/ --prompt "What is the nearest prime to 9^2?"

Transformers

from transformers import pipeline
import torch

model_id = "facebook/MobileLLM-R1-950M"

pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype="auto",
    device_map="auto",
)

# Math problem / default scenario
messages = [
    {
        "role": "system",
        "content": "Please reason step by step, and put your final answer within \\boxed{}."
    },
    {"role": "user", "content": "Compute: $1-2+3-4+5- \\dots +99-100$."},
]

# C++ coding scenario
messages = [
    {
        "role": "system",
        "content": (
            "\nYou are a helpful and harmless assistant. You should think step-by-step before responding to the instruction below.\n\n"
            "Please use c++ programming language only.\n"
            "You must use ```cpp for just the final solution code block with the following format:\n"
            "```cpp\n# Your code here\n```\n"
        )
    },
    {"role": "user", "content": "Write a C++ program that prints 'Hello, World!'."},
]

# Python coding scenario
messages = [
    {
        "role": "system",
        "content": (
            "\nYou are a helpful and harmless assistant. You should think step-by-step before responding to the instruction below.\n\n"
            "Please use python programming language only.\n"
            "You must use ```python for just the final solution code block with the following format:\n"
            "```python\n# Your code here\n```\n"
        )
    },
    {"role": "user", "content": "Write a Python function that returns the square of a number."},
]

outputs = pipe(
    messages,
    max_new_tokens=8192,
)
print(outputs[0]["generated_text"][-1])

You can also run inference with vLLM. You only need to register the model architecture Llama4ForCausalLM with the vLLM ModelRegistry.

from vllm.model_executor.models.llama4 import Llama4ForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)

Evaluation

MobileLLM-R1 base model

Model Size MATH500 GSM8K MBPP HumanEval CommonSense Avg. MMLU
4-shot
em
8-shot
em
3-shot
pass@1
0-shot
pass@1
0-shot
accuracy
5-shot
accuracy
<150M
SmolLM2-135M-base 135M 0.4 1.8 3.8 0.0 50.7 --
MobileLLM-R1-140M-base 140M 4.6 16.3 5.4 15.9 44.3 --
150M - 400M
Gemma-3-270M-pt 268M 0.6 1.1 2.0 3.1 48.4 26.5
SmolLM2-360M-base 362M 1.8 5.0 19.4 0.0 56.6 24.7
MobileLLM-R1-360M-base 359M 13.4 39.4 20.8 32.9 51.0 26.8
400M - 1B
Qwen2.5-0.5B-base 494M 14.8 41.8 29.6 28.1 52.3 47.5
Qwen3-0.6B-base 596M 29.8 60.9 39.0 30.5 55.3 52.4
MobileLLM-R1-950M-base 949M 26.8 61.6 39.2 46.3 58.6 47.4
> 1B
Gemma-3-1B-pt 1.0B 0.6 2.4 9.4 6.1 57.3 26.1
LLaMA3.2-1B-base 1.24B 1.6 6.8 26.6 17.1 58.4 32.0
OLMo-2-0425-1B-base 1.48B 5.2 39.8 7.8 6.7 61.0 42.4
Qwen2.5-1.5B-base 1.54B 31.0 68.4 44.6 36.6 58.7 61.2
SmolLM2-1.7B-base 1.71B 11.6 31.8 35.4 0.6 62.9 50.0
Qwen3-1.7B-base 2.03B 38.5 76.2 56.4 47.6 60.9 62.1

Here, CommonSense Avg. denotes an average of 8 tasks in CommonSense Reasoning benchmarks including ARC-easy, ARC-challenge, BoolQ, PIQA, SIQA, HellaSwag, OBQA, and WinoGrand. Models with fewer than 150M parameters do not yield reliable MMLU scores and are therefore denoted as '—'.

MobileLLM-R1 post-trained model

Model Size MATH500 GSM8K AIME'24 AIME'25 LiveCodeBench-v6
0-shot
pass@1
0-shot
pass@1
0-shot
pass@1, n=64
0-shot
pass@1, n=64
0-shot
pass@1, n=16
<150M
SmolLM2-135M-Instruct 135M 3.0 2.4 -- -- 0.0
MobileLLM-R1-140M 140M 7.4 3.0 -- -- 1.0
150M - 400M
Gemma-3-270m-it 268M 6.8 8.4 -- -- 0.0
SmolLM2-360M-Instruct 362M 3.4 8.1 -- -- 0.7
MobileLLM-R1-360M 359M 26.6 22.7 -- -- 4.8
400M - 1B
Qwen2.5-0.5B-Instruct 494M 31.2 48.1 0.1 0.3 3.6
Qwen3-0.6B 596M 73.0 79.2 11.3 17.0 14.9
MobileLLM-R1-950M 949M 74.0 67.5 15.5 16.3 19.9
> 1B
Gemma-3-1B-it 1.0B 45.4 62.9 0.9 0.0 2.0
LLaMA3.2-1B-Instruct 1.24B 24.8 38.8 1.1 0.2 4.1
OLMo-2-0425-1B-Instruct 1.48B 19.2 69.7 0.6 0.1 0.0
OpenReasoning-Nemotron-1.5B 1.54B 83.4 76.7 49.7 40.4 28.3
DeepSeek-R1-Distill-Qwen-1.5B 1.54B 83.2 77.3 29.1 23.4 19.9
Qwen2.5-1.5B-Instruct 1.54B 54.0 70.0 2.5 0.9 7.9
SmolLM2-1.7B-Instruct 1.71B 19.2 41.8 0.3 0.1 4.4
Qwen3-1.7B 2.03B 89.4 90.3 47.0 37.0 29.8

For AIME, we evaluate models across 64 runs and report the average accuracy. For LiveCodeBench, results are reported as the average accuracy across 16 runs. Models with fewer than 400M parameters do not produce reliable AIME scores and are therefore denoted as '—'.

Training

Training Process

image/jpeg

Training stages and hyperparameter details

In the pretraining phase, MobileLLM-R1 models are randomly initialized and optimized using the Adam optimizer with hyperparameters (β_1, β_2, ε) = (0.9, 0.95, 1e-8), coupled with a weight decay coefficient of 0.1. The learning rate follows a 2k-step warmup schedule and then decays linearly from its peak to 10% of the maximum.

In the mid-training phase, we use Adam optimizer with learning rate linearly decays from its maximum value to zero. We employ knowledge distillation with Llama-3.1-8B-Instruct model as the teacher, where the student is trained via minimizing the KL divergence between its output logits and the teacher logits.

In the post-training phase, we use the Adam optimizer with zero weight decay. The learning rate warmup ratio is set to 0.03 for general-purpose SFT and 0.1 for reasoning-specific SFT, and it linearly decays from its maximum value to zero. Full training hyperparameters are provided in the table below.

Stage Phase Tokens / Samples BS Sequence Length Steps LR #GPUs Training Time
Pre-training Phase1 2T tokens 16 2k 500k 4.00E-03 16 x 8 4-5 days
Phase2 2T tokens 16 2k 500k 4.00E-03 16 x 8 4-5 days
Mid-training Phase1 100B tokens 4 4k 50K 3.60E-04 16 x 8 1-2 days
Phase2 100B tokens 4 4k 50K 3.60E-04 16 x 8 1-2 days
Post-training General SFT 866K samples 4 4k 2 epochs 5.00E-06 16 x 8 ~2h
Reasoning SFT 6.2M samples 8 32k 4 epochs 8.00E-05 16 x 8 ~2.5days

Data Mix

Pre-training

Dataset Rows Tokens (B) Phase1 Mix Ratio Phase2 Mix Ratio
StarCoder 206,640,114 263.8 10.66% 0.52%
OpenWebMath 6,117,786 12.6 6.93% 23.33%
FineWeb-Edu 1,279,107,432 1300 63.75% 54.83%
Wiki 7,222,303 3.7 5.03% 0.14%
Arxiv 1,533,917 28 6.36% 1.32%
StackExchange 29,249,120 19.6 5.03% 0.86%
Algebraic stack 3,404,331 12.6 2.25% 1.26%
Nemotron science 708,920 2 -- 0.03%
Nemotron code 10,108,883 16 -- 0.72%
Nemotron math 22,066,397 15 -- 3.01%
Cosmopedia 31,064,744 25 -- 2.70%
Facebook natural reasoning 1,145,824 1.8 -- 3.18%
FineMath 48,283,984 34 -- 8.01%
peS2o 38,800,000 50 -- 0.08%
Total 100% 100%

Mid-training

Dataset Subset Rows (M) Phase1 Mix Ratio Phase2 Mix Ratio
Dolmino DCLM Baseline 606 37.03% 6.51%
FLAN 57.3 4.10% 0.72%
peS2o 38.8 11.41% 2.01%
Wiki 6.17 2.66% 0.47%
StackExchange 2.48 2.12% 2.00%
Math 21 11.63% 29.10%
Nemotron Nemotron-Pretraining-Code-v1 882 20.69% 29.10%
Nemotron-CC-Math-v1 144 3.45% 19.40%
StarCoder StarCoder 206 6.90% 9.70%
Benchmark training set TriviaQA (train)
OBQA (train)
NaturalQuestions (train)
PIQA (train)
GSM8K (train)
BoolQ (train)
ARC-Easy (train)
ARC-Challenge (train)
~0.01 -- 0.97%
Total 100.00% 100.00%

Post-training

Phase Dataset Rows
General SFT Tulu-3-sft-olmo-2-mixture-0225 866K samples
Reasoning SFT OpenMathReasoning 3.2M samples
OpenScienceReasoning-2 803K samples
OpenCodeReasoning-2 2.16M samples

Citation

If you find our model useful for your research, please consider citing:

@misc{mobilellm_r1_2025,
  title={MobileLLM-R1: Model Card},
  author={Zechun Liu*, Ernie Chang*, Changsheng Zhao*, Chia-Jung Chang, Wei Wen, Chen Lai, Rick Cao, Yuandong Tian, Raghuraman Krishnamoorthi, Yangyang Shi, Vikas Chandra},
  year={2025},
  url = {https://huggingface.co/mobilellm-r1}
}

Contact

Zechun Liu, Meta Inc (zechunliu at meta dot com)

Ernie Chang, Meta Inc (erniecyc at meta dot com)

Changsheng Zhao, Meta Inc (cszhao at meta dot com)

License

MobileLLM-R1 is FAIR NC licensed as of now

Downloads last month
39
Safetensors
Model size
0.9B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for robbiemu/MobileLLM-R1-950M-MLX

Finetuned
(3)
this model