fix(optimizer): resolve bug where weight decay was multiplied by wrong lr value
#5
by
dongseokmotif
- opened
torch-ext/optimizer/muon.py
CHANGED
|
@@ -104,7 +104,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
|
@@ -133,7 +133,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 133 |
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
-
p.data.add_(u, alpha=-
|
| 137 |
|
| 138 |
|
| 139 |
def default_is_muon(x, name):
|
|
@@ -387,7 +387,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 387 |
state = param_to_state[id(p)]
|
| 388 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 389 |
_scatter(
|
| 390 |
-
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
+
def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
|
|
|
| 133 |
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
+
p.data.add_(u, alpha=-adjusted_lr)
|
| 137 |
|
| 138 |
|
| 139 |
def default_is_muon(x, name):
|
|
|
|
| 387 |
state = param_to_state[id(p)]
|
| 388 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 389 |
_scatter(
|
| 390 |
+
p, state, lr, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|