feat(muon): add test for muon
Browse files- README.md +4 -0
- test/test_muon/README.md +21 -0
- test/test_muon/__init__.py +0 -0
- test/test_muon/muon.py +1 -0
- test/test_muon/run_test.sh +2 -0
- test/test_muon/test.py +88 -0
README.md
CHANGED
|
@@ -77,3 +77,7 @@ The following tools are run via pre-commit:
|
|
| 77 |
```bash
|
| 78 |
pre-commit run isort --all-files
|
| 79 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
```bash
|
| 78 |
pre-commit run isort --all-files
|
| 79 |
```
|
| 80 |
+
|
| 81 |
+
### Test
|
| 82 |
+
|
| 83 |
+
- There is a [simple unittest for Parallel Muon](./test/test_muon/README.md)
|
test/test_muon/README.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Muon Optimizer Test
|
| 2 |
+
|
| 3 |
+
This directory contains a test script for the **Muon optimizer**.
|
| 4 |
+
|
| 5 |
+
To execute the test, simply run:
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
# By default, the test will use 8 GPUs.
|
| 9 |
+
./run_test.sh
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
The number of GPUs can be controlled with the NGPU environment variable.
|
| 13 |
+
For example, to run with 4 GPUs:
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
NGPU=4 ./run_test.sh
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Limitations:
|
| 20 |
+
- Multi-node execution is not supported yet.
|
| 21 |
+
- Ensure that the specified number of GPUs is available on your machine before running.
|
test/test_muon/__init__.py
ADDED
|
File without changes
|
test/test_muon/muon.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
../../torch-ext/optimizer/muon.py
|
test/test_muon/run_test.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NGPU=${NGPU:-"8"}
|
| 2 |
+
torchrun --nproc-per-node=8 test.py
|
test/test_muon/test.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
from muon import Muon
|
| 6 |
+
from torch.distributed.fsdp import FSDPModule, fully_shard
|
| 7 |
+
from torch.distributed.tensor import DTensor
|
| 8 |
+
from torch.distributed.tensor.placement_types import Replicate
|
| 9 |
+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
| 10 |
+
PreTrainedTokenizerBase)
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_model(fsdp: bool) -> torch.nn.Module:
|
| 17 |
+
model_name = "Motif-Technologies/Motif-2.6B"
|
| 18 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 19 |
+
model_name,
|
| 20 |
+
trust_remote_code=True,
|
| 21 |
+
).bfloat16().cuda()
|
| 22 |
+
|
| 23 |
+
torch.manual_seed(0)
|
| 24 |
+
random_grads = []
|
| 25 |
+
for param in model.parameters():
|
| 26 |
+
random_grad = torch.randn_like(param,
|
| 27 |
+
device=param.device,
|
| 28 |
+
dtype=param.dtype)
|
| 29 |
+
random_grads.append(random_grad)
|
| 30 |
+
|
| 31 |
+
if fsdp:
|
| 32 |
+
for layer in model.model.layers:
|
| 33 |
+
fully_shard(layer)
|
| 34 |
+
layer.reshard()
|
| 35 |
+
fully_shard(model)
|
| 36 |
+
model.reshard()
|
| 37 |
+
|
| 38 |
+
for i, param in enumerate(model.parameters()):
|
| 39 |
+
if isinstance(param.data, DTensor):
|
| 40 |
+
unsharded_grad = DTensor.from_local(
|
| 41 |
+
random_grads[i],
|
| 42 |
+
device_mesh=param.data.device_mesh,
|
| 43 |
+
placements=[Replicate()] * param.data.device_mesh.ndim,
|
| 44 |
+
)
|
| 45 |
+
sharded_grad = unsharded_grad.redistribute(
|
| 46 |
+
device_mesh=param.data.device_mesh,
|
| 47 |
+
placements=param.data.placements)
|
| 48 |
+
param.grad = sharded_grad
|
| 49 |
+
else:
|
| 50 |
+
param.grad = random_grads[i]
|
| 51 |
+
|
| 52 |
+
return model
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def run_muon(fsdp: bool) -> torch.nn.Module:
|
| 56 |
+
model = load_model(fsdp=fsdp)
|
| 57 |
+
optim = Muon(model)
|
| 58 |
+
optim.step()
|
| 59 |
+
|
| 60 |
+
return model
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def compare_results(parallel_muon_result: torch.nn.Module,
|
| 64 |
+
sequential_muon_result: torch.nn.Module) -> None:
|
| 65 |
+
for (name_p, p), (name_s,
|
| 66 |
+
s) in zip(parallel_muon_result.named_parameters(),
|
| 67 |
+
sequential_muon_result.named_parameters()):
|
| 68 |
+
p = p.data.full_tensor()
|
| 69 |
+
s = s.data
|
| 70 |
+
# Parallel Muon should exactly match Sequential Muon
|
| 71 |
+
if torch.abs(p - s).max() > 0:
|
| 72 |
+
max_diff_index = torch.argmax(torch.abs(p - s))
|
| 73 |
+
logger.error(f"Models differ at parameter {name_p}")
|
| 74 |
+
return
|
| 75 |
+
logger.info("Models match!")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def test_muon():
|
| 79 |
+
parallel_muon_result = run_muon(fsdp=True)
|
| 80 |
+
sequential_muon_result = run_muon(fsdp=False)
|
| 81 |
+
|
| 82 |
+
compare_results(parallel_muon_result, sequential_muon_result)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
dist.init_process_group(backend="nccl")
|
| 87 |
+
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
| 88 |
+
test_muon()
|