Commit
·
5ce4c31
1
Parent(s):
867401e
Teach HIP grouped_gemm about autograd
Browse files- wrap the ROCm grouped GEMM call in a torch.autograd.Function so hidden states and expert weights receive gradients
- reuse the backend kernel for backward matmuls and normalize batch size tensors on the host
- note the hipBLASLt opt-in flag in grouped_gemm.hip while keeping it off by default
Tests: python -m pytest axolotl.shisa/tests/e2e/test_ring_moe_grouped.py -k megablocks_gradient_parity -s
csrc/grouped_gemm/grouped_gemm.hip
CHANGED
|
@@ -17,6 +17,9 @@
|
|
| 17 |
namespace grouped_gemm {
|
| 18 |
namespace {
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
bool use_hipblaslt_backend() {
|
| 21 |
static int cached = [] {
|
| 22 |
const char* raw = std::getenv("MEGABLOCKS_GG_USE_HIPBLASLT");
|
|
|
|
| 17 |
namespace grouped_gemm {
|
| 18 |
namespace {
|
| 19 |
|
| 20 |
+
// Experimental: toggled via MEGABLOCKS_GG_USE_HIPBLASLT=1. This flag is
|
| 21 |
+
// intentionally off by default because the hipBLASLt path still fails on the
|
| 22 |
+
// largest `tests/ops_test.py` configurations.
|
| 23 |
bool use_hipblaslt_backend() {
|
| 24 |
static int cached = [] {
|
| 25 |
const char* raw = std::getenv("MEGABLOCKS_GG_USE_HIPBLASLT");
|
torch-ext/megablocks/grouped_gemm/backend.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
| 1 |
# NOTE: Torch needs to be imported before the custom
|
| 2 |
# extensions. Otherwise libc10.so cannot be found.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
|
| 5 |
# # TODO(tgale): Wrap this in a try-block with better
|
|
@@ -13,6 +17,7 @@ import torch
|
|
| 13 |
# from megablocks._ops import ops as backend # type: ignore
|
| 14 |
from .._ops import ops as backend # type: ignore
|
| 15 |
|
|
|
|
| 16 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
| 17 |
assert not (trans_a and trans_b)
|
| 18 |
assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
|
|
@@ -32,8 +37,99 @@ def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
|
| 32 |
# reproduced by `_dev/debug-gg-small.py`.
|
| 33 |
return torch.zeros(*shape, device=a.device, dtype=a.dtype)
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
if c is None:
|
| 37 |
-
c = _allocate_output(a, b,
|
| 38 |
-
backend.gmm(a, b, c,
|
| 39 |
return c
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# NOTE: Torch needs to be imported before the custom
|
| 2 |
# extensions. Otherwise libc10.so cannot be found.
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
import torch
|
| 8 |
|
| 9 |
# # TODO(tgale): Wrap this in a try-block with better
|
|
|
|
| 17 |
# from megablocks._ops import ops as backend # type: ignore
|
| 18 |
from .._ops import ops as backend # type: ignore
|
| 19 |
|
| 20 |
+
|
| 21 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
| 22 |
assert not (trans_a and trans_b)
|
| 23 |
assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
|
|
|
|
| 37 |
# reproduced by `_dev/debug-gg-small.py`.
|
| 38 |
return torch.zeros(*shape, device=a.device, dtype=a.dtype)
|
| 39 |
|
| 40 |
+
|
| 41 |
+
def _normalize_batch_sizes(batch_sizes: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
if batch_sizes.device.type != "cpu":
|
| 43 |
+
batch_sizes = batch_sizes.to(device="cpu", dtype=torch.int64)
|
| 44 |
+
else:
|
| 45 |
+
batch_sizes = batch_sizes.to(dtype=torch.int64)
|
| 46 |
+
return batch_sizes
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _run_backend(
|
| 50 |
+
a: torch.Tensor,
|
| 51 |
+
b: torch.Tensor,
|
| 52 |
+
batch_sizes: torch.Tensor,
|
| 53 |
+
trans_a: bool,
|
| 54 |
+
trans_b: bool,
|
| 55 |
+
c: Optional[torch.Tensor] = None,
|
| 56 |
+
) -> torch.Tensor:
|
| 57 |
+
batch_sizes_cpu = _normalize_batch_sizes(batch_sizes)
|
| 58 |
if c is None:
|
| 59 |
+
c = _allocate_output(a, b, batch_sizes_cpu, trans_a, trans_b)
|
| 60 |
+
backend.gmm(a, b, c, batch_sizes_cpu, trans_a, trans_b)
|
| 61 |
return c
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class _GroupedGemmFunction(torch.autograd.Function):
|
| 65 |
+
@staticmethod
|
| 66 |
+
def forward( # type: ignore[override]
|
| 67 |
+
ctx,
|
| 68 |
+
a: torch.Tensor,
|
| 69 |
+
b: torch.Tensor,
|
| 70 |
+
batch_sizes: torch.Tensor,
|
| 71 |
+
trans_a: bool,
|
| 72 |
+
trans_b: bool,
|
| 73 |
+
c: Optional[torch.Tensor],
|
| 74 |
+
) -> torch.Tensor:
|
| 75 |
+
if trans_a:
|
| 76 |
+
raise NotImplementedError("Grouped GEMM autograd currently requires trans_a=False.")
|
| 77 |
+
|
| 78 |
+
batch_sizes_cpu = _normalize_batch_sizes(batch_sizes)
|
| 79 |
+
output = _run_backend(a, b, batch_sizes_cpu, trans_a, trans_b, c)
|
| 80 |
+
|
| 81 |
+
ctx.save_for_backward(a, b, batch_sizes_cpu)
|
| 82 |
+
ctx.trans_a = trans_a
|
| 83 |
+
ctx.trans_b = trans_b
|
| 84 |
+
|
| 85 |
+
return output
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
|
| 89 |
+
a, b, batch_sizes_cpu = ctx.saved_tensors
|
| 90 |
+
trans_a = ctx.trans_a
|
| 91 |
+
trans_b = ctx.trans_b
|
| 92 |
+
|
| 93 |
+
if trans_a:
|
| 94 |
+
raise NotImplementedError("Grouped GEMM backward currently requires trans_a=False.")
|
| 95 |
+
|
| 96 |
+
grad_output = grad_output.contiguous()
|
| 97 |
+
grad_output_cast = grad_output
|
| 98 |
+
if grad_output_cast.dtype != a.dtype:
|
| 99 |
+
grad_output_cast = grad_output_cast.to(dtype=a.dtype)
|
| 100 |
+
|
| 101 |
+
grad_a = grad_b = None
|
| 102 |
+
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
if ctx.needs_input_grad[0]:
|
| 105 |
+
grad_a = _run_backend(
|
| 106 |
+
grad_output_cast,
|
| 107 |
+
b.detach(),
|
| 108 |
+
batch_sizes_cpu,
|
| 109 |
+
trans_a=False,
|
| 110 |
+
trans_b=not trans_b,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if ctx.needs_input_grad[1]:
|
| 114 |
+
grad_b_eff = _run_backend(
|
| 115 |
+
a.detach(),
|
| 116 |
+
grad_output_cast,
|
| 117 |
+
batch_sizes_cpu,
|
| 118 |
+
trans_a=True,
|
| 119 |
+
trans_b=False,
|
| 120 |
+
)
|
| 121 |
+
grad_b = (
|
| 122 |
+
grad_b_eff.transpose(-2, -1) if trans_b else grad_b_eff
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if grad_a is not None and grad_a.dtype != a.dtype:
|
| 126 |
+
grad_a = grad_a.to(dtype=a.dtype)
|
| 127 |
+
if grad_b is not None and grad_b.dtype != b.dtype:
|
| 128 |
+
grad_b = grad_b.to(dtype=b.dtype)
|
| 129 |
+
|
| 130 |
+
# None returned for batch_sizes / trans flags / optional c.
|
| 131 |
+
return grad_a, grad_b, None, None, None, None
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
|
| 135 |
+
return _GroupedGemmFunction.apply(a, b, batch_sizes, trans_a, trans_b, c)
|