misc
Browse files- test/test_muon/test.py +1 -1
- torch-ext/optimizer/muon.py +4 -4
test/test_muon/test.py
CHANGED
|
@@ -2,7 +2,7 @@ import logging
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.distributed as dist
|
| 5 |
-
from muon import Muon, get_default_muon_param_groups
|
| 6 |
from torch.distributed.fsdp import FSDPModule, fully_shard
|
| 7 |
from torch.distributed.tensor import DTensor
|
| 8 |
from torch.distributed.tensor.placement_types import Replicate
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.distributed as dist
|
| 5 |
+
from optimizer.muon import Muon, get_default_muon_param_groups
|
| 6 |
from torch.distributed.fsdp import FSDPModule, fully_shard
|
| 7 |
from torch.distributed.tensor import DTensor
|
| 8 |
from torch.distributed.tensor.placement_types import Replicate
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -701,10 +701,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 701 |
new_scale = math.sqrt(threshold / v_ele)
|
| 702 |
if new_scale < scales_full[head_idx]:
|
| 703 |
scales_full[head_idx] = new_scale
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
scaling += 1
|
| 709 |
|
| 710 |
return scales_full if scaling > 0 else None
|
|
|
|
| 701 |
new_scale = math.sqrt(threshold / v_ele)
|
| 702 |
if new_scale < scales_full[head_idx]:
|
| 703 |
scales_full[head_idx] = new_scale
|
| 704 |
+
logger.info(
|
| 705 |
+
f"[{kind}] Head {head_idx} exceeded threshold "
|
| 706 |
+
f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
|
| 707 |
+
)
|
| 708 |
scaling += 1
|
| 709 |
|
| 710 |
return scales_full if scaling > 0 else None
|