MLX Runtime (Apple silicon) — Added Files & Usage
This fork adds a lightweight MLX runtime so you can run the original MobileLLM‑R1‑950M weights with Apple’s MLX on Apple silicon. It keeps the original weights (model.safetensors) and tokenizer; only the runtime is added. Additional code is provided to reproduce the (also included) mlx-lm conversion and 4-bit quant, though some manual modifications are needed to mlx-lm at this time to get it to run.
Technical Documentation
For detailed technical information about this port, see:
- MLX Technical Summary - Challenges and solutions for porting MobileLLM-R1 to MLX in this PoC conversion.
- Conversion Log - Details of the model conversion process
- Quantization Log - Information about quantization procedures and results
What’s included (added files)
model.py— Minimal MLX implementation of the architecture with GQA, optional Q/K norm, RoPE, and output weight tying.inference.py— Simple text generation CLI with temperature, top‑p, greedy mode, optional chat template, EOS handling, plus boxed‑answer controls for math.test_model.py— Diagnostics to verify model structure/parameter shapes and key weight presence.check_shape.py— Heuristic check to inspect the MLP variant frommodel.safetensorsandconfig.json.main.py— Convenience entry for quick manual tests.
Notes
- This is an MLX runtime; it does not change or fine‑tune the weights. The README front‑matter marks this repo as a derivative of
facebook/MobileLLM-R1-950Mviabase_modelso it appears correctly on Hugging Face. - Tested via
uvon macOS with Python 3.13; deps are pinned inuv.lock/pyproject.toml.
Quick start (MLX, local safetensors)
- Install and run with uv:
uv run python inference.py --prompt "What is 2+2?" --temperature 0.0 --max-tokens 64 - Use chat template (default if
chat_template.jinjapresent):uv run python inference.py --prompt "Explain quicksort in 1–2 sentences." --temperature 0.7 --top-p 0.9 - Disable chat template:
uv run python inference.py --prompt "Explain quicksort in 1–2 sentences." --disable-chat-template --temperature 0.7 --top-p 0.9 - Math mode, final answer only:
uv run python inference.py --prompt "Compute 17 * 23. Put your final answer in \\boxed{.}" --temperature 0.0 --final-only --stop-at-boxed --extract-boxed --max-tokens 128
Tips
- If a sampled response stops mid‑sentence, increase
--max-tokens(e.g., 192–256) or use a lower--temperature/--top-p. - For concise answers with the chat template, pass a system prompt:
--system "Be concise. Answer in 1–2 sentences.".
Diagnostics
- Structure/weights check:
uv run python test_model.py - MLP variant heuristic:
uv run python check_shape.py .
Details
- The loader maps HF weight names to MLX module names and detects the MLP variant from weight keys to ensure correct layer wiring.
- Attention uses standard
1/sqrt(d)scaling for best generation quality.
Installation
This project uses uv for dependency management.
Using uv (recommended)
# 1. Clone the repo
git clone <your-repo>
cd <your-repo>
# 2. Sync all dependencies (includes the default set)
uv sync
# 3. (Optional) Add the torch group if you plan to customize/train models
uv sync --extra torch
Without uv
If you prefer pip/venv, a requirements.txt is provided:
python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate
pip install -r requirements.txt
The
torchextra is only required if you intend to fine-tune or swap model back-ends; the default installation already supports inference.
MLX Inference Examples (safetensors)
- Basic greedy generation:
uv run python inference.py --prompt "MobileLLM-R1 runs on MLX." --temperature 0 --max-tokens 64
- Chat-style with template:
uv run python inference.py --prompt "Briefly summarize quicksort." --temperature 0.7 --top-p 0.9
- Disable the chat template:
uv run python inference.py --prompt "Briefly summarize quicksort." --disable-chat-template --temperature 0.7 --top-p 0.9
- Math/coding “final answer only”:
uv run python inference.py --prompt "Solve: 128 / 8. Put final answer in \\boxed{.}" --temperature 0 --final-only --stop-at-boxed --extract-boxed
Design Choices (why not a trivial block)
This runtime mirrors the functional details of the released weights so they load 1:1 and generate well in MLX. A minimal “one size fits all” block hides critical differences and leads to poor output quality. Key choices:
Attention layout and features
- Grouped-Query Attention (GQA): separate
num_attention_headsvsnum_key_value_headswith head_dim from config. We implement a customAttentionso K/V can be repeated across groups and still match the HF weight layout. - Q/K normalization: optional RMSNorm applied to per-head Q and K, controlled by
use_qk_norm. - RoPE: MLX
nn.RoPEwith the model’srope_theta(8e6 here), and a per-layer toggle viano_rope_layers. We gate RoPE per block, with a safe fallback if the list disables all layers. - Scaling: we use standard
1/sqrt(d)for SDPA. Some configs expose anattn_scaleused for training tricks; applying it at inference severely degraded outputs, so it’s not multiplied into SDPA.
- Grouped-Query Attention (GQA): separate
MLP variant detection
- MobileLLM variants use either standard SwiGLU (gate_proj/up_proj/down_proj) or a dual-branch dense MLP. We detect the variant from weight keys in
model.safetensorsand instantiate the correct module so shapes and semantics match.
- MobileLLM variants use either standard SwiGLU (gate_proj/up_proj/down_proj) or a dual-branch dense MLP. We detect the variant from weight keys in
Weight tying and mapping
- Tie output logits to the token embedding matrix when
tie_word_embeddingsis true, matching HF behavior and saving memory. - Map HF names to MLX names during load:
model.embed_tokens→tok_embeddings, layer/attn/norm renames,mlp.→feed_forward.,model.norm→norm.
- Tie output logits to the token embedding matrix when
Template and decoding
- The provided Jinja chat template is supported for parity with HF chat usage, but allow
--disable-chat-templatefor raw prompting. Multiple EOS IDs are supported. - Sampling: temperature, top‑p, and greedy; optional repetition/frequency penalties; math helpers
--final-only/--stop-at-boxed/--extract-boxedto keep answers concise.
- The provided Jinja chat template is supported for parity with HF chat usage, but allow
Model Details
We present MobileLLM-R1, a new series of efficient reasoning models in the MobileLLM family. The release includes two categories of models:
Base models:
Final models:
Note: These models are not general-purpose chat models. They are Supervised Fine-Tuned (SFT) models, specifically trained to address mathematical, programming (Python, C++), and scientific problems.
In addition to the models, we release the complete training recipes and data sources to ensure reproducibility and support further research.
Remarkably, the MobileLLM-R1 950M, pre-trained on only ~2T high-quality tokens and with fewer than 5T total training tokens, achieves comparable or superior performance to Qwen3 0.6B, which was trained on 36T tokens, across MATH, GSM8K, MMLU, and LiveCodeBench benchmarks.
Compared to existing fully open-source models, MobileLLM-R1 950M model achieves ~5× higher accuracy on MATH compared to the Olmo 1.24B model and ~2× higher accuracy relative to the SmolLM2 1.7B model, despite being substantially smaller in parameter scale. In addition, MobileLLM-R1 950M outperforms both Olmo 1.24B and SmolLM2 1.7B by a wide margin on coding benchmarks, establishing a new state-of-the-art among fully open-source models.
Highlights
Pretrained Model
Token efficiency comparison across pretrained models
Post-trained Model
Model Architecture:
| # Layers | # Attnetion Heads | # KV Heads | Dim | Hidden Dim | Params | |
|---|---|---|---|---|---|---|
| MobileLLM-R1-140M | 15 | 9 | 3 | 576 | 2048 | 140M |
| MobileLLM-R1-360M | 15 | 16 | 4 | 1024 | 4096 | 359M |
| MobileLLM-R1-950M | 22 | 24 | 6 | 1536 | 6144 | 949M |
| Input modalities | Output modalities | Context Length | Vocaburary Size | Shared Embeddings | |
|---|---|---|---|---|---|
| MobileLLM-R1-140M-base | Text | Text | 4k | 128k | Yes |
| MobileLLM-R1-360M-base | Text | Text | 4k | 128k | Yes |
| MobileLLM-R1-950M-base | Text | Text | 4k | 128k | Yes |
| MobileLLM-R1-140M | Text | Text | 32k | 128k | Yes |
| MobileLLM-R1-360M | Text | Text | 32k | 128k | Yes |
| MobileLLM-R1-950M | Text | Text | 32k | 128k | Yes |
How to use
To load the pretrained model for further finetuning or evaluation:
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/MobileLLM-R1-950M")
model = AutoModelForCausalLM.from_pretrained("facebook/MobileLLM-R1-950M")
Inference examples
Inference (MLX)
Use the MLX runtime provided in this repo to run the local model.safetensors on Apple silicon.
- Basic:
uv run python inference.py --prompt "Hello MLX" --temperature 0.7 --top-p 0.9 - Deterministic:
uv run python inference.py --prompt "Hello MLX" --temperature 0 --max-tokens 64
Flags in inference.py
--model-path: path to model directory (default:.)--prompt: input text--max-tokens: number of tokens to generate--temperature: 0 for greedy, >0 for sampling--top-p: nucleus sampling cutoff--system: optional system message when using chat template--final-only: instructs model to output only a final boxed answer--stop-at-boxed: stop generation after closing}following\boxed{--extract-boxed: print the last\boxed{...}content--disable-chat-template: bypasschat_template.jinjaand send raw prompt (with BOS)--repetition-penalty: discourage previously generated tokens (>1.0)--frequency-penalty: subtract alpha * token frequency from logits
See also: the “MLX Runtime (Apple silicon) — Added Files & Usage” section above for more examples and notes.
Inference (MLX-LM)
Two mlx-lm models are also provided, a conversion and a dynamic 4 bit quantization. code to reproduce and a handy inference runtime are provided in custom_mlx_lm/. After installation the following examples should work (I am forgetting, you may need to first copy the model into mlx_lm/ as llama4_text.py)
mobilellm-infer --model-path MobileLLM-R1-950M-mixed-4bit-mlx --prompt "What is the nearest prime to 9^2?
mobilellm-infer --model-path MobileLLM-R1-950M-mlx/ --prompt "What is the nearest prime to 9^2?"
Transformers
from transformers import pipeline
import torch
model_id = "facebook/MobileLLM-R1-950M"
pipe = pipeline(
"text-generation",
model=model_id,
torch_dtype="auto",
device_map="auto",
)
# Math problem / default scenario
messages = [
{
"role": "system",
"content": "Please reason step by step, and put your final answer within \\boxed{}."
},
{"role": "user", "content": "Compute: $1-2+3-4+5- \\dots +99-100$."},
]
# C++ coding scenario
messages = [
{
"role": "system",
"content": (
"\nYou are a helpful and harmless assistant. You should think step-by-step before responding to the instruction below.\n\n"
"Please use c++ programming language only.\n"
"You must use ```cpp for just the final solution code block with the following format:\n"
"```cpp\n# Your code here\n```\n"
)
},
{"role": "user", "content": "Write a C++ program that prints 'Hello, World!'."},
]
# Python coding scenario
messages = [
{
"role": "system",
"content": (
"\nYou are a helpful and harmless assistant. You should think step-by-step before responding to the instruction below.\n\n"
"Please use python programming language only.\n"
"You must use ```python for just the final solution code block with the following format:\n"
"```python\n# Your code here\n```\n"
)
},
{"role": "user", "content": "Write a Python function that returns the square of a number."},
]
outputs = pipe(
messages,
max_new_tokens=8192,
)
print(outputs[0]["generated_text"][-1])
You can also run inference with vLLM. You only need to register the model architecture Llama4ForCausalLM with the vLLM ModelRegistry.
from vllm.model_executor.models.llama4 import Llama4ForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)
Evaluation
MobileLLM-R1 base model
| Model | Size | MATH500 | GSM8K | MBPP | HumanEval | CommonSense Avg. | MMLU |
|---|---|---|---|---|---|---|---|
| 4-shot em |
8-shot em |
3-shot pass@1 |
0-shot pass@1 |
0-shot accuracy |
5-shot accuracy |
||
| <150M | |||||||
| SmolLM2-135M-base | 135M | 0.4 | 1.8 | 3.8 | 0.0 | 50.7 | -- |
| MobileLLM-R1-140M-base | 140M | 4.6 | 16.3 | 5.4 | 15.9 | 44.3 | -- |
| 150M - 400M | |||||||
| Gemma-3-270M-pt | 268M | 0.6 | 1.1 | 2.0 | 3.1 | 48.4 | 26.5 |
| SmolLM2-360M-base | 362M | 1.8 | 5.0 | 19.4 | 0.0 | 56.6 | 24.7 |
| MobileLLM-R1-360M-base | 359M | 13.4 | 39.4 | 20.8 | 32.9 | 51.0 | 26.8 |
| 400M - 1B | |||||||
| Qwen2.5-0.5B-base | 494M | 14.8 | 41.8 | 29.6 | 28.1 | 52.3 | 47.5 |
| Qwen3-0.6B-base | 596M | 29.8 | 60.9 | 39.0 | 30.5 | 55.3 | 52.4 |
| MobileLLM-R1-950M-base | 949M | 26.8 | 61.6 | 39.2 | 46.3 | 58.6 | 47.4 |
| > 1B | |||||||
| Gemma-3-1B-pt | 1.0B | 0.6 | 2.4 | 9.4 | 6.1 | 57.3 | 26.1 |
| LLaMA3.2-1B-base | 1.24B | 1.6 | 6.8 | 26.6 | 17.1 | 58.4 | 32.0 |
| OLMo-2-0425-1B-base | 1.48B | 5.2 | 39.8 | 7.8 | 6.7 | 61.0 | 42.4 |
| Qwen2.5-1.5B-base | 1.54B | 31.0 | 68.4 | 44.6 | 36.6 | 58.7 | 61.2 |
| SmolLM2-1.7B-base | 1.71B | 11.6 | 31.8 | 35.4 | 0.6 | 62.9 | 50.0 |
| Qwen3-1.7B-base | 2.03B | 38.5 | 76.2 | 56.4 | 47.6 | 60.9 | 62.1 |
Here, CommonSense Avg. denotes an average of 8 tasks in CommonSense Reasoning benchmarks including ARC-easy, ARC-challenge, BoolQ, PIQA, SIQA, HellaSwag, OBQA, and WinoGrand. Models with fewer than 150M parameters do not yield reliable MMLU scores and are therefore denoted as '—'.
MobileLLM-R1 post-trained model
| Model | Size | MATH500 | GSM8K | AIME'24 | AIME'25 | LiveCodeBench-v6 |
|---|---|---|---|---|---|---|
| 0-shot pass@1 |
0-shot pass@1 |
0-shot pass@1, n=64 |
0-shot pass@1, n=64 |
0-shot pass@1, n=16 |
||
| <150M | ||||||
| SmolLM2-135M-Instruct | 135M | 3.0 | 2.4 | -- | -- | 0.0 |
| MobileLLM-R1-140M | 140M | 7.4 | 3.0 | -- | -- | 1.0 |
| 150M - 400M | ||||||
| Gemma-3-270m-it | 268M | 6.8 | 8.4 | -- | -- | 0.0 |
| SmolLM2-360M-Instruct | 362M | 3.4 | 8.1 | -- | -- | 0.7 |
| MobileLLM-R1-360M | 359M | 26.6 | 22.7 | -- | -- | 4.8 |
| 400M - 1B | ||||||
| Qwen2.5-0.5B-Instruct | 494M | 31.2 | 48.1 | 0.1 | 0.3 | 3.6 |
| Qwen3-0.6B | 596M | 73.0 | 79.2 | 11.3 | 17.0 | 14.9 |
| MobileLLM-R1-950M | 949M | 74.0 | 67.5 | 15.5 | 16.3 | 19.9 |
| > 1B | ||||||
| Gemma-3-1B-it | 1.0B | 45.4 | 62.9 | 0.9 | 0.0 | 2.0 |
| LLaMA3.2-1B-Instruct | 1.24B | 24.8 | 38.8 | 1.1 | 0.2 | 4.1 |
| OLMo-2-0425-1B-Instruct | 1.48B | 19.2 | 69.7 | 0.6 | 0.1 | 0.0 |
| OpenReasoning-Nemotron-1.5B | 1.54B | 83.4 | 76.7 | 49.7 | 40.4 | 28.3 |
| DeepSeek-R1-Distill-Qwen-1.5B | 1.54B | 83.2 | 77.3 | 29.1 | 23.4 | 19.9 |
| Qwen2.5-1.5B-Instruct | 1.54B | 54.0 | 70.0 | 2.5 | 0.9 | 7.9 |
| SmolLM2-1.7B-Instruct | 1.71B | 19.2 | 41.8 | 0.3 | 0.1 | 4.4 |
| Qwen3-1.7B | 2.03B | 89.4 | 90.3 | 47.0 | 37.0 | 29.8 |
For AIME, we evaluate models across 64 runs and report the average accuracy. For LiveCodeBench, results are reported as the average accuracy across 16 runs. Models with fewer than 400M parameters do not produce reliable AIME scores and are therefore denoted as '—'.
Training
Training Process
Training stages and hyperparameter details
In the pretraining phase, MobileLLM-R1 models are randomly initialized and optimized using the Adam optimizer with hyperparameters (β_1, β_2, ε) = (0.9, 0.95, 1e-8), coupled with a weight decay coefficient of 0.1. The learning rate follows a 2k-step warmup schedule and then decays linearly from its peak to 10% of the maximum.
In the mid-training phase, we use Adam optimizer with learning rate linearly decays from its maximum value to zero. We employ knowledge distillation with Llama-3.1-8B-Instruct model as the teacher, where the student is trained via minimizing the KL divergence between its output logits and the teacher logits.
In the post-training phase, we use the Adam optimizer with zero weight decay. The learning rate warmup ratio is set to 0.03 for general-purpose SFT and 0.1 for reasoning-specific SFT, and it linearly decays from its maximum value to zero. Full training hyperparameters are provided in the table below.
| Stage | Phase | Tokens / Samples | BS | Sequence Length | Steps | LR | #GPUs | Training Time |
|---|---|---|---|---|---|---|---|---|
| Pre-training | Phase1 | 2T tokens | 16 | 2k | 500k | 4.00E-03 | 16 x 8 | 4-5 days |
| Phase2 | 2T tokens | 16 | 2k | 500k | 4.00E-03 | 16 x 8 | 4-5 days | |
| Mid-training | Phase1 | 100B tokens | 4 | 4k | 50K | 3.60E-04 | 16 x 8 | 1-2 days |
| Phase2 | 100B tokens | 4 | 4k | 50K | 3.60E-04 | 16 x 8 | 1-2 days | |
| Post-training | General SFT | 866K samples | 4 | 4k | 2 epochs | 5.00E-06 | 16 x 8 | ~2h |
| Reasoning SFT | 6.2M samples | 8 | 32k | 4 epochs | 8.00E-05 | 16 x 8 | ~2.5days |
Data Mix
Pre-training
| Dataset | Rows | Tokens (B) | Phase1 Mix Ratio | Phase2 Mix Ratio |
|---|---|---|---|---|
| StarCoder | 206,640,114 | 263.8 | 10.66% | 0.52% |
| OpenWebMath | 6,117,786 | 12.6 | 6.93% | 23.33% |
| FineWeb-Edu | 1,279,107,432 | 1300 | 63.75% | 54.83% |
| Wiki | 7,222,303 | 3.7 | 5.03% | 0.14% |
| Arxiv | 1,533,917 | 28 | 6.36% | 1.32% |
| StackExchange | 29,249,120 | 19.6 | 5.03% | 0.86% |
| Algebraic stack | 3,404,331 | 12.6 | 2.25% | 1.26% |
| Nemotron science | 708,920 | 2 | -- | 0.03% |
| Nemotron code | 10,108,883 | 16 | -- | 0.72% |
| Nemotron math | 22,066,397 | 15 | -- | 3.01% |
| Cosmopedia | 31,064,744 | 25 | -- | 2.70% |
| Facebook natural reasoning | 1,145,824 | 1.8 | -- | 3.18% |
| FineMath | 48,283,984 | 34 | -- | 8.01% |
| peS2o | 38,800,000 | 50 | -- | 0.08% |
| Total | 100% | 100% |
Mid-training
| Dataset | Subset | Rows (M) | Phase1 Mix Ratio | Phase2 Mix Ratio |
|---|---|---|---|---|
| Dolmino | DCLM Baseline | 606 | 37.03% | 6.51% |
| FLAN | 57.3 | 4.10% | 0.72% | |
| peS2o | 38.8 | 11.41% | 2.01% | |
| Wiki | 6.17 | 2.66% | 0.47% | |
| StackExchange | 2.48 | 2.12% | 2.00% | |
| Math | 21 | 11.63% | 29.10% | |
| Nemotron | Nemotron-Pretraining-Code-v1 | 882 | 20.69% | 29.10% |
| Nemotron-CC-Math-v1 | 144 | 3.45% | 19.40% | |
| StarCoder | StarCoder | 206 | 6.90% | 9.70% |
| Benchmark training set | TriviaQA (train) OBQA (train) NaturalQuestions (train) PIQA (train) GSM8K (train) BoolQ (train) ARC-Easy (train) ARC-Challenge (train) |
~0.01 | -- | 0.97% |
| Total | 100.00% | 100.00% |
Post-training
| Phase | Dataset | Rows |
|---|---|---|
| General SFT | Tulu-3-sft-olmo-2-mixture-0225 | 866K samples |
| Reasoning SFT | OpenMathReasoning | 3.2M samples |
| OpenScienceReasoning-2 | 803K samples | |
| OpenCodeReasoning-2 | 2.16M samples |
Citation
If you find our model useful for your research, please consider citing:
@misc{mobilellm_r1_2025,
title={MobileLLM-R1: Model Card},
author={Zechun Liu*, Ernie Chang*, Changsheng Zhao*, Chia-Jung Chang, Wei Wen, Chen Lai, Rick Cao, Yuandong Tian, Raghuraman Krishnamoorthi, Yangyang Shi, Vikas Chandra},
year={2025},
url = {https://huggingface.co/mobilellm-r1}
}
Contact
Zechun Liu, Meta Inc (zechunliu at meta dot com)
Ernie Chang, Meta Inc (erniecyc at meta dot com)
Changsheng Zhao, Meta Inc (cszhao at meta dot com)
License
MobileLLM-R1 is FAIR NC licensed as of now
- Downloads last month
- 39
Model tree for robbiemu/MobileLLM-R1-950M-MLX
Base model
facebook/MobileLLM-R1-950M-base


