Instructions to use Motif-Technologies/optimizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/optimizer with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/optimizer") - Notebooks
- Google Colab
- Kaggle
| import copy | |
| import logging | |
| import time | |
| from contextlib import nullcontext | |
| import pytest | |
| import torch | |
| import torch.distributed as dist | |
| from optimizer.muon import Muon, get_default_muon_param_groups | |
| from optimizer.newton_schulz import set_ns_compile | |
| from torch.distributed.tensor import (DTensor, Replicate, Shard, | |
| distribute_tensor) | |
| from torch.profiler import ProfilerActivity, profile | |
| from .utils import (ParallelDims, _apply_fsdp, assert_params_equal, | |
| parallelize_motif, parallelize_qk_logits) | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| def apply_muon_step( | |
| model: torch.nn.Module, | |
| parallel_dims: ParallelDims | None, | |
| grads: list[torch.Tensor], | |
| warmup_step: int, | |
| chunk_size: int, | |
| qk_logits: dict[int, torch.Tensor] | None = None, | |
| use_distributed_muon: bool = False, | |
| measure_perf: bool = False, | |
| do_profile: bool = False, | |
| test_name: str | None = None, | |
| ) -> tuple[torch.nn.Module, tuple[float, float] | None]: | |
| """ apply single Muon step with optional QK clipping """ | |
| # 1. Apply gradients to model parameters | |
| assert len(grads) == len(list(model.parameters())) | |
| for grad, param in zip(grads, model.parameters()): | |
| grad = grad.to(param.device) | |
| if isinstance(param.data, DTensor): | |
| unsharded_grad = DTensor.from_local( | |
| grad, | |
| device_mesh=param.data.device_mesh, | |
| placements=[Replicate()] * param.data.device_mesh.ndim, | |
| ) | |
| sharded_grad = unsharded_grad.redistribute( | |
| device_mesh=param.data.device_mesh, | |
| placements=param.data.placements) | |
| param.grad = sharded_grad | |
| else: | |
| param.grad = grad | |
| # 2. Setup Muon optimizer | |
| params = get_default_muon_param_groups(model) | |
| clip_config = dict({ | |
| "q_indices": | |
| list(range(model.config.num_attention_heads)), | |
| "k_indices": | |
| list(range(model.config.num_attention_heads)), | |
| "head_dim": | |
| model.config.hidden_size // model.config.num_attention_heads, | |
| "threshold": | |
| 0.5 | |
| }) | |
| optim = Muon( | |
| params=params, | |
| clip_config=clip_config if qk_logits is not None else None, | |
| none_grad=False, | |
| warmup_step=warmup_step, | |
| chunk_size=chunk_size, | |
| use_distributed_muon=use_distributed_muon, | |
| ) | |
| optim.step(qk_logits=qk_logits) | |
| timing_result: tuple[float, float] | None = None | |
| if measure_perf: | |
| # extra warm up | |
| optim.step(qk_logits=qk_logits) | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| torch.cuda.reset_peak_memory_stats() | |
| start.record() | |
| num_iters = 20 | |
| if do_profile: | |
| context = profile( | |
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | |
| record_shapes=True) | |
| else: | |
| context = nullcontext() | |
| with context as prof: | |
| for _i in range(num_iters): | |
| optim.step(qk_logits=qk_logits) | |
| end.record() | |
| end.synchronize() | |
| if prof is not None: | |
| date = time.strftime("%Y%m%d_%H%M%S", time.localtime()) | |
| name = test_name or "trace" | |
| rank = dist.get_rank() | |
| prof.export_chrome_trace(f"{name}_{date}_rank{rank}.json") | |
| peak_memory = torch.cuda.max_memory_allocated() | |
| elapsed_time_ms = start.elapsed_time(end) / num_iters | |
| timing_result = (elapsed_time_ms, peak_memory) | |
| return model, timing_result | |
| def sequential_muon_result( | |
| skip_verify, # from conftest.py | |
| inputs # from conftest.py | |
| ) -> dict[tuple[bool, bool], torch.nn.Module]: | |
| """Run Muon optimizer to sequential model for baseline results. | |
| Returns dict keyed by ``(apply_qk_clip, use_compile)``. | |
| """ | |
| if skip_verify: | |
| logger.info("Skipping verification tests as per user request") | |
| return None | |
| model, grads, qk_logits = inputs | |
| results: dict[tuple[bool, bool], torch.nn.Module] = {} | |
| for use_compile in [False, True]: | |
| set_ns_compile(use_compile) | |
| results[(False, use_compile)] = apply_muon_step( | |
| model=copy.deepcopy(model).cuda(), | |
| parallel_dims=None, | |
| grads=grads, | |
| warmup_step=-1, | |
| chunk_size=-1, | |
| qk_logits=None, | |
| )[0].cpu() | |
| results[(True, use_compile)] = apply_muon_step( | |
| model=copy.deepcopy(model).cuda(), | |
| parallel_dims=None, | |
| grads=grads, | |
| warmup_step=-1, | |
| chunk_size=-1, | |
| qk_logits=qk_logits, | |
| )[0].cpu() | |
| set_ns_compile(True) # restore default | |
| return results | |
| OVERLAP_STEPS = [5] | |
| CHUNK_SIZES = [2] | |
| def test_parallel_muon( | |
| request, | |
| sequential_muon_result: dict[tuple[bool, bool], torch.nn.Module], | |
| parallel_dims: ParallelDims, | |
| apply_qk_clip: bool, | |
| use_distributed_muon: bool, | |
| warmup_step: int, | |
| chunk_size: int, | |
| use_compile: bool, | |
| inputs: tuple[torch.nn.Module, list[torch.Tensor], | |
| dict[int, torch.Tensor]], # from conftest.py | |
| measure_perf, # from conftest.py | |
| do_profile, # from conftest.py | |
| ) -> None: | |
| if use_distributed_muon and chunk_size != CHUNK_SIZES[0]: | |
| pytest.skip("Distributed Muon does not effected by chunk size") | |
| if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]: | |
| pytest.skip("Distributed Muon does not effected by warmup step") | |
| set_ns_compile(use_compile) | |
| model, grads, qk_logits = inputs | |
| if not apply_qk_clip: | |
| qk_logits = None | |
| # Deepcopy the model to avoid in-place modification | |
| model = copy.deepcopy(model).cuda() | |
| parallelized_model = parallelize_motif(model, parallel_dims) | |
| if qk_logits is not None: | |
| # Deepcopy the qk logits to avoid in-place modification | |
| qk_logits = copy.deepcopy(qk_logits) | |
| qk_logits = parallelize_qk_logits(qk_logits, parallel_dims) | |
| parallelized_model, timing_result = apply_muon_step( | |
| model=parallelized_model, | |
| parallel_dims=parallel_dims, | |
| grads=grads, | |
| warmup_step=warmup_step, | |
| chunk_size=chunk_size, | |
| qk_logits=qk_logits, | |
| use_distributed_muon=use_distributed_muon, | |
| measure_perf=measure_perf, | |
| do_profile=do_profile, | |
| test_name=request.node.name, | |
| ) | |
| if measure_perf: | |
| assert timing_result is not None | |
| avg_time_ms, peak_memory = timing_result | |
| logger.info( | |
| f"\nParallel dims: {parallel_dims}, " | |
| f"\nUse distributed Muon: {use_distributed_muon}, " | |
| f"\nApply QK clip: {apply_qk_clip} => " | |
| f"\nChunk Size, Warmup Step, Avg Time (ms), Peak Memory (MB):" | |
| f"\n{chunk_size}, {warmup_step}, {avg_time_ms:.2f}, {peak_memory / (1024**2):.2f}," | |
| ) | |
| if sequential_muon_result is None: | |
| logger.info("Skipping correctness check as sequential result is None") | |
| elif measure_perf: | |
| logger.info("Skipping correctness check as timing is enabled") | |
| else: | |
| atol = 1e-5 if use_compile else 0 | |
| rtol = 1e-2 if use_compile else 0 | |
| assert_params_equal(parallelized_model, | |
| sequential_muon_result[(apply_qk_clip, | |
| use_compile)], | |
| atol=atol, | |
| rtol=rtol) | |
| def test_parallel_muon_empty_shard(init_dist): | |
| """Regression: parallel Muon must handle chunks where some ranks have | |
| empty local shards (dim-0 < world_size). | |
| With 8-way Shard(0) and dim-0 of size 4, ranks 4-7 get 0-element local | |
| shards. Previously ``_launch_gather`` hit ``assert total_send > 0``. | |
| """ | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| mesh = dist.init_device_mesh("cuda", (world_size, ), | |
| mesh_dim_names=("dp", )) | |
| set_ns_compile(False) | |
| # dim-0 = 4 < 8 ranks → ranks 4-7 have empty local shards with Shard(0) | |
| small_dim = 4 | |
| num_params = 4 | |
| torch.manual_seed(42) | |
| muon_params = [] | |
| muon_names = [] | |
| for i in range(num_params): | |
| full = torch.randn(small_dim, 64, device="cuda") | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| grad_full = torch.randn(small_dim, 64, device="cuda") | |
| p.grad = distribute_tensor(grad_full, mesh, [Shard(0)]) | |
| muon_params.append(p) | |
| muon_names.append(f"layer.{i}.weight") | |
| param_groups = [{ | |
| "params": muon_params, | |
| "names": muon_names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| }] | |
| optim = Muon(params=param_groups, chunk_size=1, warmup_step=0) | |
| # Must not raise AssertionError: total_send > 0 | |
| optim.step() | |
| # Run a second step to verify cached path also works | |
| for p in muon_params: | |
| grad_full = torch.randn(small_dim, 64, device="cuda") | |
| p.grad = distribute_tensor(grad_full, mesh, [Shard(0)]) | |
| optim.step() | |
| set_ns_compile(True) | |
| logger.info("test_parallel_muon_empty_shard PASSED (rank %d)", rank) | |
| def test_parallel_muon_uneven_shard(init_dist, uneven_dim): | |
| """Test that parallel Muon produces correct results when parameter | |
| dimensions are not evenly divisible by the number of shard ranks. | |
| For example, dim=33 with 8 ranks gives 7 ranks with 4 rows and | |
| 1 rank with 5 rows. This exercises the remainder-handling logic | |
| in ``get_slices_of_dtensor`` and the all-to-all pipeline. | |
| """ | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| mesh = dist.init_device_mesh("cuda", (world_size, ), | |
| mesh_dim_names=("dp", )) | |
| set_ns_compile(False) | |
| torch.manual_seed(42) | |
| other_dim = 64 | |
| num_params = 3 | |
| # --- Build sharded params + grads --- | |
| muon_params = [] | |
| muon_names = [] | |
| full_params_snapshot = [] | |
| full_grads = [] | |
| for i in range(num_params): | |
| full = torch.randn(uneven_dim, other_dim, device="cuda") | |
| full_params_snapshot.append(full.clone()) | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| grad_full = torch.randn(uneven_dim, other_dim, device="cuda") | |
| full_grads.append(grad_full.clone()) | |
| p.grad = distribute_tensor(grad_full, mesh, [Shard(0)]) | |
| muon_params.append(p) | |
| muon_names.append(f"layer.{i}.weight") | |
| # --- Parallel path (all2all pipeline) --- | |
| param_groups_par = [{ | |
| "params": muon_params, | |
| "names": muon_names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| }] | |
| optim_par = Muon(params=param_groups_par, chunk_size=1, warmup_step=0) | |
| optim_par.step() | |
| # --- Sequential baseline (base path, no sharding) --- | |
| seq_params = [] | |
| seq_names = [] | |
| for i in range(num_params): | |
| p = torch.nn.Parameter(full_params_snapshot[i].clone()) | |
| p.grad = full_grads[i].clone() | |
| seq_params.append(p) | |
| seq_names.append(f"layer.{i}.weight") | |
| param_groups_seq = [{ | |
| "params": seq_params, | |
| "names": seq_names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| }] | |
| optim_seq = Muon(params=param_groups_seq) | |
| optim_seq.step() | |
| # --- Compare: parallel result (gathered) must match sequential --- | |
| for i in range(num_params): | |
| par_full = muon_params[i].data.full_tensor() | |
| seq_full = seq_params[i].data | |
| torch.testing.assert_close(par_full, seq_full, atol=0, rtol=0) | |
| set_ns_compile(True) | |
| logger.info("test_parallel_muon_uneven_shard (dim=%d) PASSED (rank %d)", | |
| uneven_dim, rank) | |
| def test_pp_dp_replicate_no_deadlock(init_dist, inputs): | |
| """PP regression test using real Motif model. | |
| PP=2, dp_replicate=2, dp_shard=2 on 8 GPUs. Splits the | |
| Motif-2.6B-4layer model across 2 pipeline stages following the | |
| torchtitan pattern (deep copy → delete non-stage layers → per-stage | |
| FSDP). Each stage independently runs Muon optimizer and the result | |
| is verified against a sequential baseline (atol=0, rtol=0). | |
| Without use_local_synchronization=True in construct_shard_mesh(), | |
| different stages would deadlock on dist.new_group() because they | |
| call it for different parameters. | |
| """ | |
| import re | |
| import torch.nn as nn | |
| from optimizer.distributed.utils import _ranks_to_dist_cache | |
| rank = dist.get_rank() | |
| assert dist.get_world_size() == 8 | |
| set_ns_compile(False) | |
| _ranks_to_dist_cache.clear() | |
| model_orig, grads_orig, _ = inputs | |
| # Build name→grad mapping from original model | |
| grad_dict = { | |
| name: grad | |
| for (name, _), grad in zip(model_orig.named_parameters(), grads_orig) | |
| } | |
| # Full mesh: PP=2, dp_replicate=2, dp_shard=2 | |
| full_mesh = dist.init_device_mesh( | |
| "cuda", | |
| (2, 2, 2), | |
| mesh_dim_names=("pp", "dp_replicate", "dp_shard"), | |
| ) | |
| dp_mesh = full_mesh["dp_replicate", "dp_shard"] | |
| pp_rank = full_mesh.get_local_rank("pp") | |
| # -- Helpers ---------------------------------------------------------- | |
| def _split_motif(model): | |
| """Split Motif model per PP stage (torchtitan pattern). | |
| Stage 0: embed_tokens + layers[0:2] | |
| Stage 1: layers[2:4] + norm + output | |
| Non-stage components replaced with nn.Identity (no params). | |
| """ | |
| all_layers = list(model.model.layers) | |
| if pp_rank == 0: | |
| model.model.layers = nn.ModuleList(all_layers[:2]) | |
| model.model.norm = nn.Identity() | |
| if hasattr(model, "output"): | |
| model.output = nn.Identity() | |
| if hasattr(model, "lm_head"): | |
| model.lm_head = nn.Identity() | |
| else: | |
| model.model.layers = nn.ModuleList(all_layers[2:]) | |
| model.model.embed_tokens = nn.Identity() | |
| return model | |
| layer_offset = 0 if pp_rank == 0 else 2 | |
| def _remap(name): | |
| """Map stage param name → original param name (layer index offset). | |
| Also handles weight tying: Motif ties lm_head.weight to | |
| model.embed_tokens.weight, so named_parameters() only lists the | |
| latter. After stage-split, stage 1 loses embed_tokens but keeps | |
| lm_head, so we remap it back. | |
| """ | |
| # Weight tying: lm_head.weight ↔ model.embed_tokens.weight | |
| if name == "lm_head.weight": | |
| return "model.embed_tokens.weight" | |
| if layer_offset == 0: | |
| return name | |
| def _replace(m): | |
| return f"layers.{int(m.group(1)) + layer_offset}." | |
| return re.sub(r"layers\.(\d+)\.", _replace, name) | |
| def _stage_grads(model): | |
| """Build grads list aligned with stage model parameters.""" | |
| return [grad_dict[_remap(n)] for n, _ in model.named_parameters()] | |
| # -- Parallel path: split → FSDP → Muon step ------------------------- | |
| par_model = _split_motif(copy.deepcopy(model_orig).cuda()) | |
| _apply_fsdp(par_model, dp_mesh) | |
| par_model, _ = apply_muon_step( | |
| model=par_model, | |
| parallel_dims=None, | |
| grads=_stage_grads(par_model), | |
| warmup_step=5, | |
| chunk_size=2, | |
| qk_logits=None, | |
| ) | |
| # -- Sequential baseline: split → no FSDP → base Muon ---------------- | |
| seq_model = _split_motif(copy.deepcopy(model_orig).cuda()) | |
| seq_model, _ = apply_muon_step( | |
| model=seq_model, | |
| parallel_dims=None, | |
| grads=_stage_grads(seq_model), | |
| warmup_step=-1, | |
| chunk_size=-1, | |
| qk_logits=None, | |
| ) | |
| # Correctness: parallel must match sequential exactly | |
| assert_params_equal(par_model, seq_model, atol=0, rtol=0) | |
| set_ns_compile(True) | |
| logger.info( | |
| "test_pp_dp_replicate_no_deadlock PASSED (rank %d, pp_rank %d)", rank, | |
| pp_rank) | |