Kernels
TaehyunKim commited on
Commit
b0230e7
·
unverified ·
1 Parent(s): ff2fcfb

Update torch-ext/optimizer/muon.py

Browse files
Files changed (1) hide show
  1. torch-ext/optimizer/muon.py +1 -1
torch-ext/optimizer/muon.py CHANGED
@@ -656,7 +656,7 @@ class Muon(torch.optim.Optimizer):
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
- shard_mesh, process_group = self.get_shard_mesh(p)
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
  num_ranks = dist.get_world_size(group=process_group)
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
  num_ranks = dist.get_world_size(group=process_group)