medmekk HF Staff commited on
Commit
bd6a211
·
verified ·
1 Parent(s): 55d99de

Upload custom kernels

Browse files
build/torch-universal/triton_llama_mlp/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import layers
2
+ __all__ = ["layers"]
build/torch-universal/triton_llama_mlp/layers.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .mlp import TritonLlamaMLP
2
+
3
+ __all__ = ["TritonLlamaMLP"]
build/torch-universal/triton_llama_mlp/mlp.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import triton
4
+ import triton.language as tl
5
+ from typing import Callable, Optional
6
+
7
+
8
+ @triton.jit
9
+ def matmul_kernel(
10
+ # Pointers to matrices
11
+ a_ptr, b_ptr, c_ptr,
12
+ # Matrix dimensions
13
+ M, N, K,
14
+ # The stride variables represent how much to increase the ptr by when moving by 1
15
+ # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
16
+ # by to get the element one row down (A has M rows).
17
+ stride_am, stride_ak, #
18
+ stride_bk, stride_bn, #
19
+ stride_cm, stride_cn,
20
+ # Meta-parameters
21
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
22
+ GROUP_SIZE_M: tl.constexpr, #
23
+ ACTIVATION: tl.constexpr = None #
24
+ ):
25
+ """Kernel for computing the matmul C = A x B.
26
+ A has shape (M, K), B has shape (K, N) and C has shape (M, N)
27
+ """
28
+ # -----------------------------------------------------------
29
+ # Map program ids `pid` to the block of C it should compute.
30
+ # This is done in a grouped ordering to promote L2 data reuse.
31
+ # See above `L2 Cache Optimizations` section for details.
32
+ pid = tl.program_id(axis=0)
33
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
34
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
35
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
36
+ group_id = pid // num_pid_in_group
37
+ first_pid_m = group_id * GROUP_SIZE_M
38
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
39
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
40
+ pid_n = (pid % num_pid_in_group) // group_size_m
41
+
42
+ # ----------------------------------------------------------
43
+ # Create pointers for the first blocks of A and B.
44
+ # We will advance this pointer as we move in the K direction
45
+ # and accumulate
46
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
47
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
48
+ # See above `Pointer Arithmetic` section for details
49
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
50
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
51
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
52
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
53
+ b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
54
+
55
+ # -----------------------------------------------------------
56
+ # Iterate to compute a block of the C matrix.
57
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
58
+ # of fp32 values for higher accuracy.
59
+ # `accumulator` will be converted back to fp16 after the loop.
60
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
61
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
62
+ # Load the next block of A and B, generate a mask by checking the K dimension.
63
+ # If it is out of bounds, set it to 0.
64
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
65
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
66
+ # We accumulate along the K dimension.
67
+ accumulator += tl.dot(a, b)
68
+ # Advance the ptrs to the next K block.
69
+ a_ptrs += BLOCK_SIZE_K * stride_ak
70
+ b_ptrs += BLOCK_SIZE_K * stride_bk
71
+ # You can fuse arbitrary activation functions here
72
+ # while the accumulator is still in FP32!
73
+
74
+ c = accumulator.to(tl.float32)
75
+
76
+ # -----------------------------------------------------------
77
+ # Write back the block of the output matrix C with masks.
78
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
79
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
80
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
81
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
82
+ tl.store(c_ptrs, c, mask=c_mask)
83
+
84
+
85
+ # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
86
+ @triton.jit
87
+ def silu(x):
88
+ return x * tl.sigmoid(x)
89
+
90
+ def matmul(a, b, activation=""):
91
+ # Check constraints.
92
+ assert a.shape[1] == b.shape[0], "Incompatible dimensions"
93
+ assert a.is_contiguous(), "Matrix A must be contiguous"
94
+ M, K = a.shape
95
+ K, N = b.shape
96
+ BLOCK_SIZE_M = 32
97
+ BLOCK_SIZE_N = 32
98
+ BLOCK_SIZE_K = 32
99
+ GROUP_SIZE_M = 8
100
+ # Allocates output.
101
+ c = torch.empty((M, N), device=a.device, dtype=torch.float)
102
+ # 1D launch kernel where each block gets its own program.
103
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
104
+ matmul_kernel[grid](
105
+ a, b, c, #
106
+ M, N, K, #
107
+ a.stride(0), a.stride(1), #
108
+ b.stride(0), b.stride(1), #
109
+ c.stride(0), c.stride(1), #
110
+ BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, #
111
+ GROUP_SIZE_M, #
112
+ ACTIVATION=activation #
113
+ )
114
+ return c
115
+
116
+
117
+ class TritonLlamaMLP(nn.Module):
118
+ """LlamaMLP implementation using Triton kernels for matrix multiplication"""
119
+ gate_proj: nn.Linear
120
+ up_proj: nn.Linear
121
+ down_proj: nn.Linear
122
+ act_fn: Callable
123
+ def forward(self, x):
124
+ # Replace nn.Linear with matmul using triton kernel
125
+ # Save original shape for reshaping back later
126
+ original_shape = x.shape
127
+ # Reshape input to 2D for matmul: (*, hidden_size) -> (batch_size*seq_len, hidden_size)
128
+ x_2d = x.reshape(-1, x.size(-1))
129
+
130
+ # Gate projection
131
+ gate_output = matmul(x_2d, self.gate_proj.weight.t())
132
+ if self.gate_proj.bias is not None:
133
+ gate_output += self.gate_proj.bias
134
+
135
+ # Up projection
136
+ up_output = matmul(x_2d, self.up_proj.weight.t())
137
+ if self.up_proj.bias is not None:
138
+ up_output += self.up_proj.bias
139
+
140
+ # Apply activation function and element-wise multiplication
141
+ intermediate_output = self.act_fn(gate_output) * up_output
142
+
143
+ # Final projection
144
+ down_output = matmul(intermediate_output, self.down_proj.weight.t())
145
+ if self.down_proj.bias is not None:
146
+ down_output += self.down_proj.bias
147
+
148
+ # Reshape back to original dimensions: (batch_size*seq_len, hidden_size) -> (*, hidden_size)
149
+ return down_output.reshape(original_shape)
150
+
151
+
152
+ # def test_triton_llama_mlp():
153
+ # """Test that TritonLlamaMLP produces the same output as LlamaMLP from transformers."""
154
+ # import torch
155
+ # import torch.nn.functional as F
156
+
157
+ # # Skip test if CUDA is not available
158
+ # if not torch.cuda.is_available():
159
+ # print("CUDA not available, skipping test")
160
+ # return True
161
+
162
+ # # Define test parameters
163
+ # batch_size = 2
164
+ # seq_len = 4
165
+ # hidden_size = 128
166
+ # intermediate_size = 256
167
+
168
+ # # Create input tensor
169
+ # x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float)
170
+
171
+ # # Create a standard PyTorch implementation for comparison
172
+ # class StandardLlamaMLP(nn.Module):
173
+ # def __init__(self, hidden_size, intermediate_size):
174
+ # super().__init__()
175
+ # self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
176
+ # self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
177
+ # self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
178
+ # self.act_fn = F.silu
179
+
180
+ # def forward(self, x):
181
+ # return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
182
+
183
+ # # Initialize models
184
+ # standard_mlp = StandardLlamaMLP(hidden_size, intermediate_size).to("cuda").to(torch.float)
185
+
186
+ # # Create our Triton implementation
187
+ # triton_mlp = TritonLlamaMLP()
188
+ # triton_mlp.gate_proj = standard_mlp.gate_proj
189
+ # triton_mlp.up_proj = standard_mlp.up_proj
190
+ # triton_mlp.down_proj = standard_mlp.down_proj
191
+ # triton_mlp.act_fn = standard_mlp.act_fn
192
+
193
+ # # Run both implementations
194
+ # with torch.no_grad():
195
+ # standard_output = standard_mlp(x)
196
+ # triton_output = triton_mlp(x)
197
+
198
+ # # Compare outputs
199
+ # max_diff = torch.max(torch.abs(standard_output - triton_output))
200
+ # print(f"Maximum difference between standard and Triton implementation: {max_diff}")
201
+
202
+ # # Check if outputs are close enough (allowing for some floating point differences)
203
+ # is_close = torch.allclose(standard_output, triton_output, rtol=1e-2, atol=1e-2)
204
+ # print(f"Outputs match within tolerance: {is_close}")
205
+
206
+ # return is_close
207
+
208
+ # if __name__ == "__main__":
209
+ # test_triton_llama_mlp()
210
+
torch-ext/triton_llama_mlp/mlp.py CHANGED
@@ -71,7 +71,7 @@ def matmul_kernel(
71
  # You can fuse arbitrary activation functions here
72
  # while the accumulator is still in FP32!
73
 
74
- c = accumulator.to(tl.float16)
75
 
76
  # -----------------------------------------------------------
77
  # Write back the block of the output matrix C with masks.
@@ -98,7 +98,7 @@ def matmul(a, b, activation=""):
98
  BLOCK_SIZE_K = 32
99
  GROUP_SIZE_M = 8
100
  # Allocates output.
101
- c = torch.empty((M, N), device=a.device, dtype=torch.float16)
102
  # 1D launch kernel where each block gets its own program.
103
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
104
  matmul_kernel[grid](
@@ -166,7 +166,7 @@ class TritonLlamaMLP(nn.Module):
166
  # intermediate_size = 256
167
 
168
  # # Create input tensor
169
- # x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
170
 
171
  # # Create a standard PyTorch implementation for comparison
172
  # class StandardLlamaMLP(nn.Module):
@@ -181,7 +181,7 @@ class TritonLlamaMLP(nn.Module):
181
  # return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
182
 
183
  # # Initialize models
184
- # standard_mlp = StandardLlamaMLP(hidden_size, intermediate_size).to("cuda").to(torch.float16)
185
 
186
  # # Create our Triton implementation
187
  # triton_mlp = TritonLlamaMLP()
 
71
  # You can fuse arbitrary activation functions here
72
  # while the accumulator is still in FP32!
73
 
74
+ c = accumulator.to(tl.float32)
75
 
76
  # -----------------------------------------------------------
77
  # Write back the block of the output matrix C with masks.
 
98
  BLOCK_SIZE_K = 32
99
  GROUP_SIZE_M = 8
100
  # Allocates output.
101
+ c = torch.empty((M, N), device=a.device, dtype=torch.float)
102
  # 1D launch kernel where each block gets its own program.
103
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
104
  matmul_kernel[grid](
 
166
  # intermediate_size = 256
167
 
168
  # # Create input tensor
169
+ # x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float)
170
 
171
  # # Create a standard PyTorch implementation for comparison
172
  # class StandardLlamaMLP(nn.Module):
 
181
  # return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
182
 
183
  # # Initialize models
184
+ # standard_mlp = StandardLlamaMLP(hidden_size, intermediate_size).to("cuda").to(torch.float)
185
 
186
  # # Create our Triton implementation
187
  # triton_mlp = TritonLlamaMLP()