Kernels
ca1207 commited on
Commit
35894d1
·
1 Parent(s): 6e9baad
test/test_muon/test.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
 
3
  import torch
4
  import torch.distributed as dist
5
- from muon import Muon, get_default_muon_param_groups
6
  from torch.distributed.fsdp import FSDPModule, fully_shard
7
  from torch.distributed.tensor import DTensor
8
  from torch.distributed.tensor.placement_types import Replicate
 
2
 
3
  import torch
4
  import torch.distributed as dist
5
+ from optimizer.muon import Muon, get_default_muon_param_groups
6
  from torch.distributed.fsdp import FSDPModule, fully_shard
7
  from torch.distributed.tensor import DTensor
8
  from torch.distributed.tensor.placement_types import Replicate
torch-ext/optimizer/muon.py CHANGED
@@ -701,10 +701,10 @@ class Muon(torch.optim.Optimizer):
701
  new_scale = math.sqrt(threshold / v_ele)
702
  if new_scale < scales_full[head_idx]:
703
  scales_full[head_idx] = new_scale
704
- #logger.info(
705
- # f"[{kind}] Head {head_idx} exceeded threshold "
706
- # f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
707
- #)
708
  scaling += 1
709
 
710
  return scales_full if scaling > 0 else None
 
701
  new_scale = math.sqrt(threshold / v_ele)
702
  if new_scale < scales_full[head_idx]:
703
  scales_full[head_idx] = new_scale
704
+ logger.info(
705
+ f"[{kind}] Head {head_idx} exceeded threshold "
706
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
707
+ )
708
  scaling += 1
709
 
710
  return scales_full if scaling > 0 else None