leonardlin commited on
Commit
5ce4c31
·
1 Parent(s): 867401e

Teach HIP grouped_gemm about autograd

Browse files

- wrap the ROCm grouped GEMM call in a torch.autograd.Function so hidden states and expert weights receive gradients

- reuse the backend kernel for backward matmuls and normalize batch size tensors on the host

- note the hipBLASLt opt-in flag in grouped_gemm.hip while keeping it off by default

Tests: python -m pytest axolotl.shisa/tests/e2e/test_ring_moe_grouped.py -k megablocks_gradient_parity -s

csrc/grouped_gemm/grouped_gemm.hip CHANGED
@@ -17,6 +17,9 @@
17
  namespace grouped_gemm {
18
  namespace {
19
 
 
 
 
20
  bool use_hipblaslt_backend() {
21
  static int cached = [] {
22
  const char* raw = std::getenv("MEGABLOCKS_GG_USE_HIPBLASLT");
 
17
  namespace grouped_gemm {
18
  namespace {
19
 
20
+ // Experimental: toggled via MEGABLOCKS_GG_USE_HIPBLASLT=1. This flag is
21
+ // intentionally off by default because the hipBLASLt path still fails on the
22
+ // largest `tests/ops_test.py` configurations.
23
  bool use_hipblaslt_backend() {
24
  static int cached = [] {
25
  const char* raw = std::getenv("MEGABLOCKS_GG_USE_HIPBLASLT");
torch-ext/megablocks/grouped_gemm/backend.py CHANGED
@@ -1,5 +1,9 @@
1
  # NOTE: Torch needs to be imported before the custom
2
  # extensions. Otherwise libc10.so cannot be found.
 
 
 
 
3
  import torch
4
 
5
  # # TODO(tgale): Wrap this in a try-block with better
@@ -13,6 +17,7 @@ import torch
13
  # from megablocks._ops import ops as backend # type: ignore
14
  from .._ops import ops as backend # type: ignore
15
 
 
16
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
  assert not (trans_a and trans_b)
18
  assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
@@ -32,8 +37,99 @@ def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
32
  # reproduced by `_dev/debug-gg-small.py`.
33
  return torch.zeros(*shape, device=a.device, dtype=a.dtype)
34
 
35
- def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if c is None:
37
- c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
38
- backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
39
  return c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # NOTE: Torch needs to be imported before the custom
2
  # extensions. Otherwise libc10.so cannot be found.
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
  import torch
8
 
9
  # # TODO(tgale): Wrap this in a try-block with better
 
17
  # from megablocks._ops import ops as backend # type: ignore
18
  from .._ops import ops as backend # type: ignore
19
 
20
+
21
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
22
  assert not (trans_a and trans_b)
23
  assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
 
37
  # reproduced by `_dev/debug-gg-small.py`.
38
  return torch.zeros(*shape, device=a.device, dtype=a.dtype)
39
 
40
+
41
+ def _normalize_batch_sizes(batch_sizes: torch.Tensor) -> torch.Tensor:
42
+ if batch_sizes.device.type != "cpu":
43
+ batch_sizes = batch_sizes.to(device="cpu", dtype=torch.int64)
44
+ else:
45
+ batch_sizes = batch_sizes.to(dtype=torch.int64)
46
+ return batch_sizes
47
+
48
+
49
+ def _run_backend(
50
+ a: torch.Tensor,
51
+ b: torch.Tensor,
52
+ batch_sizes: torch.Tensor,
53
+ trans_a: bool,
54
+ trans_b: bool,
55
+ c: Optional[torch.Tensor] = None,
56
+ ) -> torch.Tensor:
57
+ batch_sizes_cpu = _normalize_batch_sizes(batch_sizes)
58
  if c is None:
59
+ c = _allocate_output(a, b, batch_sizes_cpu, trans_a, trans_b)
60
+ backend.gmm(a, b, c, batch_sizes_cpu, trans_a, trans_b)
61
  return c
62
+
63
+
64
+ class _GroupedGemmFunction(torch.autograd.Function):
65
+ @staticmethod
66
+ def forward( # type: ignore[override]
67
+ ctx,
68
+ a: torch.Tensor,
69
+ b: torch.Tensor,
70
+ batch_sizes: torch.Tensor,
71
+ trans_a: bool,
72
+ trans_b: bool,
73
+ c: Optional[torch.Tensor],
74
+ ) -> torch.Tensor:
75
+ if trans_a:
76
+ raise NotImplementedError("Grouped GEMM autograd currently requires trans_a=False.")
77
+
78
+ batch_sizes_cpu = _normalize_batch_sizes(batch_sizes)
79
+ output = _run_backend(a, b, batch_sizes_cpu, trans_a, trans_b, c)
80
+
81
+ ctx.save_for_backward(a, b, batch_sizes_cpu)
82
+ ctx.trans_a = trans_a
83
+ ctx.trans_b = trans_b
84
+
85
+ return output
86
+
87
+ @staticmethod
88
+ def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
89
+ a, b, batch_sizes_cpu = ctx.saved_tensors
90
+ trans_a = ctx.trans_a
91
+ trans_b = ctx.trans_b
92
+
93
+ if trans_a:
94
+ raise NotImplementedError("Grouped GEMM backward currently requires trans_a=False.")
95
+
96
+ grad_output = grad_output.contiguous()
97
+ grad_output_cast = grad_output
98
+ if grad_output_cast.dtype != a.dtype:
99
+ grad_output_cast = grad_output_cast.to(dtype=a.dtype)
100
+
101
+ grad_a = grad_b = None
102
+
103
+ with torch.no_grad():
104
+ if ctx.needs_input_grad[0]:
105
+ grad_a = _run_backend(
106
+ grad_output_cast,
107
+ b.detach(),
108
+ batch_sizes_cpu,
109
+ trans_a=False,
110
+ trans_b=not trans_b,
111
+ )
112
+
113
+ if ctx.needs_input_grad[1]:
114
+ grad_b_eff = _run_backend(
115
+ a.detach(),
116
+ grad_output_cast,
117
+ batch_sizes_cpu,
118
+ trans_a=True,
119
+ trans_b=False,
120
+ )
121
+ grad_b = (
122
+ grad_b_eff.transpose(-2, -1) if trans_b else grad_b_eff
123
+ )
124
+
125
+ if grad_a is not None and grad_a.dtype != a.dtype:
126
+ grad_a = grad_a.to(dtype=a.dtype)
127
+ if grad_b is not None and grad_b.dtype != b.dtype:
128
+ grad_b = grad_b.to(dtype=b.dtype)
129
+
130
+ # None returned for batch_sizes / trans flags / optional c.
131
+ return grad_a, grad_b, None, None, None, None
132
+
133
+
134
+ def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
135
+ return _GroupedGemmFunction.apply(a, b, batch_sizes, trans_a, trans_b, c)