TaehyunKim
commited on
Update torch-ext/optimizer/muon.py
Browse files
torch-ext/optimizer/muon.py
CHANGED
|
@@ -656,7 +656,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
-
shard_mesh, process_group = self.get_shard_mesh(p)
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|