Commit 
							
							·
						
						bdd2678
	
1
								Parent(s):
							
							8535e80
								
fix(muon): delete intermediate tensors immediately to lower peak mem usage
Browse files- build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
 - build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
 - build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -19
 - build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +3 -3
 - build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
 - build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +10 -19
 - build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
 - build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
 - build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -19
 - build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +3 -3
 - build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
 - build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +10 -19
 - build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +3 -3
 - build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
 - build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +10 -19
 - build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +3 -3
 - build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
 - build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +10 -19
 - build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +3 -3
 - build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
 - build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +10 -19
 - build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
 - build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
 - build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -19
 - build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
 - build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
 - build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -19
 - build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
 - build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
 - build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +10 -19
 - build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc +0 -0
 - build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc +0 -0
 - build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
 - build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
 - build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +10 -19
 - torch-ext/optimizer/muon.py +10 -19
 
    	
        build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from . import  
     | 
| 3 | 
         
            -
            ops = torch.ops. 
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            -
                return f" 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from . import _optimizer_8535e80_dirty
         
     | 
| 3 | 
         
            +
            ops = torch.ops._optimizer_8535e80_dirty
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            +
                return f"_optimizer_8535e80_dirty::{op_name}"
         
     | 
    	
        build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:a46d9e65efcfa82522950d9ebf2b2b4594d9ed5abc28704352a1f7de2dae707a
         
     | 
| 3 | 
         
            +
            size 1787272
         
     | 
    	
        build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         | 
    	
        build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from . import  
     | 
| 3 | 
         
            -
            ops = torch.ops. 
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            -
                return f" 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from . import _optimizer_8535e80_dirty
         
     | 
| 3 | 
         
            +
            ops = torch.ops._optimizer_8535e80_dirty
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            +
                return f"_optimizer_8535e80_dirty::{op_name}"
         
     | 
    	
        build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:d351a600884b7378f546a345afe65c176e1399bb42fb7dfe4333b0e90975803b
         
     | 
| 3 | 
         
            +
            size 1824224
         
     | 
    	
        build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         | 
    	
        build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from . import  
     | 
| 3 | 
         
            -
            ops = torch.ops. 
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            -
                return f" 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from . import _optimizer_8535e80_dirty
         
     | 
| 3 | 
         
            +
            ops = torch.ops._optimizer_8535e80_dirty
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            +
                return f"_optimizer_8535e80_dirty::{op_name}"
         
     | 
    	
        build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:2c0843f38cee494b7a5939eb62d27039d76dc3f69401d411efbacaa25cb0d67a
         
     | 
| 3 | 
         
            +
            size 1824224
         
     | 
    	
        build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         | 
    	
        build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from . import  
     | 
| 3 | 
         
            -
            ops = torch.ops. 
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            -
                return f" 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from . import _optimizer_8535e80_dirty
         
     | 
| 3 | 
         
            +
            ops = torch.ops._optimizer_8535e80_dirty
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            +
                return f"_optimizer_8535e80_dirty::{op_name}"
         
     | 
    	
        build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:acdba99ce95532a9ca6a8987a7ab61a257657872f2cc672c91e8e5fe809aa24e
         
     | 
| 3 | 
         
            +
            size 1749744
         
     | 
    	
        build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         | 
    	
        build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from . import  
     | 
| 3 | 
         
            -
            ops = torch.ops. 
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            -
                return f" 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from . import _optimizer_8535e80_dirty
         
     | 
| 3 | 
         
            +
            ops = torch.ops._optimizer_8535e80_dirty
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            +
                return f"_optimizer_8535e80_dirty::{op_name}"
         
     | 
    	
        build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:f7d5e76c002507f66f2a227d02c2b11aa3fdc3f07a2a0b82faaa34133adb77ef
         
     | 
| 3 | 
         
            +
            size 1787192
         
     | 
    	
        build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         | 
    	
        build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from . import  
     | 
| 3 | 
         
            -
            ops = torch.ops. 
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            -
                return f" 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from . import _optimizer_8535e80_dirty
         
     | 
| 3 | 
         
            +
            ops = torch.ops._optimizer_8535e80_dirty
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            +
                return f"_optimizer_8535e80_dirty::{op_name}"
         
     | 
    	
        build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:becccd250f38a84803350cfb5fac3a6682b1e594968a714642724cbc71246b4a
         
     | 
| 3 | 
         
            +
            size 1824184
         
     | 
    	
        build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         | 
    	
        build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from . import  
     | 
| 3 | 
         
            -
            ops = torch.ops. 
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            -
                return f" 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from . import _optimizer_8535e80_dirty
         
     | 
| 3 | 
         
            +
            ops = torch.ops._optimizer_8535e80_dirty
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            +
                return f"_optimizer_8535e80_dirty::{op_name}"
         
     | 
    	
        build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:34215ecc274ef516967962c8457dad214e9bbf618bf5eee8f467371f4f620284
         
     | 
| 3 | 
         
            +
            size 1824184
         
     | 
    	
        build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from . import  
     | 
| 3 | 
         
            -
            ops = torch.ops. 
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            -
                return f" 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from . import _optimizer_8535e80_dirty
         
     | 
| 3 | 
         
            +
            ops = torch.ops._optimizer_8535e80_dirty
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            +
                return f"_optimizer_8535e80_dirty::{op_name}"
         
     | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:c23a3adbe4dc1a64b4851a9f8e4aed0e3e1eeeded27322c54f5b942282a2a332
         
     | 
| 3 | 
         
            +
            size 1787368
         
     | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         | 
    	
        build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from . import  
     | 
| 3 | 
         
            -
            ops = torch.ops. 
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            -
                return f" 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from . import _optimizer_8535e80_dirty
         
     | 
| 3 | 
         
            +
            ops = torch.ops._optimizer_8535e80_dirty
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            +
                return f"_optimizer_8535e80_dirty::{op_name}"
         
     | 
    	
        build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:d4aa09c22745d5efe1ef0669c4ca05615f67595dc90cabeee6e878301fa9bd22
         
     | 
| 3 | 
         
            +
            size 1824256
         
     | 
    	
        build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         | 
    	
        build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from . import  
     | 
| 3 | 
         
            -
            ops = torch.ops. 
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            -
                return f" 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from . import _optimizer_8535e80_dirty
         
     | 
| 3 | 
         
            +
            ops = torch.ops._optimizer_8535e80_dirty
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            +
                return f"_optimizer_8535e80_dirty::{op_name}"
         
     | 
    	
        build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:b4baf569b70749c4657062fb0f56943fc486adb0c482e50c7aa8e31ddf5cc870
         
     | 
| 3 | 
         
            +
            size 1883352
         
     | 
    	
        build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         | 
    	
        build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc
    CHANGED
    
    | 
         Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc differ 
     | 
| 
         | 
    	
        build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc
    CHANGED
    
    | 
         Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc differ 
     | 
| 
         | 
    	
        build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from . import  
     | 
| 3 | 
         
            -
            ops = torch.ops. 
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            -
                return f" 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from . import _optimizer_8535e80_dirty
         
     | 
| 3 | 
         
            +
            ops = torch.ops._optimizer_8535e80_dirty
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            def add_op_namespace_prefix(op_name: str):
         
     | 
| 6 | 
         
             
                """
         
     | 
| 7 | 
         
             
                Prefix op by namespace.
         
     | 
| 8 | 
         
             
                """
         
     | 
| 9 | 
         
            +
                return f"_optimizer_8535e80_dirty::{op_name}"
         
     | 
    	
        build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:8566c9bc05e13c9394572f9f9c6bac24c31932548be485f49eb49fb249880832
         
     | 
| 3 | 
         
            +
            size 1749648
         
     | 
    	
        build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         | 
    	
        torch-ext/optimizer/muon.py
    CHANGED
    
    | 
         @@ -48,7 +48,6 @@ class _muon_state: 
     | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 51 | 
         
            -
                scattered_u: torch.Tensor | None = None
         
     | 
| 52 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 54 | 
         | 
| 
         @@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream): 
     | 
|
| 93 | 
         
             
                        state.computed_u = u
         
     | 
| 94 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 95 | 
         
             
                        state.compute_event.record()
         
     | 
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                    else:
         
     | 
| 97 | 
         
             
                        state.computed_u = None
         
     | 
| 98 | 
         
             
                        state.compute_event = None
         
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
            -
            def _scatter(p, state, rank, comm_stream):
         
     | 
| 102 | 
         
             
                u = state.computed_u
         
     | 
| 103 | 
         
             
                mesh = p.device_mesh
         
     | 
| 104 | 
         | 
| 
         @@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream): 
     | 
|
| 118 | 
         
             
                        src=state.worker_rank,
         
     | 
| 119 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 120 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 122 | 
         
             
                        u,
         
     | 
| 123 | 
         
             
                        placements=p.placements,
         
     | 
| 124 | 
         
             
                        device_mesh=mesh,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                     
     | 
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         @@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 353 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 354 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 355 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 356 | 
         
            -
                             
     | 
| 
         | 
|
| 357 | 
         | 
| 358 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 359 | 
         | 
| 
         @@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer): 
     | 
|
| 368 | 
         | 
| 369 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 370 | 
         | 
| 371 | 
         
            -
                    for p in params:
         
     | 
| 372 | 
         
            -
                        g = p.grad
         
     | 
| 373 | 
         
            -
                        if g is None:
         
     | 
| 374 | 
         
            -
                            continue
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                        # Update p with sharded u
         
     | 
| 377 | 
         
            -
                        state = param_to_state[id(p)]
         
     | 
| 378 | 
         
            -
                        self._update_p(
         
     | 
| 379 | 
         
            -
                            p,
         
     | 
| 380 | 
         
            -
                            state.scattered_u,
         
     | 
| 381 | 
         
            -
                            lr=lr,
         
     | 
| 382 | 
         
            -
                            wd=wd,
         
     | 
| 383 | 
         
            -
                        )
         
     | 
| 384 | 
         
            -
             
     | 
| 385 | 
         
             
                def step(self, closure=None):
         
     | 
| 386 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 387 | 
         | 
| 
         | 
|
| 48 | 
         
             
                worker_rank: int | None = None
         
     | 
| 49 | 
         
             
                gathered_grad: torch.Tensor | None = None
         
     | 
| 50 | 
         
             
                computed_u: torch.Tensor | None = None
         
     | 
| 
         | 
|
| 51 | 
         
             
                gather_event: torch.cuda.Event | None = None
         
     | 
| 52 | 
         
             
                compute_event: torch.cuda.Event | None = None
         
     | 
| 53 | 
         | 
| 
         | 
|
| 92 | 
         
             
                        state.computed_u = u
         
     | 
| 93 | 
         
             
                        state.compute_event = torch.cuda.Event()
         
     | 
| 94 | 
         
             
                        state.compute_event.record()
         
     | 
| 95 | 
         
            +
                        state.gathered_grad.record_stream(compute_stream)
         
     | 
| 96 | 
         
            +
                        del state.gathered_grad
         
     | 
| 97 | 
         
             
                    else:
         
     | 
| 98 | 
         
             
                        state.computed_u = None
         
     | 
| 99 | 
         
             
                        state.compute_event = None
         
     | 
| 100 | 
         | 
| 101 | 
         | 
| 102 | 
         
            +
            def _scatter(p, state, lr, wd, rank, comm_stream):
         
     | 
| 103 | 
         
             
                u = state.computed_u
         
     | 
| 104 | 
         
             
                mesh = p.device_mesh
         
     | 
| 105 | 
         | 
| 
         | 
|
| 119 | 
         
             
                        src=state.worker_rank,
         
     | 
| 120 | 
         
             
                        group=mesh.get_group(),
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 122 | 
         
            +
                    if rank == state.worker_rank:
         
     | 
| 123 | 
         
            +
                        state.computed_u.record_stream(comm_stream)
         
     | 
| 124 | 
         
            +
                        del state.computed_u
         
     | 
| 125 | 
         
             
                    u = DTensor.from_local(
         
     | 
| 126 | 
         
             
                        u,
         
     | 
| 127 | 
         
             
                        placements=p.placements,
         
     | 
| 128 | 
         
             
                        device_mesh=mesh,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         
            +
                    p.data.mul_(1 - lr * wd)
         
     | 
| 131 | 
         
            +
                    p.data.add_(u, alpha=-lr)
         
     | 
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            class Muon(torch.optim.Optimizer):
         
     | 
| 
         | 
|
| 357 | 
         
             
                    def enqueue_scatters(start_idx, chunk_size):
         
     | 
| 358 | 
         
             
                        for p in ordered_params[start_idx : start_idx + chunk_size]:
         
     | 
| 359 | 
         
             
                            state = param_to_state[id(p)]
         
     | 
| 360 | 
         
            +
                            adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
         
     | 
| 361 | 
         
            +
                            _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
         
     | 
| 362 | 
         | 
| 363 | 
         
             
                    chunk_size = params[0].device_mesh.mesh.numel()
         
     | 
| 364 | 
         | 
| 
         | 
|
| 373 | 
         | 
| 374 | 
         
             
                    torch.cuda.current_stream().wait_stream(self.comm_stream)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 376 | 
         
             
                def step(self, closure=None):
         
     | 
| 377 | 
         
             
                    """Perform a single optimization step.
         
     | 
| 378 | 
         |