Kernels
danieldk HF Staff commited on
Commit
839efba
·
1 Parent(s): d1a2b62

Revert "Remove old flash-attn3 builds"

Browse files

This reverts commit d1a2b627342457010f361fad8266f553c7d3400b.

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/__init__.py +17 -0
  2. build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +3 -0
  3. build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +3 -0
  4. build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_ops.py +9 -0
  5. build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
  6. build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py +17 -0
  7. build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +3 -0
  8. build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +3 -0
  9. build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py +9 -0
  10. build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
  11. build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/__init__.py +17 -0
  12. build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +3 -0
  13. build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +3 -0
  14. build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_ops.py +9 -0
  15. build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
  16. build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/__init__.py +17 -0
  17. build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +3 -0
  18. build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +3 -0
  19. build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_ops.py +9 -0
  20. build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
  21. build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py +17 -0
  22. build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +3 -0
  23. build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py +9 -0
  24. build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
  25. build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/__init__.py +17 -0
  26. build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +3 -0
  27. build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_ops.py +9 -0
  28. build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
  29. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__init__.py +17 -0
  30. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc +0 -0
  31. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc +0 -0
  32. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc +0 -0
  33. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/_flash_attn3_8d4f83f.abi3.so +3 -0
  34. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/_ops.py +9 -0
  35. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/flash_attn_interface.py +828 -0
  36. build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py +17 -0
  37. build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_48fe103_dirty.abi3.so +3 -0
  38. build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py +9 -0
  39. build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
  40. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__init__.py +17 -0
  41. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc +0 -0
  42. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc +0 -0
  43. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc +0 -0
  44. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/_flash_attn3_8d4f83f.abi3.so +3 -0
  45. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/_ops.py +9 -0
  46. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/flash_attn_interface.py +828 -0
  47. build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__init__.py +17 -0
  48. build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc +0 -0
  49. build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc +0 -0
  50. build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc +0 -0
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21b44e8e5e447a8b8ee051d347f0e32a3446a750f79d0bd1755e553f2119aa3b
3
+ size 838459656
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12d4ff964085fd02252777a2008f5ca47c90ea6a93da590e2fc5065dd5330207
3
+ size 838459656
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn3_557701f
3
+ ops = torch.ops._flash_attn3_557701f
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn3_557701f::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ._ops import ops as flash_attn_3_cuda
9
+
10
+ def maybe_contiguous(x):
11
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
+
13
+
14
+ def _flash_attn_forward(
15
+ q,
16
+ k,
17
+ v,
18
+ k_new,
19
+ v_new,
20
+ qv,
21
+ out,
22
+ cu_seqlens_q,
23
+ cu_seqlens_k,
24
+ cu_seqlens_k_new,
25
+ seqused_q,
26
+ seqused_k,
27
+ max_seqlen_q,
28
+ max_seqlen_k,
29
+ page_table,
30
+ kv_batch_idx,
31
+ leftpad_k,
32
+ rotary_cos,
33
+ rotary_sin,
34
+ seqlens_rotary,
35
+ q_descale,
36
+ k_descale,
37
+ v_descale,
38
+ softmax_scale,
39
+ causal,
40
+ window_size=(-1, -1),
41
+ attention_chunk=0,
42
+ softcap=0.0,
43
+ rotary_interleaved=True,
44
+ scheduler_metadata=None,
45
+ num_splits=1,
46
+ pack_gqa=None,
47
+ sm_margin=0):
48
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
+ ]
53
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
+ page_table, kv_batch_idx, leftpad_k = [
55
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
+ ]
57
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
+ out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
+ q,
61
+ k,
62
+ v,
63
+ k_new,
64
+ v_new,
65
+ qv,
66
+ out,
67
+ cu_seqlens_q,
68
+ cu_seqlens_k,
69
+ cu_seqlens_k_new,
70
+ seqused_q,
71
+ seqused_k,
72
+ max_seqlen_q,
73
+ max_seqlen_k,
74
+ page_table,
75
+ kv_batch_idx,
76
+ leftpad_k,
77
+ rotary_cos,
78
+ rotary_sin,
79
+ seqlens_rotary,
80
+ q_descale,
81
+ k_descale,
82
+ v_descale,
83
+ softmax_scale,
84
+ causal,
85
+ window_size[0],
86
+ window_size[1],
87
+ attention_chunk,
88
+ softcap,
89
+ rotary_interleaved,
90
+ scheduler_metadata,
91
+ num_splits,
92
+ pack_gqa,
93
+ sm_margin,
94
+ )
95
+ return out, softmax_lse, *rest
96
+
97
+
98
+ def _flash_attn_backward(
99
+ dout,
100
+ q,
101
+ k,
102
+ v,
103
+ out,
104
+ softmax_lse,
105
+ cu_seqlens_q,
106
+ cu_seqlens_k,
107
+ sequed_q,
108
+ sequed_k,
109
+ max_seqlen_q,
110
+ max_seqlen_k,
111
+ dq,
112
+ dk,
113
+ dv,
114
+ softmax_scale,
115
+ causal,
116
+ window_size=(-1, -1),
117
+ softcap=0.0,
118
+ deterministic=False,
119
+ sm_margin=0,
120
+ ):
121
+ # dq, dk, dv are allocated by us so they should already be contiguous
122
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
+ dout,
125
+ q,
126
+ k,
127
+ v,
128
+ out,
129
+ softmax_lse,
130
+ dq,
131
+ dk,
132
+ dv,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ sequed_q,
136
+ sequed_k,
137
+ max_seqlen_q,
138
+ max_seqlen_k,
139
+ softmax_scale,
140
+ causal,
141
+ window_size[0],
142
+ window_size[1],
143
+ softcap,
144
+ deterministic,
145
+ sm_margin,
146
+ )
147
+ return dq, dk, dv, softmax_d
148
+
149
+
150
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(
153
+ ctx,
154
+ qkv,
155
+ softmax_scale,
156
+ causal,
157
+ q_descale=None, k_descale=None, v_descale=None,
158
+ window_size=(-1, -1),
159
+ attention_chunk=0,
160
+ softcap=0.0,
161
+ deterministic=False,
162
+ num_heads_q=None,
163
+ sm_margin=0,
164
+ ):
165
+ if softmax_scale is None:
166
+ softmax_scale = qkv.shape[-1] ** (-0.5)
167
+ if qkv.dim() == 5:
168
+ assert qkv.shape[-3] == 3
169
+ q, k, v = qkv.unbind(dim=-3)
170
+ else:
171
+ assert qkv.dim() == 4
172
+ assert num_heads_q is not None
173
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
+ out, softmax_lse, *rest = _flash_attn_forward(
177
+ q,
178
+ k,
179
+ v,
180
+ None, None, # k_new, v_new
181
+ None, # qv
182
+ None, # out
183
+ None, None, None, # cu_seqlens_q/k/k_new
184
+ None, None, # seqused_q/k
185
+ None, None, # max_seqlen_q/k
186
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
+ None, None, None, # rotary_cos/sin, seqlens_rotary
188
+ q_descale, k_descale, v_descale,
189
+ softmax_scale,
190
+ causal=causal,
191
+ window_size=window_size,
192
+ attention_chunk=attention_chunk,
193
+ softcap=softcap,
194
+ sm_margin=sm_margin,
195
+ )
196
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
198
+ ctx.softmax_scale = softmax_scale
199
+ ctx.causal = causal
200
+ ctx.window_size = window_size
201
+ ctx.attention_chunk = attention_chunk
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ ctx.sm_margin = sm_margin
206
+ # return out, softmax_lse
207
+ return out
208
+
209
+ @staticmethod
210
+ def backward(ctx, dout, *args):
211
+ q, k, v, out, softmax_lse = ctx.saved_tensors
212
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
+ if ctx.ndim == 5:
214
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
+ dq, dk, dv = dqkv.unbind(dim=-3)
217
+ else:
218
+ num_heads_q = q.shape[2]
219
+ num_heads_k = k.shape[2]
220
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
+ _flash_attn_backward(
224
+ dout,
225
+ q,
226
+ k,
227
+ v,
228
+ out,
229
+ softmax_lse,
230
+ None, None, # cu_seqlens_q, cu_seqlens_k,
231
+ None, None, # sequed_q, sequed_k,
232
+ None, None, # max_seqlen_q, max_seqlen_k,
233
+ dq,
234
+ dk,
235
+ dv,
236
+ ctx.softmax_scale,
237
+ ctx.causal,
238
+ ctx.window_size,
239
+ ctx.softcap,
240
+ ctx.deterministic,
241
+ ctx.sm_margin,
242
+ )
243
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
+
246
+
247
+ class FlashAttnFunc(torch.autograd.Function):
248
+
249
+ @staticmethod
250
+ def forward(
251
+ ctx,
252
+ q,
253
+ k,
254
+ v,
255
+ softmax_scale,
256
+ causal,
257
+ qv=None,
258
+ q_descale=None, k_descale=None, v_descale=None,
259
+ window_size=(-1, -1),
260
+ attention_chunk=0,
261
+ softcap=0.0,
262
+ num_splits=1,
263
+ pack_gqa=None,
264
+ deterministic=False,
265
+ sm_margin=0,
266
+ ):
267
+ if softmax_scale is None:
268
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
+ out, softmax_lse, *rest = _flash_attn_forward(
271
+ q,
272
+ k,
273
+ v,
274
+ None, None, # k_new, v_new
275
+ qv, # qv
276
+ None, # out
277
+ None, None, None, # cu_seqlens_q/k/k_new
278
+ None, None, # seqused_q/k
279
+ None, None, # max_seqlen_q/k
280
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
+ None, None, None, # rotary_cos/sin, seqlens_rotary
282
+ q_descale, k_descale, v_descale,
283
+ softmax_scale,
284
+ causal=causal,
285
+ window_size=window_size,
286
+ attention_chunk=attention_chunk,
287
+ softcap=softcap,
288
+ num_splits=num_splits,
289
+ pack_gqa=pack_gqa,
290
+ sm_margin=sm_margin,
291
+ )
292
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
294
+ ctx.softmax_scale = softmax_scale
295
+ ctx.causal = causal
296
+ ctx.window_size = window_size
297
+ ctx.attention_chunk = attention_chunk
298
+ ctx.softcap = softcap
299
+ ctx.deterministic = deterministic
300
+ ctx.sm_margin = sm_margin
301
+ return out, softmax_lse
302
+
303
+ @staticmethod
304
+ def backward(ctx, dout, *args):
305
+ q, k, v, out, softmax_lse = ctx.saved_tensors
306
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
+ _flash_attn_backward(
309
+ dout,
310
+ q,
311
+ k,
312
+ v,
313
+ out,
314
+ softmax_lse,
315
+ None, None, # cu_seqlens_q, cu_seqlens_k,
316
+ None, None, # sequed_q, sequed_k,
317
+ None, None, # max_seqlen_q, max_seqlen_k,
318
+ dq,
319
+ dk,
320
+ dv,
321
+ ctx.softmax_scale,
322
+ ctx.causal,
323
+ ctx.window_size,
324
+ ctx.softcap,
325
+ ctx.deterministic,
326
+ ctx.sm_margin,
327
+ )
328
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
+ dk = dk[..., : k.shape[-1]]
330
+ dv = dv[..., : v.shape[-1]]
331
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
+
333
+
334
+ class FlashAttnVarlenFunc(torch.autograd.Function):
335
+
336
+ @staticmethod
337
+ def forward(
338
+ ctx,
339
+ q,
340
+ k,
341
+ v,
342
+ cu_seqlens_q,
343
+ cu_seqlens_k,
344
+ seqused_q,
345
+ seqused_k,
346
+ max_seqlen_q,
347
+ max_seqlen_k,
348
+ softmax_scale,
349
+ causal,
350
+ qv=None,
351
+ q_descale=None, k_descale=None, v_descale=None,
352
+ window_size=(-1, -1),
353
+ attention_chunk=0,
354
+ softcap=0.0,
355
+ num_splits=1,
356
+ pack_gqa=None,
357
+ deterministic=False,
358
+ sm_margin=0,
359
+ ):
360
+ if softmax_scale is None:
361
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
+ out, softmax_lse, *rest = _flash_attn_forward(
364
+ q,
365
+ k,
366
+ v,
367
+ None, None, # k_new, v_new
368
+ qv, # qv
369
+ None, # out
370
+ cu_seqlens_q,
371
+ cu_seqlens_k,
372
+ None, # cu_seqlens_k_new
373
+ seqused_q,
374
+ seqused_k,
375
+ max_seqlen_q,
376
+ max_seqlen_k,
377
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
+ None, None, None, # rotary_cos/sin, seqlens_rotary
379
+ q_descale, k_descale, v_descale,
380
+ softmax_scale,
381
+ causal=causal,
382
+ window_size=window_size,
383
+ attention_chunk=attention_chunk,
384
+ softcap=softcap,
385
+ num_splits=num_splits,
386
+ pack_gqa=pack_gqa,
387
+ sm_margin=sm_margin,
388
+ )
389
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
+ ctx.max_seqlen_q = max_seqlen_q
392
+ ctx.max_seqlen_k = max_seqlen_k
393
+ ctx.softmax_scale = softmax_scale
394
+ ctx.causal = causal
395
+ ctx.window_size = window_size
396
+ ctx.attention_chunk = attention_chunk
397
+ ctx.softcap = softcap
398
+ ctx.deterministic = deterministic
399
+ ctx.sm_margin = sm_margin
400
+ return out, softmax_lse
401
+
402
+ @staticmethod
403
+ def backward(ctx, dout, *args):
404
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
+ _flash_attn_backward(
408
+ dout,
409
+ q,
410
+ k,
411
+ v,
412
+ out,
413
+ softmax_lse,
414
+ cu_seqlens_q,
415
+ cu_seqlens_k,
416
+ seqused_q,
417
+ seqused_k,
418
+ ctx.max_seqlen_q,
419
+ ctx.max_seqlen_k,
420
+ dq,
421
+ dk,
422
+ dv,
423
+ ctx.softmax_scale,
424
+ ctx.causal,
425
+ ctx.window_size,
426
+ ctx.softcap,
427
+ ctx.deterministic,
428
+ ctx.sm_margin,
429
+ )
430
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
+ dk = dk[..., : k.shape[-1]]
432
+ dv = dv[..., : v.shape[-1]]
433
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
+
435
+
436
+ def flash_attn_qkvpacked_func(
437
+ qkv,
438
+ softmax_scale=None,
439
+ causal=False,
440
+ q_descale=None, k_descale=None, v_descale=None,
441
+ window_size=(-1, -1),
442
+ attention_chunk=0,
443
+ softcap=0.0,
444
+ deterministic=False,
445
+ num_heads_q=None,
446
+ sm_margin=0,
447
+ ):
448
+ """dropout_p should be set to 0.0 during evaluation
449
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
+ of the gradients of Q, K, V.
452
+ For multi-query and grouped-query attention (MQA/GQA), please see
453
+ flash_attn_kvpacked_func and flash_attn_func.
454
+
455
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
+
458
+ Arguments:
459
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
460
+ dropout_p: float. Dropout probability.
461
+ softmax_scale: float. The scaling of QK^T before applying softmax.
462
+ Default to 1 / sqrt(headdim).
463
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
+ softcap: float. Anything > 0 activates softcapping attention.
466
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
+ the attention score of query i and key j.
468
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
470
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
+ testing only. The returned probabilities are not guaranteed to be correct
472
+ (they might not have the right scaling).
473
+ Return:
474
+ out: (batch_size, seqlen, nheads, headdim).
475
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
+ normalization factor).
478
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
+ The output of softmax (possibly with different scaling). It also encodes the dropout
480
+ pattern (negative means that location was dropped, nonnegative means it was kept).
481
+ """
482
+ return FlashAttnQKVPackedFunc.apply(
483
+ qkv,
484
+ softmax_scale,
485
+ causal,
486
+ q_descale, k_descale, v_descale,
487
+ window_size,
488
+ attention_chunk,
489
+ softcap,
490
+ deterministic,
491
+ num_heads_q,
492
+ sm_margin,
493
+ )
494
+
495
+
496
+ def flash_attn_func(
497
+ q,
498
+ k,
499
+ v,
500
+ softmax_scale=None,
501
+ causal=False,
502
+ qv=None,
503
+ q_descale=None, k_descale=None, v_descale=None,
504
+ window_size=(-1, -1),
505
+ attention_chunk=0,
506
+ softcap=0.0,
507
+ num_splits=1,
508
+ pack_gqa=None,
509
+ deterministic=False,
510
+ sm_margin=0,
511
+ ):
512
+ """dropout_p should be set to 0.0 during evaluation
513
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
+
518
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
+ 1 1 1 1 0
521
+ 1 1 1 1 1
522
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
+ 0 0
524
+ 0 0
525
+ 0 0
526
+ 1 0
527
+ 1 1
528
+ If the row of the mask is all zero, the output will be zero.
529
+
530
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
+ will only attend to keys between
532
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
+
534
+ Arguments:
535
+ q: (batch_size, seqlen, nheads, headdim)
536
+ k: (batch_size, seqlen, nheads_k, headdim)
537
+ v: (batch_size, seqlen, nheads_k, headdim)
538
+ dropout_p: float. Dropout probability.
539
+ softmax_scale: float. The scaling of QK^T before applying softmax.
540
+ Default to 1 / sqrt(headdim).
541
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
+ is added to the attention score of query i and key j.
546
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
548
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
+ testing only. The returned probabilities are not guaranteed to be correct
550
+ (they might not have the right scaling).
551
+ Return:
552
+ out: (batch_size, seqlen, nheads, headdim).
553
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
+ normalization factor).
556
+ """
557
+ return FlashAttnFunc.apply(
558
+ q,
559
+ k,
560
+ v,
561
+ softmax_scale,
562
+ causal,
563
+ qv,
564
+ q_descale, k_descale, v_descale,
565
+ window_size,
566
+ attention_chunk,
567
+ softcap,
568
+ num_splits,
569
+ pack_gqa,
570
+ deterministic,
571
+ sm_margin,
572
+ )
573
+
574
+
575
+ def flash_attn_varlen_func(
576
+ q,
577
+ k,
578
+ v,
579
+ cu_seqlens_q,
580
+ cu_seqlens_k,
581
+ max_seqlen_q,
582
+ max_seqlen_k,
583
+ seqused_q=None,
584
+ seqused_k=None,
585
+ softmax_scale=None,
586
+ causal=False,
587
+ qv=None,
588
+ q_descale=None, k_descale=None, v_descale=None,
589
+ window_size=(-1, -1),
590
+ attention_chunk=0,
591
+ softcap=0.0,
592
+ num_splits=1,
593
+ pack_gqa=None,
594
+ deterministic=False,
595
+ sm_margin=0,
596
+ ):
597
+ return FlashAttnVarlenFunc.apply(
598
+ q,
599
+ k,
600
+ v,
601
+ cu_seqlens_q,
602
+ cu_seqlens_k,
603
+ seqused_q,
604
+ seqused_k,
605
+ max_seqlen_q,
606
+ max_seqlen_k,
607
+ softmax_scale,
608
+ causal,
609
+ qv,
610
+ q_descale, k_descale, v_descale,
611
+ window_size,
612
+ attention_chunk,
613
+ softcap,
614
+ num_splits,
615
+ pack_gqa,
616
+ deterministic,
617
+ sm_margin,
618
+ )
619
+
620
+
621
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
+
624
+
625
+ def flash_attn_with_kvcache(
626
+ q,
627
+ k_cache,
628
+ v_cache,
629
+ k=None,
630
+ v=None,
631
+ qv=None,
632
+ rotary_cos=None,
633
+ rotary_sin=None,
634
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
+ cache_batch_idx: Optional[torch.Tensor] = None,
636
+ cache_leftpad: Optional[torch.Tensor] = None,
637
+ page_table: Optional[torch.Tensor] = None,
638
+ cu_seqlens_q: Optional[torch.Tensor] = None,
639
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
+ max_seqlen_q: Optional[int] = None,
641
+ rotary_seqlens: Optional[torch.Tensor] = None,
642
+ q_descale: Optional[torch.Tensor] = None,
643
+ k_descale: Optional[torch.Tensor] = None,
644
+ v_descale: Optional[torch.Tensor] = None,
645
+ softmax_scale=None,
646
+ causal=False,
647
+ window_size=(-1, -1), # -1 means infinite context window
648
+ attention_chunk=0,
649
+ softcap=0.0, # 0.0 means deactivated
650
+ rotary_interleaved=True,
651
+ scheduler_metadata=None,
652
+ num_splits=0, # Can be tuned for speed
653
+ pack_gqa=None, # Can be tuned for speed
654
+ sm_margin=0, # Can be tuned if some SMs are used for communication
655
+ return_softmax_lse=False,
656
+ ):
657
+ """
658
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
+ the previous step, and update them with the new keys/values from the current step, and do
661
+ attention with the updated cache, all in 1 kernel.
662
+
663
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
+
667
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
+
674
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
+
676
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
+
681
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
+ 1 1 1 1 0
684
+ 1 1 1 1 1
685
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
+ 0 0
687
+ 0 0
688
+ 0 0
689
+ 1 0
690
+ 1 1
691
+ If the row of the mask is all zero, the output will be zero.
692
+
693
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
+ will only attend to keys between
695
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
+
697
+ Note: Does not support backward pass.
698
+
699
+ Arguments:
700
+ q: (batch_size, seqlen, nheads, headdim)
701
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
+ page_block_size must be a multiple of 256.
704
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
+ k with k_cache, starting at the indices specified by cache_seqlens.
708
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
+ KV cache.
715
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
718
+ might come from any of the duplicate indices.
719
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
+ softmax_scale: float. The scaling of QK^T before applying softmax.
722
+ Default to 1 / sqrt(headdim).
723
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
+ softcap: float. Anything > 0 activates softcapping attention.
726
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
+ (i.e. GPT-NeoX style).
730
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
+ to automatically determine the number of splits.
733
+ Don't change this unless you know what you are doing.
734
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
+
736
+ Return:
737
+ out: (batch_size, seqlen, nheads, headdim).
738
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
+ normalization factor).
741
+ """
742
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
+ if softmax_scale is None:
745
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
+ cache_seqlens = torch.full(
748
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
+ )
750
+ cache_seqlens = maybe_contiguous(cache_seqlens)
751
+ out, softmax_lse, *rest = _flash_attn_forward(
752
+ q,
753
+ k_cache,
754
+ v_cache,
755
+ k,
756
+ v,
757
+ qv,
758
+ None, # out
759
+ cu_seqlens_q,
760
+ None, # cu_seqlens_k
761
+ cu_seqlens_k_new,
762
+ None, # seqused_q
763
+ cache_seqlens,
764
+ max_seqlen_q,
765
+ None, # max_seqlen_k
766
+ page_table,
767
+ cache_batch_idx,
768
+ cache_leftpad,
769
+ rotary_cos,
770
+ rotary_sin,
771
+ rotary_seqlens,
772
+ q_descale, k_descale, v_descale,
773
+ softmax_scale,
774
+ causal=causal,
775
+ window_size=window_size,
776
+ attention_chunk=attention_chunk,
777
+ softcap=softcap,
778
+ rotary_interleaved=rotary_interleaved,
779
+ scheduler_metadata=scheduler_metadata,
780
+ num_splits=num_splits,
781
+ pack_gqa=pack_gqa,
782
+ sm_margin=sm_margin,
783
+ )
784
+ # return (out, softmax_lse) if return_softmax_lse else out
785
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
786
+
787
+
788
+ def get_scheduler_metadata(
789
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
+ cache_seqlens: torch.Tensor,
791
+ qkv_dtype=torch.bfloat16,
792
+ headdim_v=None,
793
+ cu_seqlens_q: Optional[torch.Tensor] = None,
794
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
+ cache_leftpad: Optional[torch.Tensor] = None,
796
+ page_size: Optional[int] = None,
797
+ max_seqlen_k_new=0,
798
+ causal=False,
799
+ window_size=(-1, -1), # -1 means infinite context window
800
+ attention_chunk=0,
801
+ has_softcap=False,
802
+ num_splits=0, # Can be tuned for speed
803
+ pack_gqa=None, # Can be tuned for speed
804
+ sm_margin=0, # Can be tuned if some SMs are used for communication
805
+ ):
806
+ cache_seqlens = maybe_contiguous(cache_seqlens)
807
+ if headdim_v is None:
808
+ headdim_v = headdim
809
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
+ qkv_dtype,
812
+ cache_seqlens,
813
+ cu_seqlens_q,
814
+ None, # cu_seqlens_k
815
+ cu_seqlens_k_new,
816
+ None, # seqused_q
817
+ cache_leftpad,
818
+ page_size,
819
+ max_seqlen_k_new,
820
+ causal,
821
+ window_size[0], window_size[1],
822
+ attention_chunk,
823
+ has_softcap,
824
+ num_splits,
825
+ pack_gqa,
826
+ sm_margin,
827
+ )
828
+ return scheduler_metadata
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21b44e8e5e447a8b8ee051d347f0e32a3446a750f79d0bd1755e553f2119aa3b
3
+ size 838459656
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12d4ff964085fd02252777a2008f5ca47c90ea6a93da590e2fc5065dd5330207
3
+ size 838459656
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn3_557701f
3
+ ops = torch.ops._flash_attn3_557701f
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn3_557701f::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ._ops import ops as flash_attn_3_cuda
9
+
10
+ def maybe_contiguous(x):
11
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
+
13
+
14
+ def _flash_attn_forward(
15
+ q,
16
+ k,
17
+ v,
18
+ k_new,
19
+ v_new,
20
+ qv,
21
+ out,
22
+ cu_seqlens_q,
23
+ cu_seqlens_k,
24
+ cu_seqlens_k_new,
25
+ seqused_q,
26
+ seqused_k,
27
+ max_seqlen_q,
28
+ max_seqlen_k,
29
+ page_table,
30
+ kv_batch_idx,
31
+ leftpad_k,
32
+ rotary_cos,
33
+ rotary_sin,
34
+ seqlens_rotary,
35
+ q_descale,
36
+ k_descale,
37
+ v_descale,
38
+ softmax_scale,
39
+ causal,
40
+ window_size=(-1, -1),
41
+ attention_chunk=0,
42
+ softcap=0.0,
43
+ rotary_interleaved=True,
44
+ scheduler_metadata=None,
45
+ num_splits=1,
46
+ pack_gqa=None,
47
+ sm_margin=0):
48
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
+ ]
53
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
+ page_table, kv_batch_idx, leftpad_k = [
55
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
+ ]
57
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
+ out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
+ q,
61
+ k,
62
+ v,
63
+ k_new,
64
+ v_new,
65
+ qv,
66
+ out,
67
+ cu_seqlens_q,
68
+ cu_seqlens_k,
69
+ cu_seqlens_k_new,
70
+ seqused_q,
71
+ seqused_k,
72
+ max_seqlen_q,
73
+ max_seqlen_k,
74
+ page_table,
75
+ kv_batch_idx,
76
+ leftpad_k,
77
+ rotary_cos,
78
+ rotary_sin,
79
+ seqlens_rotary,
80
+ q_descale,
81
+ k_descale,
82
+ v_descale,
83
+ softmax_scale,
84
+ causal,
85
+ window_size[0],
86
+ window_size[1],
87
+ attention_chunk,
88
+ softcap,
89
+ rotary_interleaved,
90
+ scheduler_metadata,
91
+ num_splits,
92
+ pack_gqa,
93
+ sm_margin,
94
+ )
95
+ return out, softmax_lse, *rest
96
+
97
+
98
+ def _flash_attn_backward(
99
+ dout,
100
+ q,
101
+ k,
102
+ v,
103
+ out,
104
+ softmax_lse,
105
+ cu_seqlens_q,
106
+ cu_seqlens_k,
107
+ sequed_q,
108
+ sequed_k,
109
+ max_seqlen_q,
110
+ max_seqlen_k,
111
+ dq,
112
+ dk,
113
+ dv,
114
+ softmax_scale,
115
+ causal,
116
+ window_size=(-1, -1),
117
+ softcap=0.0,
118
+ deterministic=False,
119
+ sm_margin=0,
120
+ ):
121
+ # dq, dk, dv are allocated by us so they should already be contiguous
122
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
+ dout,
125
+ q,
126
+ k,
127
+ v,
128
+ out,
129
+ softmax_lse,
130
+ dq,
131
+ dk,
132
+ dv,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ sequed_q,
136
+ sequed_k,
137
+ max_seqlen_q,
138
+ max_seqlen_k,
139
+ softmax_scale,
140
+ causal,
141
+ window_size[0],
142
+ window_size[1],
143
+ softcap,
144
+ deterministic,
145
+ sm_margin,
146
+ )
147
+ return dq, dk, dv, softmax_d
148
+
149
+
150
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(
153
+ ctx,
154
+ qkv,
155
+ softmax_scale,
156
+ causal,
157
+ q_descale=None, k_descale=None, v_descale=None,
158
+ window_size=(-1, -1),
159
+ attention_chunk=0,
160
+ softcap=0.0,
161
+ deterministic=False,
162
+ num_heads_q=None,
163
+ sm_margin=0,
164
+ ):
165
+ if softmax_scale is None:
166
+ softmax_scale = qkv.shape[-1] ** (-0.5)
167
+ if qkv.dim() == 5:
168
+ assert qkv.shape[-3] == 3
169
+ q, k, v = qkv.unbind(dim=-3)
170
+ else:
171
+ assert qkv.dim() == 4
172
+ assert num_heads_q is not None
173
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
+ out, softmax_lse, *rest = _flash_attn_forward(
177
+ q,
178
+ k,
179
+ v,
180
+ None, None, # k_new, v_new
181
+ None, # qv
182
+ None, # out
183
+ None, None, None, # cu_seqlens_q/k/k_new
184
+ None, None, # seqused_q/k
185
+ None, None, # max_seqlen_q/k
186
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
+ None, None, None, # rotary_cos/sin, seqlens_rotary
188
+ q_descale, k_descale, v_descale,
189
+ softmax_scale,
190
+ causal=causal,
191
+ window_size=window_size,
192
+ attention_chunk=attention_chunk,
193
+ softcap=softcap,
194
+ sm_margin=sm_margin,
195
+ )
196
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
198
+ ctx.softmax_scale = softmax_scale
199
+ ctx.causal = causal
200
+ ctx.window_size = window_size
201
+ ctx.attention_chunk = attention_chunk
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ ctx.sm_margin = sm_margin
206
+ # return out, softmax_lse
207
+ return out
208
+
209
+ @staticmethod
210
+ def backward(ctx, dout, *args):
211
+ q, k, v, out, softmax_lse = ctx.saved_tensors
212
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
+ if ctx.ndim == 5:
214
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
+ dq, dk, dv = dqkv.unbind(dim=-3)
217
+ else:
218
+ num_heads_q = q.shape[2]
219
+ num_heads_k = k.shape[2]
220
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
+ _flash_attn_backward(
224
+ dout,
225
+ q,
226
+ k,
227
+ v,
228
+ out,
229
+ softmax_lse,
230
+ None, None, # cu_seqlens_q, cu_seqlens_k,
231
+ None, None, # sequed_q, sequed_k,
232
+ None, None, # max_seqlen_q, max_seqlen_k,
233
+ dq,
234
+ dk,
235
+ dv,
236
+ ctx.softmax_scale,
237
+ ctx.causal,
238
+ ctx.window_size,
239
+ ctx.softcap,
240
+ ctx.deterministic,
241
+ ctx.sm_margin,
242
+ )
243
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
+
246
+
247
+ class FlashAttnFunc(torch.autograd.Function):
248
+
249
+ @staticmethod
250
+ def forward(
251
+ ctx,
252
+ q,
253
+ k,
254
+ v,
255
+ softmax_scale,
256
+ causal,
257
+ qv=None,
258
+ q_descale=None, k_descale=None, v_descale=None,
259
+ window_size=(-1, -1),
260
+ attention_chunk=0,
261
+ softcap=0.0,
262
+ num_splits=1,
263
+ pack_gqa=None,
264
+ deterministic=False,
265
+ sm_margin=0,
266
+ ):
267
+ if softmax_scale is None:
268
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
+ out, softmax_lse, *rest = _flash_attn_forward(
271
+ q,
272
+ k,
273
+ v,
274
+ None, None, # k_new, v_new
275
+ qv, # qv
276
+ None, # out
277
+ None, None, None, # cu_seqlens_q/k/k_new
278
+ None, None, # seqused_q/k
279
+ None, None, # max_seqlen_q/k
280
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
+ None, None, None, # rotary_cos/sin, seqlens_rotary
282
+ q_descale, k_descale, v_descale,
283
+ softmax_scale,
284
+ causal=causal,
285
+ window_size=window_size,
286
+ attention_chunk=attention_chunk,
287
+ softcap=softcap,
288
+ num_splits=num_splits,
289
+ pack_gqa=pack_gqa,
290
+ sm_margin=sm_margin,
291
+ )
292
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
294
+ ctx.softmax_scale = softmax_scale
295
+ ctx.causal = causal
296
+ ctx.window_size = window_size
297
+ ctx.attention_chunk = attention_chunk
298
+ ctx.softcap = softcap
299
+ ctx.deterministic = deterministic
300
+ ctx.sm_margin = sm_margin
301
+ return out, softmax_lse
302
+
303
+ @staticmethod
304
+ def backward(ctx, dout, *args):
305
+ q, k, v, out, softmax_lse = ctx.saved_tensors
306
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
+ _flash_attn_backward(
309
+ dout,
310
+ q,
311
+ k,
312
+ v,
313
+ out,
314
+ softmax_lse,
315
+ None, None, # cu_seqlens_q, cu_seqlens_k,
316
+ None, None, # sequed_q, sequed_k,
317
+ None, None, # max_seqlen_q, max_seqlen_k,
318
+ dq,
319
+ dk,
320
+ dv,
321
+ ctx.softmax_scale,
322
+ ctx.causal,
323
+ ctx.window_size,
324
+ ctx.softcap,
325
+ ctx.deterministic,
326
+ ctx.sm_margin,
327
+ )
328
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
+ dk = dk[..., : k.shape[-1]]
330
+ dv = dv[..., : v.shape[-1]]
331
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
+
333
+
334
+ class FlashAttnVarlenFunc(torch.autograd.Function):
335
+
336
+ @staticmethod
337
+ def forward(
338
+ ctx,
339
+ q,
340
+ k,
341
+ v,
342
+ cu_seqlens_q,
343
+ cu_seqlens_k,
344
+ seqused_q,
345
+ seqused_k,
346
+ max_seqlen_q,
347
+ max_seqlen_k,
348
+ softmax_scale,
349
+ causal,
350
+ qv=None,
351
+ q_descale=None, k_descale=None, v_descale=None,
352
+ window_size=(-1, -1),
353
+ attention_chunk=0,
354
+ softcap=0.0,
355
+ num_splits=1,
356
+ pack_gqa=None,
357
+ deterministic=False,
358
+ sm_margin=0,
359
+ ):
360
+ if softmax_scale is None:
361
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
+ out, softmax_lse, *rest = _flash_attn_forward(
364
+ q,
365
+ k,
366
+ v,
367
+ None, None, # k_new, v_new
368
+ qv, # qv
369
+ None, # out
370
+ cu_seqlens_q,
371
+ cu_seqlens_k,
372
+ None, # cu_seqlens_k_new
373
+ seqused_q,
374
+ seqused_k,
375
+ max_seqlen_q,
376
+ max_seqlen_k,
377
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
+ None, None, None, # rotary_cos/sin, seqlens_rotary
379
+ q_descale, k_descale, v_descale,
380
+ softmax_scale,
381
+ causal=causal,
382
+ window_size=window_size,
383
+ attention_chunk=attention_chunk,
384
+ softcap=softcap,
385
+ num_splits=num_splits,
386
+ pack_gqa=pack_gqa,
387
+ sm_margin=sm_margin,
388
+ )
389
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
+ ctx.max_seqlen_q = max_seqlen_q
392
+ ctx.max_seqlen_k = max_seqlen_k
393
+ ctx.softmax_scale = softmax_scale
394
+ ctx.causal = causal
395
+ ctx.window_size = window_size
396
+ ctx.attention_chunk = attention_chunk
397
+ ctx.softcap = softcap
398
+ ctx.deterministic = deterministic
399
+ ctx.sm_margin = sm_margin
400
+ return out, softmax_lse
401
+
402
+ @staticmethod
403
+ def backward(ctx, dout, *args):
404
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
+ _flash_attn_backward(
408
+ dout,
409
+ q,
410
+ k,
411
+ v,
412
+ out,
413
+ softmax_lse,
414
+ cu_seqlens_q,
415
+ cu_seqlens_k,
416
+ seqused_q,
417
+ seqused_k,
418
+ ctx.max_seqlen_q,
419
+ ctx.max_seqlen_k,
420
+ dq,
421
+ dk,
422
+ dv,
423
+ ctx.softmax_scale,
424
+ ctx.causal,
425
+ ctx.window_size,
426
+ ctx.softcap,
427
+ ctx.deterministic,
428
+ ctx.sm_margin,
429
+ )
430
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
+ dk = dk[..., : k.shape[-1]]
432
+ dv = dv[..., : v.shape[-1]]
433
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
+
435
+
436
+ def flash_attn_qkvpacked_func(
437
+ qkv,
438
+ softmax_scale=None,
439
+ causal=False,
440
+ q_descale=None, k_descale=None, v_descale=None,
441
+ window_size=(-1, -1),
442
+ attention_chunk=0,
443
+ softcap=0.0,
444
+ deterministic=False,
445
+ num_heads_q=None,
446
+ sm_margin=0,
447
+ ):
448
+ """dropout_p should be set to 0.0 during evaluation
449
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
+ of the gradients of Q, K, V.
452
+ For multi-query and grouped-query attention (MQA/GQA), please see
453
+ flash_attn_kvpacked_func and flash_attn_func.
454
+
455
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
+
458
+ Arguments:
459
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
460
+ dropout_p: float. Dropout probability.
461
+ softmax_scale: float. The scaling of QK^T before applying softmax.
462
+ Default to 1 / sqrt(headdim).
463
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
+ softcap: float. Anything > 0 activates softcapping attention.
466
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
+ the attention score of query i and key j.
468
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
470
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
+ testing only. The returned probabilities are not guaranteed to be correct
472
+ (they might not have the right scaling).
473
+ Return:
474
+ out: (batch_size, seqlen, nheads, headdim).
475
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
+ normalization factor).
478
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
+ The output of softmax (possibly with different scaling). It also encodes the dropout
480
+ pattern (negative means that location was dropped, nonnegative means it was kept).
481
+ """
482
+ return FlashAttnQKVPackedFunc.apply(
483
+ qkv,
484
+ softmax_scale,
485
+ causal,
486
+ q_descale, k_descale, v_descale,
487
+ window_size,
488
+ attention_chunk,
489
+ softcap,
490
+ deterministic,
491
+ num_heads_q,
492
+ sm_margin,
493
+ )
494
+
495
+
496
+ def flash_attn_func(
497
+ q,
498
+ k,
499
+ v,
500
+ softmax_scale=None,
501
+ causal=False,
502
+ qv=None,
503
+ q_descale=None, k_descale=None, v_descale=None,
504
+ window_size=(-1, -1),
505
+ attention_chunk=0,
506
+ softcap=0.0,
507
+ num_splits=1,
508
+ pack_gqa=None,
509
+ deterministic=False,
510
+ sm_margin=0,
511
+ ):
512
+ """dropout_p should be set to 0.0 during evaluation
513
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
+
518
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
+ 1 1 1 1 0
521
+ 1 1 1 1 1
522
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
+ 0 0
524
+ 0 0
525
+ 0 0
526
+ 1 0
527
+ 1 1
528
+ If the row of the mask is all zero, the output will be zero.
529
+
530
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
+ will only attend to keys between
532
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
+
534
+ Arguments:
535
+ q: (batch_size, seqlen, nheads, headdim)
536
+ k: (batch_size, seqlen, nheads_k, headdim)
537
+ v: (batch_size, seqlen, nheads_k, headdim)
538
+ dropout_p: float. Dropout probability.
539
+ softmax_scale: float. The scaling of QK^T before applying softmax.
540
+ Default to 1 / sqrt(headdim).
541
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
+ is added to the attention score of query i and key j.
546
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
548
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
+ testing only. The returned probabilities are not guaranteed to be correct
550
+ (they might not have the right scaling).
551
+ Return:
552
+ out: (batch_size, seqlen, nheads, headdim).
553
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
+ normalization factor).
556
+ """
557
+ return FlashAttnFunc.apply(
558
+ q,
559
+ k,
560
+ v,
561
+ softmax_scale,
562
+ causal,
563
+ qv,
564
+ q_descale, k_descale, v_descale,
565
+ window_size,
566
+ attention_chunk,
567
+ softcap,
568
+ num_splits,
569
+ pack_gqa,
570
+ deterministic,
571
+ sm_margin,
572
+ )
573
+
574
+
575
+ def flash_attn_varlen_func(
576
+ q,
577
+ k,
578
+ v,
579
+ cu_seqlens_q,
580
+ cu_seqlens_k,
581
+ max_seqlen_q,
582
+ max_seqlen_k,
583
+ seqused_q=None,
584
+ seqused_k=None,
585
+ softmax_scale=None,
586
+ causal=False,
587
+ qv=None,
588
+ q_descale=None, k_descale=None, v_descale=None,
589
+ window_size=(-1, -1),
590
+ attention_chunk=0,
591
+ softcap=0.0,
592
+ num_splits=1,
593
+ pack_gqa=None,
594
+ deterministic=False,
595
+ sm_margin=0,
596
+ ):
597
+ return FlashAttnVarlenFunc.apply(
598
+ q,
599
+ k,
600
+ v,
601
+ cu_seqlens_q,
602
+ cu_seqlens_k,
603
+ seqused_q,
604
+ seqused_k,
605
+ max_seqlen_q,
606
+ max_seqlen_k,
607
+ softmax_scale,
608
+ causal,
609
+ qv,
610
+ q_descale, k_descale, v_descale,
611
+ window_size,
612
+ attention_chunk,
613
+ softcap,
614
+ num_splits,
615
+ pack_gqa,
616
+ deterministic,
617
+ sm_margin,
618
+ )
619
+
620
+
621
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
+
624
+
625
+ def flash_attn_with_kvcache(
626
+ q,
627
+ k_cache,
628
+ v_cache,
629
+ k=None,
630
+ v=None,
631
+ qv=None,
632
+ rotary_cos=None,
633
+ rotary_sin=None,
634
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
+ cache_batch_idx: Optional[torch.Tensor] = None,
636
+ cache_leftpad: Optional[torch.Tensor] = None,
637
+ page_table: Optional[torch.Tensor] = None,
638
+ cu_seqlens_q: Optional[torch.Tensor] = None,
639
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
+ max_seqlen_q: Optional[int] = None,
641
+ rotary_seqlens: Optional[torch.Tensor] = None,
642
+ q_descale: Optional[torch.Tensor] = None,
643
+ k_descale: Optional[torch.Tensor] = None,
644
+ v_descale: Optional[torch.Tensor] = None,
645
+ softmax_scale=None,
646
+ causal=False,
647
+ window_size=(-1, -1), # -1 means infinite context window
648
+ attention_chunk=0,
649
+ softcap=0.0, # 0.0 means deactivated
650
+ rotary_interleaved=True,
651
+ scheduler_metadata=None,
652
+ num_splits=0, # Can be tuned for speed
653
+ pack_gqa=None, # Can be tuned for speed
654
+ sm_margin=0, # Can be tuned if some SMs are used for communication
655
+ return_softmax_lse=False,
656
+ ):
657
+ """
658
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
+ the previous step, and update them with the new keys/values from the current step, and do
661
+ attention with the updated cache, all in 1 kernel.
662
+
663
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
+
667
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
+
674
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
+
676
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
+
681
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
+ 1 1 1 1 0
684
+ 1 1 1 1 1
685
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
+ 0 0
687
+ 0 0
688
+ 0 0
689
+ 1 0
690
+ 1 1
691
+ If the row of the mask is all zero, the output will be zero.
692
+
693
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
+ will only attend to keys between
695
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
+
697
+ Note: Does not support backward pass.
698
+
699
+ Arguments:
700
+ q: (batch_size, seqlen, nheads, headdim)
701
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
+ page_block_size must be a multiple of 256.
704
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
+ k with k_cache, starting at the indices specified by cache_seqlens.
708
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
+ KV cache.
715
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
718
+ might come from any of the duplicate indices.
719
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
+ softmax_scale: float. The scaling of QK^T before applying softmax.
722
+ Default to 1 / sqrt(headdim).
723
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
+ softcap: float. Anything > 0 activates softcapping attention.
726
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
+ (i.e. GPT-NeoX style).
730
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
+ to automatically determine the number of splits.
733
+ Don't change this unless you know what you are doing.
734
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
+
736
+ Return:
737
+ out: (batch_size, seqlen, nheads, headdim).
738
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
+ normalization factor).
741
+ """
742
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
+ if softmax_scale is None:
745
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
+ cache_seqlens = torch.full(
748
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
+ )
750
+ cache_seqlens = maybe_contiguous(cache_seqlens)
751
+ out, softmax_lse, *rest = _flash_attn_forward(
752
+ q,
753
+ k_cache,
754
+ v_cache,
755
+ k,
756
+ v,
757
+ qv,
758
+ None, # out
759
+ cu_seqlens_q,
760
+ None, # cu_seqlens_k
761
+ cu_seqlens_k_new,
762
+ None, # seqused_q
763
+ cache_seqlens,
764
+ max_seqlen_q,
765
+ None, # max_seqlen_k
766
+ page_table,
767
+ cache_batch_idx,
768
+ cache_leftpad,
769
+ rotary_cos,
770
+ rotary_sin,
771
+ rotary_seqlens,
772
+ q_descale, k_descale, v_descale,
773
+ softmax_scale,
774
+ causal=causal,
775
+ window_size=window_size,
776
+ attention_chunk=attention_chunk,
777
+ softcap=softcap,
778
+ rotary_interleaved=rotary_interleaved,
779
+ scheduler_metadata=scheduler_metadata,
780
+ num_splits=num_splits,
781
+ pack_gqa=pack_gqa,
782
+ sm_margin=sm_margin,
783
+ )
784
+ # return (out, softmax_lse) if return_softmax_lse else out
785
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
786
+
787
+
788
+ def get_scheduler_metadata(
789
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
+ cache_seqlens: torch.Tensor,
791
+ qkv_dtype=torch.bfloat16,
792
+ headdim_v=None,
793
+ cu_seqlens_q: Optional[torch.Tensor] = None,
794
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
+ cache_leftpad: Optional[torch.Tensor] = None,
796
+ page_size: Optional[int] = None,
797
+ max_seqlen_k_new=0,
798
+ causal=False,
799
+ window_size=(-1, -1), # -1 means infinite context window
800
+ attention_chunk=0,
801
+ has_softcap=False,
802
+ num_splits=0, # Can be tuned for speed
803
+ pack_gqa=None, # Can be tuned for speed
804
+ sm_margin=0, # Can be tuned if some SMs are used for communication
805
+ ):
806
+ cache_seqlens = maybe_contiguous(cache_seqlens)
807
+ if headdim_v is None:
808
+ headdim_v = headdim
809
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
+ qkv_dtype,
812
+ cache_seqlens,
813
+ cu_seqlens_q,
814
+ None, # cu_seqlens_k
815
+ cu_seqlens_k_new,
816
+ None, # seqused_q
817
+ cache_leftpad,
818
+ page_size,
819
+ max_seqlen_k_new,
820
+ causal,
821
+ window_size[0], window_size[1],
822
+ attention_chunk,
823
+ has_softcap,
824
+ num_splits,
825
+ pack_gqa,
826
+ sm_margin,
827
+ )
828
+ return scheduler_metadata
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9627e08ec8778d2a409a2a0477572edb3e03eaca2b45e7b4810ee0a9126d6547
3
+ size 838456048
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07fe025ba95671f6ff957991f74c66063bfb10ab6737641c88f88116c9f83718
3
+ size 838456048
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn3_557701f
3
+ ops = torch.ops._flash_attn3_557701f
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn3_557701f::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ._ops import ops as flash_attn_3_cuda
9
+
10
+ def maybe_contiguous(x):
11
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
+
13
+
14
+ def _flash_attn_forward(
15
+ q,
16
+ k,
17
+ v,
18
+ k_new,
19
+ v_new,
20
+ qv,
21
+ out,
22
+ cu_seqlens_q,
23
+ cu_seqlens_k,
24
+ cu_seqlens_k_new,
25
+ seqused_q,
26
+ seqused_k,
27
+ max_seqlen_q,
28
+ max_seqlen_k,
29
+ page_table,
30
+ kv_batch_idx,
31
+ leftpad_k,
32
+ rotary_cos,
33
+ rotary_sin,
34
+ seqlens_rotary,
35
+ q_descale,
36
+ k_descale,
37
+ v_descale,
38
+ softmax_scale,
39
+ causal,
40
+ window_size=(-1, -1),
41
+ attention_chunk=0,
42
+ softcap=0.0,
43
+ rotary_interleaved=True,
44
+ scheduler_metadata=None,
45
+ num_splits=1,
46
+ pack_gqa=None,
47
+ sm_margin=0):
48
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
+ ]
53
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
+ page_table, kv_batch_idx, leftpad_k = [
55
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
+ ]
57
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
+ out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
+ q,
61
+ k,
62
+ v,
63
+ k_new,
64
+ v_new,
65
+ qv,
66
+ out,
67
+ cu_seqlens_q,
68
+ cu_seqlens_k,
69
+ cu_seqlens_k_new,
70
+ seqused_q,
71
+ seqused_k,
72
+ max_seqlen_q,
73
+ max_seqlen_k,
74
+ page_table,
75
+ kv_batch_idx,
76
+ leftpad_k,
77
+ rotary_cos,
78
+ rotary_sin,
79
+ seqlens_rotary,
80
+ q_descale,
81
+ k_descale,
82
+ v_descale,
83
+ softmax_scale,
84
+ causal,
85
+ window_size[0],
86
+ window_size[1],
87
+ attention_chunk,
88
+ softcap,
89
+ rotary_interleaved,
90
+ scheduler_metadata,
91
+ num_splits,
92
+ pack_gqa,
93
+ sm_margin,
94
+ )
95
+ return out, softmax_lse, *rest
96
+
97
+
98
+ def _flash_attn_backward(
99
+ dout,
100
+ q,
101
+ k,
102
+ v,
103
+ out,
104
+ softmax_lse,
105
+ cu_seqlens_q,
106
+ cu_seqlens_k,
107
+ sequed_q,
108
+ sequed_k,
109
+ max_seqlen_q,
110
+ max_seqlen_k,
111
+ dq,
112
+ dk,
113
+ dv,
114
+ softmax_scale,
115
+ causal,
116
+ window_size=(-1, -1),
117
+ softcap=0.0,
118
+ deterministic=False,
119
+ sm_margin=0,
120
+ ):
121
+ # dq, dk, dv are allocated by us so they should already be contiguous
122
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
+ dout,
125
+ q,
126
+ k,
127
+ v,
128
+ out,
129
+ softmax_lse,
130
+ dq,
131
+ dk,
132
+ dv,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ sequed_q,
136
+ sequed_k,
137
+ max_seqlen_q,
138
+ max_seqlen_k,
139
+ softmax_scale,
140
+ causal,
141
+ window_size[0],
142
+ window_size[1],
143
+ softcap,
144
+ deterministic,
145
+ sm_margin,
146
+ )
147
+ return dq, dk, dv, softmax_d
148
+
149
+
150
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(
153
+ ctx,
154
+ qkv,
155
+ softmax_scale,
156
+ causal,
157
+ q_descale=None, k_descale=None, v_descale=None,
158
+ window_size=(-1, -1),
159
+ attention_chunk=0,
160
+ softcap=0.0,
161
+ deterministic=False,
162
+ num_heads_q=None,
163
+ sm_margin=0,
164
+ ):
165
+ if softmax_scale is None:
166
+ softmax_scale = qkv.shape[-1] ** (-0.5)
167
+ if qkv.dim() == 5:
168
+ assert qkv.shape[-3] == 3
169
+ q, k, v = qkv.unbind(dim=-3)
170
+ else:
171
+ assert qkv.dim() == 4
172
+ assert num_heads_q is not None
173
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
+ out, softmax_lse, *rest = _flash_attn_forward(
177
+ q,
178
+ k,
179
+ v,
180
+ None, None, # k_new, v_new
181
+ None, # qv
182
+ None, # out
183
+ None, None, None, # cu_seqlens_q/k/k_new
184
+ None, None, # seqused_q/k
185
+ None, None, # max_seqlen_q/k
186
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
+ None, None, None, # rotary_cos/sin, seqlens_rotary
188
+ q_descale, k_descale, v_descale,
189
+ softmax_scale,
190
+ causal=causal,
191
+ window_size=window_size,
192
+ attention_chunk=attention_chunk,
193
+ softcap=softcap,
194
+ sm_margin=sm_margin,
195
+ )
196
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
198
+ ctx.softmax_scale = softmax_scale
199
+ ctx.causal = causal
200
+ ctx.window_size = window_size
201
+ ctx.attention_chunk = attention_chunk
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ ctx.sm_margin = sm_margin
206
+ # return out, softmax_lse
207
+ return out
208
+
209
+ @staticmethod
210
+ def backward(ctx, dout, *args):
211
+ q, k, v, out, softmax_lse = ctx.saved_tensors
212
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
+ if ctx.ndim == 5:
214
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
+ dq, dk, dv = dqkv.unbind(dim=-3)
217
+ else:
218
+ num_heads_q = q.shape[2]
219
+ num_heads_k = k.shape[2]
220
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
+ _flash_attn_backward(
224
+ dout,
225
+ q,
226
+ k,
227
+ v,
228
+ out,
229
+ softmax_lse,
230
+ None, None, # cu_seqlens_q, cu_seqlens_k,
231
+ None, None, # sequed_q, sequed_k,
232
+ None, None, # max_seqlen_q, max_seqlen_k,
233
+ dq,
234
+ dk,
235
+ dv,
236
+ ctx.softmax_scale,
237
+ ctx.causal,
238
+ ctx.window_size,
239
+ ctx.softcap,
240
+ ctx.deterministic,
241
+ ctx.sm_margin,
242
+ )
243
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
+
246
+
247
+ class FlashAttnFunc(torch.autograd.Function):
248
+
249
+ @staticmethod
250
+ def forward(
251
+ ctx,
252
+ q,
253
+ k,
254
+ v,
255
+ softmax_scale,
256
+ causal,
257
+ qv=None,
258
+ q_descale=None, k_descale=None, v_descale=None,
259
+ window_size=(-1, -1),
260
+ attention_chunk=0,
261
+ softcap=0.0,
262
+ num_splits=1,
263
+ pack_gqa=None,
264
+ deterministic=False,
265
+ sm_margin=0,
266
+ ):
267
+ if softmax_scale is None:
268
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
+ out, softmax_lse, *rest = _flash_attn_forward(
271
+ q,
272
+ k,
273
+ v,
274
+ None, None, # k_new, v_new
275
+ qv, # qv
276
+ None, # out
277
+ None, None, None, # cu_seqlens_q/k/k_new
278
+ None, None, # seqused_q/k
279
+ None, None, # max_seqlen_q/k
280
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
+ None, None, None, # rotary_cos/sin, seqlens_rotary
282
+ q_descale, k_descale, v_descale,
283
+ softmax_scale,
284
+ causal=causal,
285
+ window_size=window_size,
286
+ attention_chunk=attention_chunk,
287
+ softcap=softcap,
288
+ num_splits=num_splits,
289
+ pack_gqa=pack_gqa,
290
+ sm_margin=sm_margin,
291
+ )
292
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
294
+ ctx.softmax_scale = softmax_scale
295
+ ctx.causal = causal
296
+ ctx.window_size = window_size
297
+ ctx.attention_chunk = attention_chunk
298
+ ctx.softcap = softcap
299
+ ctx.deterministic = deterministic
300
+ ctx.sm_margin = sm_margin
301
+ return out, softmax_lse
302
+
303
+ @staticmethod
304
+ def backward(ctx, dout, *args):
305
+ q, k, v, out, softmax_lse = ctx.saved_tensors
306
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
+ _flash_attn_backward(
309
+ dout,
310
+ q,
311
+ k,
312
+ v,
313
+ out,
314
+ softmax_lse,
315
+ None, None, # cu_seqlens_q, cu_seqlens_k,
316
+ None, None, # sequed_q, sequed_k,
317
+ None, None, # max_seqlen_q, max_seqlen_k,
318
+ dq,
319
+ dk,
320
+ dv,
321
+ ctx.softmax_scale,
322
+ ctx.causal,
323
+ ctx.window_size,
324
+ ctx.softcap,
325
+ ctx.deterministic,
326
+ ctx.sm_margin,
327
+ )
328
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
+ dk = dk[..., : k.shape[-1]]
330
+ dv = dv[..., : v.shape[-1]]
331
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
+
333
+
334
+ class FlashAttnVarlenFunc(torch.autograd.Function):
335
+
336
+ @staticmethod
337
+ def forward(
338
+ ctx,
339
+ q,
340
+ k,
341
+ v,
342
+ cu_seqlens_q,
343
+ cu_seqlens_k,
344
+ seqused_q,
345
+ seqused_k,
346
+ max_seqlen_q,
347
+ max_seqlen_k,
348
+ softmax_scale,
349
+ causal,
350
+ qv=None,
351
+ q_descale=None, k_descale=None, v_descale=None,
352
+ window_size=(-1, -1),
353
+ attention_chunk=0,
354
+ softcap=0.0,
355
+ num_splits=1,
356
+ pack_gqa=None,
357
+ deterministic=False,
358
+ sm_margin=0,
359
+ ):
360
+ if softmax_scale is None:
361
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
+ out, softmax_lse, *rest = _flash_attn_forward(
364
+ q,
365
+ k,
366
+ v,
367
+ None, None, # k_new, v_new
368
+ qv, # qv
369
+ None, # out
370
+ cu_seqlens_q,
371
+ cu_seqlens_k,
372
+ None, # cu_seqlens_k_new
373
+ seqused_q,
374
+ seqused_k,
375
+ max_seqlen_q,
376
+ max_seqlen_k,
377
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
+ None, None, None, # rotary_cos/sin, seqlens_rotary
379
+ q_descale, k_descale, v_descale,
380
+ softmax_scale,
381
+ causal=causal,
382
+ window_size=window_size,
383
+ attention_chunk=attention_chunk,
384
+ softcap=softcap,
385
+ num_splits=num_splits,
386
+ pack_gqa=pack_gqa,
387
+ sm_margin=sm_margin,
388
+ )
389
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
+ ctx.max_seqlen_q = max_seqlen_q
392
+ ctx.max_seqlen_k = max_seqlen_k
393
+ ctx.softmax_scale = softmax_scale
394
+ ctx.causal = causal
395
+ ctx.window_size = window_size
396
+ ctx.attention_chunk = attention_chunk
397
+ ctx.softcap = softcap
398
+ ctx.deterministic = deterministic
399
+ ctx.sm_margin = sm_margin
400
+ return out, softmax_lse
401
+
402
+ @staticmethod
403
+ def backward(ctx, dout, *args):
404
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
+ _flash_attn_backward(
408
+ dout,
409
+ q,
410
+ k,
411
+ v,
412
+ out,
413
+ softmax_lse,
414
+ cu_seqlens_q,
415
+ cu_seqlens_k,
416
+ seqused_q,
417
+ seqused_k,
418
+ ctx.max_seqlen_q,
419
+ ctx.max_seqlen_k,
420
+ dq,
421
+ dk,
422
+ dv,
423
+ ctx.softmax_scale,
424
+ ctx.causal,
425
+ ctx.window_size,
426
+ ctx.softcap,
427
+ ctx.deterministic,
428
+ ctx.sm_margin,
429
+ )
430
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
+ dk = dk[..., : k.shape[-1]]
432
+ dv = dv[..., : v.shape[-1]]
433
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
+
435
+
436
+ def flash_attn_qkvpacked_func(
437
+ qkv,
438
+ softmax_scale=None,
439
+ causal=False,
440
+ q_descale=None, k_descale=None, v_descale=None,
441
+ window_size=(-1, -1),
442
+ attention_chunk=0,
443
+ softcap=0.0,
444
+ deterministic=False,
445
+ num_heads_q=None,
446
+ sm_margin=0,
447
+ ):
448
+ """dropout_p should be set to 0.0 during evaluation
449
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
+ of the gradients of Q, K, V.
452
+ For multi-query and grouped-query attention (MQA/GQA), please see
453
+ flash_attn_kvpacked_func and flash_attn_func.
454
+
455
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
+
458
+ Arguments:
459
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
460
+ dropout_p: float. Dropout probability.
461
+ softmax_scale: float. The scaling of QK^T before applying softmax.
462
+ Default to 1 / sqrt(headdim).
463
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
+ softcap: float. Anything > 0 activates softcapping attention.
466
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
+ the attention score of query i and key j.
468
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
470
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
+ testing only. The returned probabilities are not guaranteed to be correct
472
+ (they might not have the right scaling).
473
+ Return:
474
+ out: (batch_size, seqlen, nheads, headdim).
475
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
+ normalization factor).
478
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
+ The output of softmax (possibly with different scaling). It also encodes the dropout
480
+ pattern (negative means that location was dropped, nonnegative means it was kept).
481
+ """
482
+ return FlashAttnQKVPackedFunc.apply(
483
+ qkv,
484
+ softmax_scale,
485
+ causal,
486
+ q_descale, k_descale, v_descale,
487
+ window_size,
488
+ attention_chunk,
489
+ softcap,
490
+ deterministic,
491
+ num_heads_q,
492
+ sm_margin,
493
+ )
494
+
495
+
496
+ def flash_attn_func(
497
+ q,
498
+ k,
499
+ v,
500
+ softmax_scale=None,
501
+ causal=False,
502
+ qv=None,
503
+ q_descale=None, k_descale=None, v_descale=None,
504
+ window_size=(-1, -1),
505
+ attention_chunk=0,
506
+ softcap=0.0,
507
+ num_splits=1,
508
+ pack_gqa=None,
509
+ deterministic=False,
510
+ sm_margin=0,
511
+ ):
512
+ """dropout_p should be set to 0.0 during evaluation
513
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
+
518
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
+ 1 1 1 1 0
521
+ 1 1 1 1 1
522
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
+ 0 0
524
+ 0 0
525
+ 0 0
526
+ 1 0
527
+ 1 1
528
+ If the row of the mask is all zero, the output will be zero.
529
+
530
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
+ will only attend to keys between
532
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
+
534
+ Arguments:
535
+ q: (batch_size, seqlen, nheads, headdim)
536
+ k: (batch_size, seqlen, nheads_k, headdim)
537
+ v: (batch_size, seqlen, nheads_k, headdim)
538
+ dropout_p: float. Dropout probability.
539
+ softmax_scale: float. The scaling of QK^T before applying softmax.
540
+ Default to 1 / sqrt(headdim).
541
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
+ is added to the attention score of query i and key j.
546
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
548
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
+ testing only. The returned probabilities are not guaranteed to be correct
550
+ (they might not have the right scaling).
551
+ Return:
552
+ out: (batch_size, seqlen, nheads, headdim).
553
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
+ normalization factor).
556
+ """
557
+ return FlashAttnFunc.apply(
558
+ q,
559
+ k,
560
+ v,
561
+ softmax_scale,
562
+ causal,
563
+ qv,
564
+ q_descale, k_descale, v_descale,
565
+ window_size,
566
+ attention_chunk,
567
+ softcap,
568
+ num_splits,
569
+ pack_gqa,
570
+ deterministic,
571
+ sm_margin,
572
+ )
573
+
574
+
575
+ def flash_attn_varlen_func(
576
+ q,
577
+ k,
578
+ v,
579
+ cu_seqlens_q,
580
+ cu_seqlens_k,
581
+ max_seqlen_q,
582
+ max_seqlen_k,
583
+ seqused_q=None,
584
+ seqused_k=None,
585
+ softmax_scale=None,
586
+ causal=False,
587
+ qv=None,
588
+ q_descale=None, k_descale=None, v_descale=None,
589
+ window_size=(-1, -1),
590
+ attention_chunk=0,
591
+ softcap=0.0,
592
+ num_splits=1,
593
+ pack_gqa=None,
594
+ deterministic=False,
595
+ sm_margin=0,
596
+ ):
597
+ return FlashAttnVarlenFunc.apply(
598
+ q,
599
+ k,
600
+ v,
601
+ cu_seqlens_q,
602
+ cu_seqlens_k,
603
+ seqused_q,
604
+ seqused_k,
605
+ max_seqlen_q,
606
+ max_seqlen_k,
607
+ softmax_scale,
608
+ causal,
609
+ qv,
610
+ q_descale, k_descale, v_descale,
611
+ window_size,
612
+ attention_chunk,
613
+ softcap,
614
+ num_splits,
615
+ pack_gqa,
616
+ deterministic,
617
+ sm_margin,
618
+ )
619
+
620
+
621
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
+
624
+
625
+ def flash_attn_with_kvcache(
626
+ q,
627
+ k_cache,
628
+ v_cache,
629
+ k=None,
630
+ v=None,
631
+ qv=None,
632
+ rotary_cos=None,
633
+ rotary_sin=None,
634
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
+ cache_batch_idx: Optional[torch.Tensor] = None,
636
+ cache_leftpad: Optional[torch.Tensor] = None,
637
+ page_table: Optional[torch.Tensor] = None,
638
+ cu_seqlens_q: Optional[torch.Tensor] = None,
639
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
+ max_seqlen_q: Optional[int] = None,
641
+ rotary_seqlens: Optional[torch.Tensor] = None,
642
+ q_descale: Optional[torch.Tensor] = None,
643
+ k_descale: Optional[torch.Tensor] = None,
644
+ v_descale: Optional[torch.Tensor] = None,
645
+ softmax_scale=None,
646
+ causal=False,
647
+ window_size=(-1, -1), # -1 means infinite context window
648
+ attention_chunk=0,
649
+ softcap=0.0, # 0.0 means deactivated
650
+ rotary_interleaved=True,
651
+ scheduler_metadata=None,
652
+ num_splits=0, # Can be tuned for speed
653
+ pack_gqa=None, # Can be tuned for speed
654
+ sm_margin=0, # Can be tuned if some SMs are used for communication
655
+ return_softmax_lse=False,
656
+ ):
657
+ """
658
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
+ the previous step, and update them with the new keys/values from the current step, and do
661
+ attention with the updated cache, all in 1 kernel.
662
+
663
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
+
667
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
+
674
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
+
676
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
+
681
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
+ 1 1 1 1 0
684
+ 1 1 1 1 1
685
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
+ 0 0
687
+ 0 0
688
+ 0 0
689
+ 1 0
690
+ 1 1
691
+ If the row of the mask is all zero, the output will be zero.
692
+
693
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
+ will only attend to keys between
695
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
+
697
+ Note: Does not support backward pass.
698
+
699
+ Arguments:
700
+ q: (batch_size, seqlen, nheads, headdim)
701
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
+ page_block_size must be a multiple of 256.
704
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
+ k with k_cache, starting at the indices specified by cache_seqlens.
708
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
+ KV cache.
715
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
718
+ might come from any of the duplicate indices.
719
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
+ softmax_scale: float. The scaling of QK^T before applying softmax.
722
+ Default to 1 / sqrt(headdim).
723
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
+ softcap: float. Anything > 0 activates softcapping attention.
726
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
+ (i.e. GPT-NeoX style).
730
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
+ to automatically determine the number of splits.
733
+ Don't change this unless you know what you are doing.
734
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
+
736
+ Return:
737
+ out: (batch_size, seqlen, nheads, headdim).
738
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
+ normalization factor).
741
+ """
742
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
+ if softmax_scale is None:
745
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
+ cache_seqlens = torch.full(
748
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
+ )
750
+ cache_seqlens = maybe_contiguous(cache_seqlens)
751
+ out, softmax_lse, *rest = _flash_attn_forward(
752
+ q,
753
+ k_cache,
754
+ v_cache,
755
+ k,
756
+ v,
757
+ qv,
758
+ None, # out
759
+ cu_seqlens_q,
760
+ None, # cu_seqlens_k
761
+ cu_seqlens_k_new,
762
+ None, # seqused_q
763
+ cache_seqlens,
764
+ max_seqlen_q,
765
+ None, # max_seqlen_k
766
+ page_table,
767
+ cache_batch_idx,
768
+ cache_leftpad,
769
+ rotary_cos,
770
+ rotary_sin,
771
+ rotary_seqlens,
772
+ q_descale, k_descale, v_descale,
773
+ softmax_scale,
774
+ causal=causal,
775
+ window_size=window_size,
776
+ attention_chunk=attention_chunk,
777
+ softcap=softcap,
778
+ rotary_interleaved=rotary_interleaved,
779
+ scheduler_metadata=scheduler_metadata,
780
+ num_splits=num_splits,
781
+ pack_gqa=pack_gqa,
782
+ sm_margin=sm_margin,
783
+ )
784
+ # return (out, softmax_lse) if return_softmax_lse else out
785
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
786
+
787
+
788
+ def get_scheduler_metadata(
789
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
+ cache_seqlens: torch.Tensor,
791
+ qkv_dtype=torch.bfloat16,
792
+ headdim_v=None,
793
+ cu_seqlens_q: Optional[torch.Tensor] = None,
794
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
+ cache_leftpad: Optional[torch.Tensor] = None,
796
+ page_size: Optional[int] = None,
797
+ max_seqlen_k_new=0,
798
+ causal=False,
799
+ window_size=(-1, -1), # -1 means infinite context window
800
+ attention_chunk=0,
801
+ has_softcap=False,
802
+ num_splits=0, # Can be tuned for speed
803
+ pack_gqa=None, # Can be tuned for speed
804
+ sm_margin=0, # Can be tuned if some SMs are used for communication
805
+ ):
806
+ cache_seqlens = maybe_contiguous(cache_seqlens)
807
+ if headdim_v is None:
808
+ headdim_v = headdim
809
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
+ qkv_dtype,
812
+ cache_seqlens,
813
+ cu_seqlens_q,
814
+ None, # cu_seqlens_k
815
+ cu_seqlens_k_new,
816
+ None, # seqused_q
817
+ cache_leftpad,
818
+ page_size,
819
+ max_seqlen_k_new,
820
+ causal,
821
+ window_size[0], window_size[1],
822
+ attention_chunk,
823
+ has_softcap,
824
+ num_splits,
825
+ pack_gqa,
826
+ sm_margin,
827
+ )
828
+ return scheduler_metadata
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9627e08ec8778d2a409a2a0477572edb3e03eaca2b45e7b4810ee0a9126d6547
3
+ size 838456048
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07fe025ba95671f6ff957991f74c66063bfb10ab6737641c88f88116c9f83718
3
+ size 838456048
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn3_557701f
3
+ ops = torch.ops._flash_attn3_557701f
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn3_557701f::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ._ops import ops as flash_attn_3_cuda
9
+
10
+ def maybe_contiguous(x):
11
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
+
13
+
14
+ def _flash_attn_forward(
15
+ q,
16
+ k,
17
+ v,
18
+ k_new,
19
+ v_new,
20
+ qv,
21
+ out,
22
+ cu_seqlens_q,
23
+ cu_seqlens_k,
24
+ cu_seqlens_k_new,
25
+ seqused_q,
26
+ seqused_k,
27
+ max_seqlen_q,
28
+ max_seqlen_k,
29
+ page_table,
30
+ kv_batch_idx,
31
+ leftpad_k,
32
+ rotary_cos,
33
+ rotary_sin,
34
+ seqlens_rotary,
35
+ q_descale,
36
+ k_descale,
37
+ v_descale,
38
+ softmax_scale,
39
+ causal,
40
+ window_size=(-1, -1),
41
+ attention_chunk=0,
42
+ softcap=0.0,
43
+ rotary_interleaved=True,
44
+ scheduler_metadata=None,
45
+ num_splits=1,
46
+ pack_gqa=None,
47
+ sm_margin=0):
48
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
+ ]
53
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
+ page_table, kv_batch_idx, leftpad_k = [
55
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
+ ]
57
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
+ out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
+ q,
61
+ k,
62
+ v,
63
+ k_new,
64
+ v_new,
65
+ qv,
66
+ out,
67
+ cu_seqlens_q,
68
+ cu_seqlens_k,
69
+ cu_seqlens_k_new,
70
+ seqused_q,
71
+ seqused_k,
72
+ max_seqlen_q,
73
+ max_seqlen_k,
74
+ page_table,
75
+ kv_batch_idx,
76
+ leftpad_k,
77
+ rotary_cos,
78
+ rotary_sin,
79
+ seqlens_rotary,
80
+ q_descale,
81
+ k_descale,
82
+ v_descale,
83
+ softmax_scale,
84
+ causal,
85
+ window_size[0],
86
+ window_size[1],
87
+ attention_chunk,
88
+ softcap,
89
+ rotary_interleaved,
90
+ scheduler_metadata,
91
+ num_splits,
92
+ pack_gqa,
93
+ sm_margin,
94
+ )
95
+ return out, softmax_lse, *rest
96
+
97
+
98
+ def _flash_attn_backward(
99
+ dout,
100
+ q,
101
+ k,
102
+ v,
103
+ out,
104
+ softmax_lse,
105
+ cu_seqlens_q,
106
+ cu_seqlens_k,
107
+ sequed_q,
108
+ sequed_k,
109
+ max_seqlen_q,
110
+ max_seqlen_k,
111
+ dq,
112
+ dk,
113
+ dv,
114
+ softmax_scale,
115
+ causal,
116
+ window_size=(-1, -1),
117
+ softcap=0.0,
118
+ deterministic=False,
119
+ sm_margin=0,
120
+ ):
121
+ # dq, dk, dv are allocated by us so they should already be contiguous
122
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
+ dout,
125
+ q,
126
+ k,
127
+ v,
128
+ out,
129
+ softmax_lse,
130
+ dq,
131
+ dk,
132
+ dv,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ sequed_q,
136
+ sequed_k,
137
+ max_seqlen_q,
138
+ max_seqlen_k,
139
+ softmax_scale,
140
+ causal,
141
+ window_size[0],
142
+ window_size[1],
143
+ softcap,
144
+ deterministic,
145
+ sm_margin,
146
+ )
147
+ return dq, dk, dv, softmax_d
148
+
149
+
150
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(
153
+ ctx,
154
+ qkv,
155
+ softmax_scale,
156
+ causal,
157
+ q_descale=None, k_descale=None, v_descale=None,
158
+ window_size=(-1, -1),
159
+ attention_chunk=0,
160
+ softcap=0.0,
161
+ deterministic=False,
162
+ num_heads_q=None,
163
+ sm_margin=0,
164
+ ):
165
+ if softmax_scale is None:
166
+ softmax_scale = qkv.shape[-1] ** (-0.5)
167
+ if qkv.dim() == 5:
168
+ assert qkv.shape[-3] == 3
169
+ q, k, v = qkv.unbind(dim=-3)
170
+ else:
171
+ assert qkv.dim() == 4
172
+ assert num_heads_q is not None
173
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
+ out, softmax_lse, *rest = _flash_attn_forward(
177
+ q,
178
+ k,
179
+ v,
180
+ None, None, # k_new, v_new
181
+ None, # qv
182
+ None, # out
183
+ None, None, None, # cu_seqlens_q/k/k_new
184
+ None, None, # seqused_q/k
185
+ None, None, # max_seqlen_q/k
186
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
+ None, None, None, # rotary_cos/sin, seqlens_rotary
188
+ q_descale, k_descale, v_descale,
189
+ softmax_scale,
190
+ causal=causal,
191
+ window_size=window_size,
192
+ attention_chunk=attention_chunk,
193
+ softcap=softcap,
194
+ sm_margin=sm_margin,
195
+ )
196
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
198
+ ctx.softmax_scale = softmax_scale
199
+ ctx.causal = causal
200
+ ctx.window_size = window_size
201
+ ctx.attention_chunk = attention_chunk
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ ctx.sm_margin = sm_margin
206
+ # return out, softmax_lse
207
+ return out
208
+
209
+ @staticmethod
210
+ def backward(ctx, dout, *args):
211
+ q, k, v, out, softmax_lse = ctx.saved_tensors
212
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
+ if ctx.ndim == 5:
214
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
+ dq, dk, dv = dqkv.unbind(dim=-3)
217
+ else:
218
+ num_heads_q = q.shape[2]
219
+ num_heads_k = k.shape[2]
220
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
+ _flash_attn_backward(
224
+ dout,
225
+ q,
226
+ k,
227
+ v,
228
+ out,
229
+ softmax_lse,
230
+ None, None, # cu_seqlens_q, cu_seqlens_k,
231
+ None, None, # sequed_q, sequed_k,
232
+ None, None, # max_seqlen_q, max_seqlen_k,
233
+ dq,
234
+ dk,
235
+ dv,
236
+ ctx.softmax_scale,
237
+ ctx.causal,
238
+ ctx.window_size,
239
+ ctx.softcap,
240
+ ctx.deterministic,
241
+ ctx.sm_margin,
242
+ )
243
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
+
246
+
247
+ class FlashAttnFunc(torch.autograd.Function):
248
+
249
+ @staticmethod
250
+ def forward(
251
+ ctx,
252
+ q,
253
+ k,
254
+ v,
255
+ softmax_scale,
256
+ causal,
257
+ qv=None,
258
+ q_descale=None, k_descale=None, v_descale=None,
259
+ window_size=(-1, -1),
260
+ attention_chunk=0,
261
+ softcap=0.0,
262
+ num_splits=1,
263
+ pack_gqa=None,
264
+ deterministic=False,
265
+ sm_margin=0,
266
+ ):
267
+ if softmax_scale is None:
268
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
+ out, softmax_lse, *rest = _flash_attn_forward(
271
+ q,
272
+ k,
273
+ v,
274
+ None, None, # k_new, v_new
275
+ qv, # qv
276
+ None, # out
277
+ None, None, None, # cu_seqlens_q/k/k_new
278
+ None, None, # seqused_q/k
279
+ None, None, # max_seqlen_q/k
280
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
+ None, None, None, # rotary_cos/sin, seqlens_rotary
282
+ q_descale, k_descale, v_descale,
283
+ softmax_scale,
284
+ causal=causal,
285
+ window_size=window_size,
286
+ attention_chunk=attention_chunk,
287
+ softcap=softcap,
288
+ num_splits=num_splits,
289
+ pack_gqa=pack_gqa,
290
+ sm_margin=sm_margin,
291
+ )
292
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
294
+ ctx.softmax_scale = softmax_scale
295
+ ctx.causal = causal
296
+ ctx.window_size = window_size
297
+ ctx.attention_chunk = attention_chunk
298
+ ctx.softcap = softcap
299
+ ctx.deterministic = deterministic
300
+ ctx.sm_margin = sm_margin
301
+ return out, softmax_lse
302
+
303
+ @staticmethod
304
+ def backward(ctx, dout, *args):
305
+ q, k, v, out, softmax_lse = ctx.saved_tensors
306
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
+ _flash_attn_backward(
309
+ dout,
310
+ q,
311
+ k,
312
+ v,
313
+ out,
314
+ softmax_lse,
315
+ None, None, # cu_seqlens_q, cu_seqlens_k,
316
+ None, None, # sequed_q, sequed_k,
317
+ None, None, # max_seqlen_q, max_seqlen_k,
318
+ dq,
319
+ dk,
320
+ dv,
321
+ ctx.softmax_scale,
322
+ ctx.causal,
323
+ ctx.window_size,
324
+ ctx.softcap,
325
+ ctx.deterministic,
326
+ ctx.sm_margin,
327
+ )
328
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
+ dk = dk[..., : k.shape[-1]]
330
+ dv = dv[..., : v.shape[-1]]
331
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
+
333
+
334
+ class FlashAttnVarlenFunc(torch.autograd.Function):
335
+
336
+ @staticmethod
337
+ def forward(
338
+ ctx,
339
+ q,
340
+ k,
341
+ v,
342
+ cu_seqlens_q,
343
+ cu_seqlens_k,
344
+ seqused_q,
345
+ seqused_k,
346
+ max_seqlen_q,
347
+ max_seqlen_k,
348
+ softmax_scale,
349
+ causal,
350
+ qv=None,
351
+ q_descale=None, k_descale=None, v_descale=None,
352
+ window_size=(-1, -1),
353
+ attention_chunk=0,
354
+ softcap=0.0,
355
+ num_splits=1,
356
+ pack_gqa=None,
357
+ deterministic=False,
358
+ sm_margin=0,
359
+ ):
360
+ if softmax_scale is None:
361
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
+ out, softmax_lse, *rest = _flash_attn_forward(
364
+ q,
365
+ k,
366
+ v,
367
+ None, None, # k_new, v_new
368
+ qv, # qv
369
+ None, # out
370
+ cu_seqlens_q,
371
+ cu_seqlens_k,
372
+ None, # cu_seqlens_k_new
373
+ seqused_q,
374
+ seqused_k,
375
+ max_seqlen_q,
376
+ max_seqlen_k,
377
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
+ None, None, None, # rotary_cos/sin, seqlens_rotary
379
+ q_descale, k_descale, v_descale,
380
+ softmax_scale,
381
+ causal=causal,
382
+ window_size=window_size,
383
+ attention_chunk=attention_chunk,
384
+ softcap=softcap,
385
+ num_splits=num_splits,
386
+ pack_gqa=pack_gqa,
387
+ sm_margin=sm_margin,
388
+ )
389
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
+ ctx.max_seqlen_q = max_seqlen_q
392
+ ctx.max_seqlen_k = max_seqlen_k
393
+ ctx.softmax_scale = softmax_scale
394
+ ctx.causal = causal
395
+ ctx.window_size = window_size
396
+ ctx.attention_chunk = attention_chunk
397
+ ctx.softcap = softcap
398
+ ctx.deterministic = deterministic
399
+ ctx.sm_margin = sm_margin
400
+ return out, softmax_lse
401
+
402
+ @staticmethod
403
+ def backward(ctx, dout, *args):
404
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
+ _flash_attn_backward(
408
+ dout,
409
+ q,
410
+ k,
411
+ v,
412
+ out,
413
+ softmax_lse,
414
+ cu_seqlens_q,
415
+ cu_seqlens_k,
416
+ seqused_q,
417
+ seqused_k,
418
+ ctx.max_seqlen_q,
419
+ ctx.max_seqlen_k,
420
+ dq,
421
+ dk,
422
+ dv,
423
+ ctx.softmax_scale,
424
+ ctx.causal,
425
+ ctx.window_size,
426
+ ctx.softcap,
427
+ ctx.deterministic,
428
+ ctx.sm_margin,
429
+ )
430
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
+ dk = dk[..., : k.shape[-1]]
432
+ dv = dv[..., : v.shape[-1]]
433
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
+
435
+
436
+ def flash_attn_qkvpacked_func(
437
+ qkv,
438
+ softmax_scale=None,
439
+ causal=False,
440
+ q_descale=None, k_descale=None, v_descale=None,
441
+ window_size=(-1, -1),
442
+ attention_chunk=0,
443
+ softcap=0.0,
444
+ deterministic=False,
445
+ num_heads_q=None,
446
+ sm_margin=0,
447
+ ):
448
+ """dropout_p should be set to 0.0 during evaluation
449
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
+ of the gradients of Q, K, V.
452
+ For multi-query and grouped-query attention (MQA/GQA), please see
453
+ flash_attn_kvpacked_func and flash_attn_func.
454
+
455
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
+
458
+ Arguments:
459
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
460
+ dropout_p: float. Dropout probability.
461
+ softmax_scale: float. The scaling of QK^T before applying softmax.
462
+ Default to 1 / sqrt(headdim).
463
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
+ softcap: float. Anything > 0 activates softcapping attention.
466
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
+ the attention score of query i and key j.
468
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
470
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
+ testing only. The returned probabilities are not guaranteed to be correct
472
+ (they might not have the right scaling).
473
+ Return:
474
+ out: (batch_size, seqlen, nheads, headdim).
475
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
+ normalization factor).
478
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
+ The output of softmax (possibly with different scaling). It also encodes the dropout
480
+ pattern (negative means that location was dropped, nonnegative means it was kept).
481
+ """
482
+ return FlashAttnQKVPackedFunc.apply(
483
+ qkv,
484
+ softmax_scale,
485
+ causal,
486
+ q_descale, k_descale, v_descale,
487
+ window_size,
488
+ attention_chunk,
489
+ softcap,
490
+ deterministic,
491
+ num_heads_q,
492
+ sm_margin,
493
+ )
494
+
495
+
496
+ def flash_attn_func(
497
+ q,
498
+ k,
499
+ v,
500
+ softmax_scale=None,
501
+ causal=False,
502
+ qv=None,
503
+ q_descale=None, k_descale=None, v_descale=None,
504
+ window_size=(-1, -1),
505
+ attention_chunk=0,
506
+ softcap=0.0,
507
+ num_splits=1,
508
+ pack_gqa=None,
509
+ deterministic=False,
510
+ sm_margin=0,
511
+ ):
512
+ """dropout_p should be set to 0.0 during evaluation
513
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
+
518
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
+ 1 1 1 1 0
521
+ 1 1 1 1 1
522
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
+ 0 0
524
+ 0 0
525
+ 0 0
526
+ 1 0
527
+ 1 1
528
+ If the row of the mask is all zero, the output will be zero.
529
+
530
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
+ will only attend to keys between
532
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
+
534
+ Arguments:
535
+ q: (batch_size, seqlen, nheads, headdim)
536
+ k: (batch_size, seqlen, nheads_k, headdim)
537
+ v: (batch_size, seqlen, nheads_k, headdim)
538
+ dropout_p: float. Dropout probability.
539
+ softmax_scale: float. The scaling of QK^T before applying softmax.
540
+ Default to 1 / sqrt(headdim).
541
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
+ is added to the attention score of query i and key j.
546
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
548
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
+ testing only. The returned probabilities are not guaranteed to be correct
550
+ (they might not have the right scaling).
551
+ Return:
552
+ out: (batch_size, seqlen, nheads, headdim).
553
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
+ normalization factor).
556
+ """
557
+ return FlashAttnFunc.apply(
558
+ q,
559
+ k,
560
+ v,
561
+ softmax_scale,
562
+ causal,
563
+ qv,
564
+ q_descale, k_descale, v_descale,
565
+ window_size,
566
+ attention_chunk,
567
+ softcap,
568
+ num_splits,
569
+ pack_gqa,
570
+ deterministic,
571
+ sm_margin,
572
+ )
573
+
574
+
575
+ def flash_attn_varlen_func(
576
+ q,
577
+ k,
578
+ v,
579
+ cu_seqlens_q,
580
+ cu_seqlens_k,
581
+ max_seqlen_q,
582
+ max_seqlen_k,
583
+ seqused_q=None,
584
+ seqused_k=None,
585
+ softmax_scale=None,
586
+ causal=False,
587
+ qv=None,
588
+ q_descale=None, k_descale=None, v_descale=None,
589
+ window_size=(-1, -1),
590
+ attention_chunk=0,
591
+ softcap=0.0,
592
+ num_splits=1,
593
+ pack_gqa=None,
594
+ deterministic=False,
595
+ sm_margin=0,
596
+ ):
597
+ return FlashAttnVarlenFunc.apply(
598
+ q,
599
+ k,
600
+ v,
601
+ cu_seqlens_q,
602
+ cu_seqlens_k,
603
+ seqused_q,
604
+ seqused_k,
605
+ max_seqlen_q,
606
+ max_seqlen_k,
607
+ softmax_scale,
608
+ causal,
609
+ qv,
610
+ q_descale, k_descale, v_descale,
611
+ window_size,
612
+ attention_chunk,
613
+ softcap,
614
+ num_splits,
615
+ pack_gqa,
616
+ deterministic,
617
+ sm_margin,
618
+ )
619
+
620
+
621
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
+
624
+
625
+ def flash_attn_with_kvcache(
626
+ q,
627
+ k_cache,
628
+ v_cache,
629
+ k=None,
630
+ v=None,
631
+ qv=None,
632
+ rotary_cos=None,
633
+ rotary_sin=None,
634
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
+ cache_batch_idx: Optional[torch.Tensor] = None,
636
+ cache_leftpad: Optional[torch.Tensor] = None,
637
+ page_table: Optional[torch.Tensor] = None,
638
+ cu_seqlens_q: Optional[torch.Tensor] = None,
639
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
+ max_seqlen_q: Optional[int] = None,
641
+ rotary_seqlens: Optional[torch.Tensor] = None,
642
+ q_descale: Optional[torch.Tensor] = None,
643
+ k_descale: Optional[torch.Tensor] = None,
644
+ v_descale: Optional[torch.Tensor] = None,
645
+ softmax_scale=None,
646
+ causal=False,
647
+ window_size=(-1, -1), # -1 means infinite context window
648
+ attention_chunk=0,
649
+ softcap=0.0, # 0.0 means deactivated
650
+ rotary_interleaved=True,
651
+ scheduler_metadata=None,
652
+ num_splits=0, # Can be tuned for speed
653
+ pack_gqa=None, # Can be tuned for speed
654
+ sm_margin=0, # Can be tuned if some SMs are used for communication
655
+ return_softmax_lse=False,
656
+ ):
657
+ """
658
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
+ the previous step, and update them with the new keys/values from the current step, and do
661
+ attention with the updated cache, all in 1 kernel.
662
+
663
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
+
667
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
+
674
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
+
676
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
+
681
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
+ 1 1 1 1 0
684
+ 1 1 1 1 1
685
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
+ 0 0
687
+ 0 0
688
+ 0 0
689
+ 1 0
690
+ 1 1
691
+ If the row of the mask is all zero, the output will be zero.
692
+
693
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
+ will only attend to keys between
695
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
+
697
+ Note: Does not support backward pass.
698
+
699
+ Arguments:
700
+ q: (batch_size, seqlen, nheads, headdim)
701
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
+ page_block_size must be a multiple of 256.
704
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
+ k with k_cache, starting at the indices specified by cache_seqlens.
708
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
+ KV cache.
715
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
718
+ might come from any of the duplicate indices.
719
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
+ softmax_scale: float. The scaling of QK^T before applying softmax.
722
+ Default to 1 / sqrt(headdim).
723
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
+ softcap: float. Anything > 0 activates softcapping attention.
726
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
+ (i.e. GPT-NeoX style).
730
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
+ to automatically determine the number of splits.
733
+ Don't change this unless you know what you are doing.
734
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
+
736
+ Return:
737
+ out: (batch_size, seqlen, nheads, headdim).
738
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
+ normalization factor).
741
+ """
742
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
+ if softmax_scale is None:
745
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
+ cache_seqlens = torch.full(
748
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
+ )
750
+ cache_seqlens = maybe_contiguous(cache_seqlens)
751
+ out, softmax_lse, *rest = _flash_attn_forward(
752
+ q,
753
+ k_cache,
754
+ v_cache,
755
+ k,
756
+ v,
757
+ qv,
758
+ None, # out
759
+ cu_seqlens_q,
760
+ None, # cu_seqlens_k
761
+ cu_seqlens_k_new,
762
+ None, # seqused_q
763
+ cache_seqlens,
764
+ max_seqlen_q,
765
+ None, # max_seqlen_k
766
+ page_table,
767
+ cache_batch_idx,
768
+ cache_leftpad,
769
+ rotary_cos,
770
+ rotary_sin,
771
+ rotary_seqlens,
772
+ q_descale, k_descale, v_descale,
773
+ softmax_scale,
774
+ causal=causal,
775
+ window_size=window_size,
776
+ attention_chunk=attention_chunk,
777
+ softcap=softcap,
778
+ rotary_interleaved=rotary_interleaved,
779
+ scheduler_metadata=scheduler_metadata,
780
+ num_splits=num_splits,
781
+ pack_gqa=pack_gqa,
782
+ sm_margin=sm_margin,
783
+ )
784
+ # return (out, softmax_lse) if return_softmax_lse else out
785
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
786
+
787
+
788
+ def get_scheduler_metadata(
789
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
+ cache_seqlens: torch.Tensor,
791
+ qkv_dtype=torch.bfloat16,
792
+ headdim_v=None,
793
+ cu_seqlens_q: Optional[torch.Tensor] = None,
794
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
+ cache_leftpad: Optional[torch.Tensor] = None,
796
+ page_size: Optional[int] = None,
797
+ max_seqlen_k_new=0,
798
+ causal=False,
799
+ window_size=(-1, -1), # -1 means infinite context window
800
+ attention_chunk=0,
801
+ has_softcap=False,
802
+ num_splits=0, # Can be tuned for speed
803
+ pack_gqa=None, # Can be tuned for speed
804
+ sm_margin=0, # Can be tuned if some SMs are used for communication
805
+ ):
806
+ cache_seqlens = maybe_contiguous(cache_seqlens)
807
+ if headdim_v is None:
808
+ headdim_v = headdim
809
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
+ qkv_dtype,
812
+ cache_seqlens,
813
+ cu_seqlens_q,
814
+ None, # cu_seqlens_k
815
+ cu_seqlens_k_new,
816
+ None, # seqused_q
817
+ cache_leftpad,
818
+ page_size,
819
+ max_seqlen_k_new,
820
+ causal,
821
+ window_size[0], window_size[1],
822
+ attention_chunk,
823
+ has_softcap,
824
+ num_splits,
825
+ pack_gqa,
826
+ sm_margin,
827
+ )
828
+ return scheduler_metadata
build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0302224ac29ba4773d926d4cb16c01c45a374c6dd61286aae1f423f2bf495ea
3
+ size 838459544
build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn3_2e75662
3
+ ops = torch.ops._flash_attn3_2e75662
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn3_2e75662::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ._ops import ops as flash_attn_3_cuda
9
+
10
+ def maybe_contiguous(x):
11
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
+
13
+
14
+ def _flash_attn_forward(
15
+ q,
16
+ k,
17
+ v,
18
+ k_new,
19
+ v_new,
20
+ qv,
21
+ out,
22
+ cu_seqlens_q,
23
+ cu_seqlens_k,
24
+ cu_seqlens_k_new,
25
+ seqused_q,
26
+ seqused_k,
27
+ max_seqlen_q,
28
+ max_seqlen_k,
29
+ page_table,
30
+ kv_batch_idx,
31
+ leftpad_k,
32
+ rotary_cos,
33
+ rotary_sin,
34
+ seqlens_rotary,
35
+ q_descale,
36
+ k_descale,
37
+ v_descale,
38
+ softmax_scale,
39
+ causal,
40
+ window_size=(-1, -1),
41
+ attention_chunk=0,
42
+ softcap=0.0,
43
+ rotary_interleaved=True,
44
+ scheduler_metadata=None,
45
+ num_splits=1,
46
+ pack_gqa=None,
47
+ sm_margin=0):
48
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
+ ]
53
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
+ page_table, kv_batch_idx, leftpad_k = [
55
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
+ ]
57
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
+ out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
+ q,
61
+ k,
62
+ v,
63
+ k_new,
64
+ v_new,
65
+ qv,
66
+ out,
67
+ cu_seqlens_q,
68
+ cu_seqlens_k,
69
+ cu_seqlens_k_new,
70
+ seqused_q,
71
+ seqused_k,
72
+ max_seqlen_q,
73
+ max_seqlen_k,
74
+ page_table,
75
+ kv_batch_idx,
76
+ leftpad_k,
77
+ rotary_cos,
78
+ rotary_sin,
79
+ seqlens_rotary,
80
+ q_descale,
81
+ k_descale,
82
+ v_descale,
83
+ softmax_scale,
84
+ causal,
85
+ window_size[0],
86
+ window_size[1],
87
+ attention_chunk,
88
+ softcap,
89
+ rotary_interleaved,
90
+ scheduler_metadata,
91
+ num_splits,
92
+ pack_gqa,
93
+ sm_margin,
94
+ )
95
+ return out, softmax_lse, *rest
96
+
97
+
98
+ def _flash_attn_backward(
99
+ dout,
100
+ q,
101
+ k,
102
+ v,
103
+ out,
104
+ softmax_lse,
105
+ cu_seqlens_q,
106
+ cu_seqlens_k,
107
+ sequed_q,
108
+ sequed_k,
109
+ max_seqlen_q,
110
+ max_seqlen_k,
111
+ dq,
112
+ dk,
113
+ dv,
114
+ softmax_scale,
115
+ causal,
116
+ window_size=(-1, -1),
117
+ softcap=0.0,
118
+ deterministic=False,
119
+ sm_margin=0,
120
+ ):
121
+ # dq, dk, dv are allocated by us so they should already be contiguous
122
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
+ dout,
125
+ q,
126
+ k,
127
+ v,
128
+ out,
129
+ softmax_lse,
130
+ dq,
131
+ dk,
132
+ dv,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ sequed_q,
136
+ sequed_k,
137
+ max_seqlen_q,
138
+ max_seqlen_k,
139
+ softmax_scale,
140
+ causal,
141
+ window_size[0],
142
+ window_size[1],
143
+ softcap,
144
+ deterministic,
145
+ sm_margin,
146
+ )
147
+ return dq, dk, dv, softmax_d
148
+
149
+
150
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(
153
+ ctx,
154
+ qkv,
155
+ softmax_scale,
156
+ causal,
157
+ q_descale=None, k_descale=None, v_descale=None,
158
+ window_size=(-1, -1),
159
+ attention_chunk=0,
160
+ softcap=0.0,
161
+ deterministic=False,
162
+ num_heads_q=None,
163
+ sm_margin=0,
164
+ ):
165
+ if softmax_scale is None:
166
+ softmax_scale = qkv.shape[-1] ** (-0.5)
167
+ if qkv.dim() == 5:
168
+ assert qkv.shape[-3] == 3
169
+ q, k, v = qkv.unbind(dim=-3)
170
+ else:
171
+ assert qkv.dim() == 4
172
+ assert num_heads_q is not None
173
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
+ out, softmax_lse, *rest = _flash_attn_forward(
177
+ q,
178
+ k,
179
+ v,
180
+ None, None, # k_new, v_new
181
+ None, # qv
182
+ None, # out
183
+ None, None, None, # cu_seqlens_q/k/k_new
184
+ None, None, # seqused_q/k
185
+ None, None, # max_seqlen_q/k
186
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
+ None, None, None, # rotary_cos/sin, seqlens_rotary
188
+ q_descale, k_descale, v_descale,
189
+ softmax_scale,
190
+ causal=causal,
191
+ window_size=window_size,
192
+ attention_chunk=attention_chunk,
193
+ softcap=softcap,
194
+ sm_margin=sm_margin,
195
+ )
196
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
198
+ ctx.softmax_scale = softmax_scale
199
+ ctx.causal = causal
200
+ ctx.window_size = window_size
201
+ ctx.attention_chunk = attention_chunk
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ ctx.sm_margin = sm_margin
206
+ # return out, softmax_lse
207
+ return out
208
+
209
+ @staticmethod
210
+ def backward(ctx, dout, *args):
211
+ q, k, v, out, softmax_lse = ctx.saved_tensors
212
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
+ if ctx.ndim == 5:
214
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
+ dq, dk, dv = dqkv.unbind(dim=-3)
217
+ else:
218
+ num_heads_q = q.shape[2]
219
+ num_heads_k = k.shape[2]
220
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
+ _flash_attn_backward(
224
+ dout,
225
+ q,
226
+ k,
227
+ v,
228
+ out,
229
+ softmax_lse,
230
+ None, None, # cu_seqlens_q, cu_seqlens_k,
231
+ None, None, # sequed_q, sequed_k,
232
+ None, None, # max_seqlen_q, max_seqlen_k,
233
+ dq,
234
+ dk,
235
+ dv,
236
+ ctx.softmax_scale,
237
+ ctx.causal,
238
+ ctx.window_size,
239
+ ctx.softcap,
240
+ ctx.deterministic,
241
+ ctx.sm_margin,
242
+ )
243
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
+
246
+
247
+ class FlashAttnFunc(torch.autograd.Function):
248
+
249
+ @staticmethod
250
+ def forward(
251
+ ctx,
252
+ q,
253
+ k,
254
+ v,
255
+ softmax_scale,
256
+ causal,
257
+ qv=None,
258
+ q_descale=None, k_descale=None, v_descale=None,
259
+ window_size=(-1, -1),
260
+ attention_chunk=0,
261
+ softcap=0.0,
262
+ num_splits=1,
263
+ pack_gqa=None,
264
+ deterministic=False,
265
+ sm_margin=0,
266
+ ):
267
+ if softmax_scale is None:
268
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
+ out, softmax_lse, *rest = _flash_attn_forward(
271
+ q,
272
+ k,
273
+ v,
274
+ None, None, # k_new, v_new
275
+ qv, # qv
276
+ None, # out
277
+ None, None, None, # cu_seqlens_q/k/k_new
278
+ None, None, # seqused_q/k
279
+ None, None, # max_seqlen_q/k
280
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
+ None, None, None, # rotary_cos/sin, seqlens_rotary
282
+ q_descale, k_descale, v_descale,
283
+ softmax_scale,
284
+ causal=causal,
285
+ window_size=window_size,
286
+ attention_chunk=attention_chunk,
287
+ softcap=softcap,
288
+ num_splits=num_splits,
289
+ pack_gqa=pack_gqa,
290
+ sm_margin=sm_margin,
291
+ )
292
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
294
+ ctx.softmax_scale = softmax_scale
295
+ ctx.causal = causal
296
+ ctx.window_size = window_size
297
+ ctx.attention_chunk = attention_chunk
298
+ ctx.softcap = softcap
299
+ ctx.deterministic = deterministic
300
+ ctx.sm_margin = sm_margin
301
+ return out, softmax_lse
302
+
303
+ @staticmethod
304
+ def backward(ctx, dout, *args):
305
+ q, k, v, out, softmax_lse = ctx.saved_tensors
306
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
+ _flash_attn_backward(
309
+ dout,
310
+ q,
311
+ k,
312
+ v,
313
+ out,
314
+ softmax_lse,
315
+ None, None, # cu_seqlens_q, cu_seqlens_k,
316
+ None, None, # sequed_q, sequed_k,
317
+ None, None, # max_seqlen_q, max_seqlen_k,
318
+ dq,
319
+ dk,
320
+ dv,
321
+ ctx.softmax_scale,
322
+ ctx.causal,
323
+ ctx.window_size,
324
+ ctx.softcap,
325
+ ctx.deterministic,
326
+ ctx.sm_margin,
327
+ )
328
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
+ dk = dk[..., : k.shape[-1]]
330
+ dv = dv[..., : v.shape[-1]]
331
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
+
333
+
334
+ class FlashAttnVarlenFunc(torch.autograd.Function):
335
+
336
+ @staticmethod
337
+ def forward(
338
+ ctx,
339
+ q,
340
+ k,
341
+ v,
342
+ cu_seqlens_q,
343
+ cu_seqlens_k,
344
+ seqused_q,
345
+ seqused_k,
346
+ max_seqlen_q,
347
+ max_seqlen_k,
348
+ softmax_scale,
349
+ causal,
350
+ qv=None,
351
+ q_descale=None, k_descale=None, v_descale=None,
352
+ window_size=(-1, -1),
353
+ attention_chunk=0,
354
+ softcap=0.0,
355
+ num_splits=1,
356
+ pack_gqa=None,
357
+ deterministic=False,
358
+ sm_margin=0,
359
+ ):
360
+ if softmax_scale is None:
361
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
+ out, softmax_lse, *rest = _flash_attn_forward(
364
+ q,
365
+ k,
366
+ v,
367
+ None, None, # k_new, v_new
368
+ qv, # qv
369
+ None, # out
370
+ cu_seqlens_q,
371
+ cu_seqlens_k,
372
+ None, # cu_seqlens_k_new
373
+ seqused_q,
374
+ seqused_k,
375
+ max_seqlen_q,
376
+ max_seqlen_k,
377
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
+ None, None, None, # rotary_cos/sin, seqlens_rotary
379
+ q_descale, k_descale, v_descale,
380
+ softmax_scale,
381
+ causal=causal,
382
+ window_size=window_size,
383
+ attention_chunk=attention_chunk,
384
+ softcap=softcap,
385
+ num_splits=num_splits,
386
+ pack_gqa=pack_gqa,
387
+ sm_margin=sm_margin,
388
+ )
389
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
+ ctx.max_seqlen_q = max_seqlen_q
392
+ ctx.max_seqlen_k = max_seqlen_k
393
+ ctx.softmax_scale = softmax_scale
394
+ ctx.causal = causal
395
+ ctx.window_size = window_size
396
+ ctx.attention_chunk = attention_chunk
397
+ ctx.softcap = softcap
398
+ ctx.deterministic = deterministic
399
+ ctx.sm_margin = sm_margin
400
+ return out, softmax_lse
401
+
402
+ @staticmethod
403
+ def backward(ctx, dout, *args):
404
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
+ _flash_attn_backward(
408
+ dout,
409
+ q,
410
+ k,
411
+ v,
412
+ out,
413
+ softmax_lse,
414
+ cu_seqlens_q,
415
+ cu_seqlens_k,
416
+ seqused_q,
417
+ seqused_k,
418
+ ctx.max_seqlen_q,
419
+ ctx.max_seqlen_k,
420
+ dq,
421
+ dk,
422
+ dv,
423
+ ctx.softmax_scale,
424
+ ctx.causal,
425
+ ctx.window_size,
426
+ ctx.softcap,
427
+ ctx.deterministic,
428
+ ctx.sm_margin,
429
+ )
430
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
+ dk = dk[..., : k.shape[-1]]
432
+ dv = dv[..., : v.shape[-1]]
433
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
+
435
+
436
+ def flash_attn_qkvpacked_func(
437
+ qkv,
438
+ softmax_scale=None,
439
+ causal=False,
440
+ q_descale=None, k_descale=None, v_descale=None,
441
+ window_size=(-1, -1),
442
+ attention_chunk=0,
443
+ softcap=0.0,
444
+ deterministic=False,
445
+ num_heads_q=None,
446
+ sm_margin=0,
447
+ ):
448
+ """dropout_p should be set to 0.0 during evaluation
449
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
+ of the gradients of Q, K, V.
452
+ For multi-query and grouped-query attention (MQA/GQA), please see
453
+ flash_attn_kvpacked_func and flash_attn_func.
454
+
455
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
+
458
+ Arguments:
459
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
460
+ dropout_p: float. Dropout probability.
461
+ softmax_scale: float. The scaling of QK^T before applying softmax.
462
+ Default to 1 / sqrt(headdim).
463
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
+ softcap: float. Anything > 0 activates softcapping attention.
466
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
+ the attention score of query i and key j.
468
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
470
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
+ testing only. The returned probabilities are not guaranteed to be correct
472
+ (they might not have the right scaling).
473
+ Return:
474
+ out: (batch_size, seqlen, nheads, headdim).
475
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
+ normalization factor).
478
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
+ The output of softmax (possibly with different scaling). It also encodes the dropout
480
+ pattern (negative means that location was dropped, nonnegative means it was kept).
481
+ """
482
+ return FlashAttnQKVPackedFunc.apply(
483
+ qkv,
484
+ softmax_scale,
485
+ causal,
486
+ q_descale, k_descale, v_descale,
487
+ window_size,
488
+ attention_chunk,
489
+ softcap,
490
+ deterministic,
491
+ num_heads_q,
492
+ sm_margin,
493
+ )
494
+
495
+
496
+ def flash_attn_func(
497
+ q,
498
+ k,
499
+ v,
500
+ softmax_scale=None,
501
+ causal=False,
502
+ qv=None,
503
+ q_descale=None, k_descale=None, v_descale=None,
504
+ window_size=(-1, -1),
505
+ attention_chunk=0,
506
+ softcap=0.0,
507
+ num_splits=1,
508
+ pack_gqa=None,
509
+ deterministic=False,
510
+ sm_margin=0,
511
+ ):
512
+ """dropout_p should be set to 0.0 during evaluation
513
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
+
518
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
+ 1 1 1 1 0
521
+ 1 1 1 1 1
522
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
+ 0 0
524
+ 0 0
525
+ 0 0
526
+ 1 0
527
+ 1 1
528
+ If the row of the mask is all zero, the output will be zero.
529
+
530
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
+ will only attend to keys between
532
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
+
534
+ Arguments:
535
+ q: (batch_size, seqlen, nheads, headdim)
536
+ k: (batch_size, seqlen, nheads_k, headdim)
537
+ v: (batch_size, seqlen, nheads_k, headdim)
538
+ dropout_p: float. Dropout probability.
539
+ softmax_scale: float. The scaling of QK^T before applying softmax.
540
+ Default to 1 / sqrt(headdim).
541
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
+ is added to the attention score of query i and key j.
546
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
548
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
+ testing only. The returned probabilities are not guaranteed to be correct
550
+ (they might not have the right scaling).
551
+ Return:
552
+ out: (batch_size, seqlen, nheads, headdim).
553
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
+ normalization factor).
556
+ """
557
+ return FlashAttnFunc.apply(
558
+ q,
559
+ k,
560
+ v,
561
+ softmax_scale,
562
+ causal,
563
+ qv,
564
+ q_descale, k_descale, v_descale,
565
+ window_size,
566
+ attention_chunk,
567
+ softcap,
568
+ num_splits,
569
+ pack_gqa,
570
+ deterministic,
571
+ sm_margin,
572
+ )
573
+
574
+
575
+ def flash_attn_varlen_func(
576
+ q,
577
+ k,
578
+ v,
579
+ cu_seqlens_q,
580
+ cu_seqlens_k,
581
+ max_seqlen_q,
582
+ max_seqlen_k,
583
+ seqused_q=None,
584
+ seqused_k=None,
585
+ softmax_scale=None,
586
+ causal=False,
587
+ qv=None,
588
+ q_descale=None, k_descale=None, v_descale=None,
589
+ window_size=(-1, -1),
590
+ attention_chunk=0,
591
+ softcap=0.0,
592
+ num_splits=1,
593
+ pack_gqa=None,
594
+ deterministic=False,
595
+ sm_margin=0,
596
+ ):
597
+ return FlashAttnVarlenFunc.apply(
598
+ q,
599
+ k,
600
+ v,
601
+ cu_seqlens_q,
602
+ cu_seqlens_k,
603
+ seqused_q,
604
+ seqused_k,
605
+ max_seqlen_q,
606
+ max_seqlen_k,
607
+ softmax_scale,
608
+ causal,
609
+ qv,
610
+ q_descale, k_descale, v_descale,
611
+ window_size,
612
+ attention_chunk,
613
+ softcap,
614
+ num_splits,
615
+ pack_gqa,
616
+ deterministic,
617
+ sm_margin,
618
+ )
619
+
620
+
621
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
+
624
+
625
+ def flash_attn_with_kvcache(
626
+ q,
627
+ k_cache,
628
+ v_cache,
629
+ k=None,
630
+ v=None,
631
+ qv=None,
632
+ rotary_cos=None,
633
+ rotary_sin=None,
634
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
+ cache_batch_idx: Optional[torch.Tensor] = None,
636
+ cache_leftpad: Optional[torch.Tensor] = None,
637
+ page_table: Optional[torch.Tensor] = None,
638
+ cu_seqlens_q: Optional[torch.Tensor] = None,
639
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
+ max_seqlen_q: Optional[int] = None,
641
+ rotary_seqlens: Optional[torch.Tensor] = None,
642
+ q_descale: Optional[torch.Tensor] = None,
643
+ k_descale: Optional[torch.Tensor] = None,
644
+ v_descale: Optional[torch.Tensor] = None,
645
+ softmax_scale=None,
646
+ causal=False,
647
+ window_size=(-1, -1), # -1 means infinite context window
648
+ attention_chunk=0,
649
+ softcap=0.0, # 0.0 means deactivated
650
+ rotary_interleaved=True,
651
+ scheduler_metadata=None,
652
+ num_splits=0, # Can be tuned for speed
653
+ pack_gqa=None, # Can be tuned for speed
654
+ sm_margin=0, # Can be tuned if some SMs are used for communication
655
+ return_softmax_lse=False,
656
+ ):
657
+ """
658
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
+ the previous step, and update them with the new keys/values from the current step, and do
661
+ attention with the updated cache, all in 1 kernel.
662
+
663
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
+
667
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
+
674
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
+
676
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
+
681
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
+ 1 1 1 1 0
684
+ 1 1 1 1 1
685
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
+ 0 0
687
+ 0 0
688
+ 0 0
689
+ 1 0
690
+ 1 1
691
+ If the row of the mask is all zero, the output will be zero.
692
+
693
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
+ will only attend to keys between
695
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
+
697
+ Note: Does not support backward pass.
698
+
699
+ Arguments:
700
+ q: (batch_size, seqlen, nheads, headdim)
701
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
+ page_block_size must be a multiple of 256.
704
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
+ k with k_cache, starting at the indices specified by cache_seqlens.
708
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
+ KV cache.
715
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
718
+ might come from any of the duplicate indices.
719
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
+ softmax_scale: float. The scaling of QK^T before applying softmax.
722
+ Default to 1 / sqrt(headdim).
723
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
+ softcap: float. Anything > 0 activates softcapping attention.
726
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
+ (i.e. GPT-NeoX style).
730
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
+ to automatically determine the number of splits.
733
+ Don't change this unless you know what you are doing.
734
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
+
736
+ Return:
737
+ out: (batch_size, seqlen, nheads, headdim).
738
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
+ normalization factor).
741
+ """
742
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
+ if softmax_scale is None:
745
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
+ cache_seqlens = torch.full(
748
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
+ )
750
+ cache_seqlens = maybe_contiguous(cache_seqlens)
751
+ out, softmax_lse, *rest = _flash_attn_forward(
752
+ q,
753
+ k_cache,
754
+ v_cache,
755
+ k,
756
+ v,
757
+ qv,
758
+ None, # out
759
+ cu_seqlens_q,
760
+ None, # cu_seqlens_k
761
+ cu_seqlens_k_new,
762
+ None, # seqused_q
763
+ cache_seqlens,
764
+ max_seqlen_q,
765
+ None, # max_seqlen_k
766
+ page_table,
767
+ cache_batch_idx,
768
+ cache_leftpad,
769
+ rotary_cos,
770
+ rotary_sin,
771
+ rotary_seqlens,
772
+ q_descale, k_descale, v_descale,
773
+ softmax_scale,
774
+ causal=causal,
775
+ window_size=window_size,
776
+ attention_chunk=attention_chunk,
777
+ softcap=softcap,
778
+ rotary_interleaved=rotary_interleaved,
779
+ scheduler_metadata=scheduler_metadata,
780
+ num_splits=num_splits,
781
+ pack_gqa=pack_gqa,
782
+ sm_margin=sm_margin,
783
+ )
784
+ # return (out, softmax_lse) if return_softmax_lse else out
785
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
786
+
787
+
788
+ def get_scheduler_metadata(
789
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
+ cache_seqlens: torch.Tensor,
791
+ qkv_dtype=torch.bfloat16,
792
+ headdim_v=None,
793
+ cu_seqlens_q: Optional[torch.Tensor] = None,
794
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
+ cache_leftpad: Optional[torch.Tensor] = None,
796
+ page_size: Optional[int] = None,
797
+ max_seqlen_k_new=0,
798
+ causal=False,
799
+ window_size=(-1, -1), # -1 means infinite context window
800
+ attention_chunk=0,
801
+ has_softcap=False,
802
+ num_splits=0, # Can be tuned for speed
803
+ pack_gqa=None, # Can be tuned for speed
804
+ sm_margin=0, # Can be tuned if some SMs are used for communication
805
+ ):
806
+ cache_seqlens = maybe_contiguous(cache_seqlens)
807
+ if headdim_v is None:
808
+ headdim_v = headdim
809
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
+ qkv_dtype,
812
+ cache_seqlens,
813
+ cu_seqlens_q,
814
+ None, # cu_seqlens_k
815
+ cu_seqlens_k_new,
816
+ None, # seqused_q
817
+ cache_leftpad,
818
+ page_size,
819
+ max_seqlen_k_new,
820
+ causal,
821
+ window_size[0], window_size[1],
822
+ attention_chunk,
823
+ has_softcap,
824
+ num_splits,
825
+ pack_gqa,
826
+ sm_margin,
827
+ )
828
+ return scheduler_metadata
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0302224ac29ba4773d926d4cb16c01c45a374c6dd61286aae1f423f2bf495ea
3
+ size 838459544
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn3_2e75662
3
+ ops = torch.ops._flash_attn3_2e75662
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn3_2e75662::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ._ops import ops as flash_attn_3_cuda
9
+
10
+ def maybe_contiguous(x):
11
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
+
13
+
14
+ def _flash_attn_forward(
15
+ q,
16
+ k,
17
+ v,
18
+ k_new,
19
+ v_new,
20
+ qv,
21
+ out,
22
+ cu_seqlens_q,
23
+ cu_seqlens_k,
24
+ cu_seqlens_k_new,
25
+ seqused_q,
26
+ seqused_k,
27
+ max_seqlen_q,
28
+ max_seqlen_k,
29
+ page_table,
30
+ kv_batch_idx,
31
+ leftpad_k,
32
+ rotary_cos,
33
+ rotary_sin,
34
+ seqlens_rotary,
35
+ q_descale,
36
+ k_descale,
37
+ v_descale,
38
+ softmax_scale,
39
+ causal,
40
+ window_size=(-1, -1),
41
+ attention_chunk=0,
42
+ softcap=0.0,
43
+ rotary_interleaved=True,
44
+ scheduler_metadata=None,
45
+ num_splits=1,
46
+ pack_gqa=None,
47
+ sm_margin=0):
48
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
+ ]
53
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
+ page_table, kv_batch_idx, leftpad_k = [
55
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
+ ]
57
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
+ out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
+ q,
61
+ k,
62
+ v,
63
+ k_new,
64
+ v_new,
65
+ qv,
66
+ out,
67
+ cu_seqlens_q,
68
+ cu_seqlens_k,
69
+ cu_seqlens_k_new,
70
+ seqused_q,
71
+ seqused_k,
72
+ max_seqlen_q,
73
+ max_seqlen_k,
74
+ page_table,
75
+ kv_batch_idx,
76
+ leftpad_k,
77
+ rotary_cos,
78
+ rotary_sin,
79
+ seqlens_rotary,
80
+ q_descale,
81
+ k_descale,
82
+ v_descale,
83
+ softmax_scale,
84
+ causal,
85
+ window_size[0],
86
+ window_size[1],
87
+ attention_chunk,
88
+ softcap,
89
+ rotary_interleaved,
90
+ scheduler_metadata,
91
+ num_splits,
92
+ pack_gqa,
93
+ sm_margin,
94
+ )
95
+ return out, softmax_lse, *rest
96
+
97
+
98
+ def _flash_attn_backward(
99
+ dout,
100
+ q,
101
+ k,
102
+ v,
103
+ out,
104
+ softmax_lse,
105
+ cu_seqlens_q,
106
+ cu_seqlens_k,
107
+ sequed_q,
108
+ sequed_k,
109
+ max_seqlen_q,
110
+ max_seqlen_k,
111
+ dq,
112
+ dk,
113
+ dv,
114
+ softmax_scale,
115
+ causal,
116
+ window_size=(-1, -1),
117
+ softcap=0.0,
118
+ deterministic=False,
119
+ sm_margin=0,
120
+ ):
121
+ # dq, dk, dv are allocated by us so they should already be contiguous
122
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
+ dout,
125
+ q,
126
+ k,
127
+ v,
128
+ out,
129
+ softmax_lse,
130
+ dq,
131
+ dk,
132
+ dv,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ sequed_q,
136
+ sequed_k,
137
+ max_seqlen_q,
138
+ max_seqlen_k,
139
+ softmax_scale,
140
+ causal,
141
+ window_size[0],
142
+ window_size[1],
143
+ softcap,
144
+ deterministic,
145
+ sm_margin,
146
+ )
147
+ return dq, dk, dv, softmax_d
148
+
149
+
150
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(
153
+ ctx,
154
+ qkv,
155
+ softmax_scale,
156
+ causal,
157
+ q_descale=None, k_descale=None, v_descale=None,
158
+ window_size=(-1, -1),
159
+ attention_chunk=0,
160
+ softcap=0.0,
161
+ deterministic=False,
162
+ num_heads_q=None,
163
+ sm_margin=0,
164
+ ):
165
+ if softmax_scale is None:
166
+ softmax_scale = qkv.shape[-1] ** (-0.5)
167
+ if qkv.dim() == 5:
168
+ assert qkv.shape[-3] == 3
169
+ q, k, v = qkv.unbind(dim=-3)
170
+ else:
171
+ assert qkv.dim() == 4
172
+ assert num_heads_q is not None
173
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
+ out, softmax_lse, *rest = _flash_attn_forward(
177
+ q,
178
+ k,
179
+ v,
180
+ None, None, # k_new, v_new
181
+ None, # qv
182
+ None, # out
183
+ None, None, None, # cu_seqlens_q/k/k_new
184
+ None, None, # seqused_q/k
185
+ None, None, # max_seqlen_q/k
186
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
+ None, None, None, # rotary_cos/sin, seqlens_rotary
188
+ q_descale, k_descale, v_descale,
189
+ softmax_scale,
190
+ causal=causal,
191
+ window_size=window_size,
192
+ attention_chunk=attention_chunk,
193
+ softcap=softcap,
194
+ sm_margin=sm_margin,
195
+ )
196
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
198
+ ctx.softmax_scale = softmax_scale
199
+ ctx.causal = causal
200
+ ctx.window_size = window_size
201
+ ctx.attention_chunk = attention_chunk
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ ctx.sm_margin = sm_margin
206
+ # return out, softmax_lse
207
+ return out
208
+
209
+ @staticmethod
210
+ def backward(ctx, dout, *args):
211
+ q, k, v, out, softmax_lse = ctx.saved_tensors
212
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
+ if ctx.ndim == 5:
214
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
+ dq, dk, dv = dqkv.unbind(dim=-3)
217
+ else:
218
+ num_heads_q = q.shape[2]
219
+ num_heads_k = k.shape[2]
220
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
+ _flash_attn_backward(
224
+ dout,
225
+ q,
226
+ k,
227
+ v,
228
+ out,
229
+ softmax_lse,
230
+ None, None, # cu_seqlens_q, cu_seqlens_k,
231
+ None, None, # sequed_q, sequed_k,
232
+ None, None, # max_seqlen_q, max_seqlen_k,
233
+ dq,
234
+ dk,
235
+ dv,
236
+ ctx.softmax_scale,
237
+ ctx.causal,
238
+ ctx.window_size,
239
+ ctx.softcap,
240
+ ctx.deterministic,
241
+ ctx.sm_margin,
242
+ )
243
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
+
246
+
247
+ class FlashAttnFunc(torch.autograd.Function):
248
+
249
+ @staticmethod
250
+ def forward(
251
+ ctx,
252
+ q,
253
+ k,
254
+ v,
255
+ softmax_scale,
256
+ causal,
257
+ qv=None,
258
+ q_descale=None, k_descale=None, v_descale=None,
259
+ window_size=(-1, -1),
260
+ attention_chunk=0,
261
+ softcap=0.0,
262
+ num_splits=1,
263
+ pack_gqa=None,
264
+ deterministic=False,
265
+ sm_margin=0,
266
+ ):
267
+ if softmax_scale is None:
268
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
+ out, softmax_lse, *rest = _flash_attn_forward(
271
+ q,
272
+ k,
273
+ v,
274
+ None, None, # k_new, v_new
275
+ qv, # qv
276
+ None, # out
277
+ None, None, None, # cu_seqlens_q/k/k_new
278
+ None, None, # seqused_q/k
279
+ None, None, # max_seqlen_q/k
280
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
+ None, None, None, # rotary_cos/sin, seqlens_rotary
282
+ q_descale, k_descale, v_descale,
283
+ softmax_scale,
284
+ causal=causal,
285
+ window_size=window_size,
286
+ attention_chunk=attention_chunk,
287
+ softcap=softcap,
288
+ num_splits=num_splits,
289
+ pack_gqa=pack_gqa,
290
+ sm_margin=sm_margin,
291
+ )
292
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
294
+ ctx.softmax_scale = softmax_scale
295
+ ctx.causal = causal
296
+ ctx.window_size = window_size
297
+ ctx.attention_chunk = attention_chunk
298
+ ctx.softcap = softcap
299
+ ctx.deterministic = deterministic
300
+ ctx.sm_margin = sm_margin
301
+ return out, softmax_lse
302
+
303
+ @staticmethod
304
+ def backward(ctx, dout, *args):
305
+ q, k, v, out, softmax_lse = ctx.saved_tensors
306
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
+ _flash_attn_backward(
309
+ dout,
310
+ q,
311
+ k,
312
+ v,
313
+ out,
314
+ softmax_lse,
315
+ None, None, # cu_seqlens_q, cu_seqlens_k,
316
+ None, None, # sequed_q, sequed_k,
317
+ None, None, # max_seqlen_q, max_seqlen_k,
318
+ dq,
319
+ dk,
320
+ dv,
321
+ ctx.softmax_scale,
322
+ ctx.causal,
323
+ ctx.window_size,
324
+ ctx.softcap,
325
+ ctx.deterministic,
326
+ ctx.sm_margin,
327
+ )
328
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
+ dk = dk[..., : k.shape[-1]]
330
+ dv = dv[..., : v.shape[-1]]
331
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
+
333
+
334
+ class FlashAttnVarlenFunc(torch.autograd.Function):
335
+
336
+ @staticmethod
337
+ def forward(
338
+ ctx,
339
+ q,
340
+ k,
341
+ v,
342
+ cu_seqlens_q,
343
+ cu_seqlens_k,
344
+ seqused_q,
345
+ seqused_k,
346
+ max_seqlen_q,
347
+ max_seqlen_k,
348
+ softmax_scale,
349
+ causal,
350
+ qv=None,
351
+ q_descale=None, k_descale=None, v_descale=None,
352
+ window_size=(-1, -1),
353
+ attention_chunk=0,
354
+ softcap=0.0,
355
+ num_splits=1,
356
+ pack_gqa=None,
357
+ deterministic=False,
358
+ sm_margin=0,
359
+ ):
360
+ if softmax_scale is None:
361
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
+ out, softmax_lse, *rest = _flash_attn_forward(
364
+ q,
365
+ k,
366
+ v,
367
+ None, None, # k_new, v_new
368
+ qv, # qv
369
+ None, # out
370
+ cu_seqlens_q,
371
+ cu_seqlens_k,
372
+ None, # cu_seqlens_k_new
373
+ seqused_q,
374
+ seqused_k,
375
+ max_seqlen_q,
376
+ max_seqlen_k,
377
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
+ None, None, None, # rotary_cos/sin, seqlens_rotary
379
+ q_descale, k_descale, v_descale,
380
+ softmax_scale,
381
+ causal=causal,
382
+ window_size=window_size,
383
+ attention_chunk=attention_chunk,
384
+ softcap=softcap,
385
+ num_splits=num_splits,
386
+ pack_gqa=pack_gqa,
387
+ sm_margin=sm_margin,
388
+ )
389
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
+ ctx.max_seqlen_q = max_seqlen_q
392
+ ctx.max_seqlen_k = max_seqlen_k
393
+ ctx.softmax_scale = softmax_scale
394
+ ctx.causal = causal
395
+ ctx.window_size = window_size
396
+ ctx.attention_chunk = attention_chunk
397
+ ctx.softcap = softcap
398
+ ctx.deterministic = deterministic
399
+ ctx.sm_margin = sm_margin
400
+ return out, softmax_lse
401
+
402
+ @staticmethod
403
+ def backward(ctx, dout, *args):
404
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
+ _flash_attn_backward(
408
+ dout,
409
+ q,
410
+ k,
411
+ v,
412
+ out,
413
+ softmax_lse,
414
+ cu_seqlens_q,
415
+ cu_seqlens_k,
416
+ seqused_q,
417
+ seqused_k,
418
+ ctx.max_seqlen_q,
419
+ ctx.max_seqlen_k,
420
+ dq,
421
+ dk,
422
+ dv,
423
+ ctx.softmax_scale,
424
+ ctx.causal,
425
+ ctx.window_size,
426
+ ctx.softcap,
427
+ ctx.deterministic,
428
+ ctx.sm_margin,
429
+ )
430
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
+ dk = dk[..., : k.shape[-1]]
432
+ dv = dv[..., : v.shape[-1]]
433
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
+
435
+
436
+ def flash_attn_qkvpacked_func(
437
+ qkv,
438
+ softmax_scale=None,
439
+ causal=False,
440
+ q_descale=None, k_descale=None, v_descale=None,
441
+ window_size=(-1, -1),
442
+ attention_chunk=0,
443
+ softcap=0.0,
444
+ deterministic=False,
445
+ num_heads_q=None,
446
+ sm_margin=0,
447
+ ):
448
+ """dropout_p should be set to 0.0 during evaluation
449
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
+ of the gradients of Q, K, V.
452
+ For multi-query and grouped-query attention (MQA/GQA), please see
453
+ flash_attn_kvpacked_func and flash_attn_func.
454
+
455
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
+
458
+ Arguments:
459
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
460
+ dropout_p: float. Dropout probability.
461
+ softmax_scale: float. The scaling of QK^T before applying softmax.
462
+ Default to 1 / sqrt(headdim).
463
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
+ softcap: float. Anything > 0 activates softcapping attention.
466
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
+ the attention score of query i and key j.
468
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
470
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
+ testing only. The returned probabilities are not guaranteed to be correct
472
+ (they might not have the right scaling).
473
+ Return:
474
+ out: (batch_size, seqlen, nheads, headdim).
475
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
+ normalization factor).
478
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
+ The output of softmax (possibly with different scaling). It also encodes the dropout
480
+ pattern (negative means that location was dropped, nonnegative means it was kept).
481
+ """
482
+ return FlashAttnQKVPackedFunc.apply(
483
+ qkv,
484
+ softmax_scale,
485
+ causal,
486
+ q_descale, k_descale, v_descale,
487
+ window_size,
488
+ attention_chunk,
489
+ softcap,
490
+ deterministic,
491
+ num_heads_q,
492
+ sm_margin,
493
+ )
494
+
495
+
496
+ def flash_attn_func(
497
+ q,
498
+ k,
499
+ v,
500
+ softmax_scale=None,
501
+ causal=False,
502
+ qv=None,
503
+ q_descale=None, k_descale=None, v_descale=None,
504
+ window_size=(-1, -1),
505
+ attention_chunk=0,
506
+ softcap=0.0,
507
+ num_splits=1,
508
+ pack_gqa=None,
509
+ deterministic=False,
510
+ sm_margin=0,
511
+ ):
512
+ """dropout_p should be set to 0.0 during evaluation
513
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
+
518
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
+ 1 1 1 1 0
521
+ 1 1 1 1 1
522
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
+ 0 0
524
+ 0 0
525
+ 0 0
526
+ 1 0
527
+ 1 1
528
+ If the row of the mask is all zero, the output will be zero.
529
+
530
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
+ will only attend to keys between
532
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
+
534
+ Arguments:
535
+ q: (batch_size, seqlen, nheads, headdim)
536
+ k: (batch_size, seqlen, nheads_k, headdim)
537
+ v: (batch_size, seqlen, nheads_k, headdim)
538
+ dropout_p: float. Dropout probability.
539
+ softmax_scale: float. The scaling of QK^T before applying softmax.
540
+ Default to 1 / sqrt(headdim).
541
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
+ is added to the attention score of query i and key j.
546
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
548
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
+ testing only. The returned probabilities are not guaranteed to be correct
550
+ (they might not have the right scaling).
551
+ Return:
552
+ out: (batch_size, seqlen, nheads, headdim).
553
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
+ normalization factor).
556
+ """
557
+ return FlashAttnFunc.apply(
558
+ q,
559
+ k,
560
+ v,
561
+ softmax_scale,
562
+ causal,
563
+ qv,
564
+ q_descale, k_descale, v_descale,
565
+ window_size,
566
+ attention_chunk,
567
+ softcap,
568
+ num_splits,
569
+ pack_gqa,
570
+ deterministic,
571
+ sm_margin,
572
+ )
573
+
574
+
575
+ def flash_attn_varlen_func(
576
+ q,
577
+ k,
578
+ v,
579
+ cu_seqlens_q,
580
+ cu_seqlens_k,
581
+ max_seqlen_q,
582
+ max_seqlen_k,
583
+ seqused_q=None,
584
+ seqused_k=None,
585
+ softmax_scale=None,
586
+ causal=False,
587
+ qv=None,
588
+ q_descale=None, k_descale=None, v_descale=None,
589
+ window_size=(-1, -1),
590
+ attention_chunk=0,
591
+ softcap=0.0,
592
+ num_splits=1,
593
+ pack_gqa=None,
594
+ deterministic=False,
595
+ sm_margin=0,
596
+ ):
597
+ return FlashAttnVarlenFunc.apply(
598
+ q,
599
+ k,
600
+ v,
601
+ cu_seqlens_q,
602
+ cu_seqlens_k,
603
+ seqused_q,
604
+ seqused_k,
605
+ max_seqlen_q,
606
+ max_seqlen_k,
607
+ softmax_scale,
608
+ causal,
609
+ qv,
610
+ q_descale, k_descale, v_descale,
611
+ window_size,
612
+ attention_chunk,
613
+ softcap,
614
+ num_splits,
615
+ pack_gqa,
616
+ deterministic,
617
+ sm_margin,
618
+ )
619
+
620
+
621
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
+
624
+
625
+ def flash_attn_with_kvcache(
626
+ q,
627
+ k_cache,
628
+ v_cache,
629
+ k=None,
630
+ v=None,
631
+ qv=None,
632
+ rotary_cos=None,
633
+ rotary_sin=None,
634
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
+ cache_batch_idx: Optional[torch.Tensor] = None,
636
+ cache_leftpad: Optional[torch.Tensor] = None,
637
+ page_table: Optional[torch.Tensor] = None,
638
+ cu_seqlens_q: Optional[torch.Tensor] = None,
639
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
+ max_seqlen_q: Optional[int] = None,
641
+ rotary_seqlens: Optional[torch.Tensor] = None,
642
+ q_descale: Optional[torch.Tensor] = None,
643
+ k_descale: Optional[torch.Tensor] = None,
644
+ v_descale: Optional[torch.Tensor] = None,
645
+ softmax_scale=None,
646
+ causal=False,
647
+ window_size=(-1, -1), # -1 means infinite context window
648
+ attention_chunk=0,
649
+ softcap=0.0, # 0.0 means deactivated
650
+ rotary_interleaved=True,
651
+ scheduler_metadata=None,
652
+ num_splits=0, # Can be tuned for speed
653
+ pack_gqa=None, # Can be tuned for speed
654
+ sm_margin=0, # Can be tuned if some SMs are used for communication
655
+ return_softmax_lse=False,
656
+ ):
657
+ """
658
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
+ the previous step, and update them with the new keys/values from the current step, and do
661
+ attention with the updated cache, all in 1 kernel.
662
+
663
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
+
667
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
+
674
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
+
676
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
+
681
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
+ 1 1 1 1 0
684
+ 1 1 1 1 1
685
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
+ 0 0
687
+ 0 0
688
+ 0 0
689
+ 1 0
690
+ 1 1
691
+ If the row of the mask is all zero, the output will be zero.
692
+
693
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
+ will only attend to keys between
695
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
+
697
+ Note: Does not support backward pass.
698
+
699
+ Arguments:
700
+ q: (batch_size, seqlen, nheads, headdim)
701
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
+ page_block_size must be a multiple of 256.
704
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
+ k with k_cache, starting at the indices specified by cache_seqlens.
708
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
+ KV cache.
715
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
718
+ might come from any of the duplicate indices.
719
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
+ softmax_scale: float. The scaling of QK^T before applying softmax.
722
+ Default to 1 / sqrt(headdim).
723
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
+ softcap: float. Anything > 0 activates softcapping attention.
726
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
+ (i.e. GPT-NeoX style).
730
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
+ to automatically determine the number of splits.
733
+ Don't change this unless you know what you are doing.
734
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
+
736
+ Return:
737
+ out: (batch_size, seqlen, nheads, headdim).
738
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
+ normalization factor).
741
+ """
742
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
+ if softmax_scale is None:
745
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
+ cache_seqlens = torch.full(
748
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
+ )
750
+ cache_seqlens = maybe_contiguous(cache_seqlens)
751
+ out, softmax_lse, *rest = _flash_attn_forward(
752
+ q,
753
+ k_cache,
754
+ v_cache,
755
+ k,
756
+ v,
757
+ qv,
758
+ None, # out
759
+ cu_seqlens_q,
760
+ None, # cu_seqlens_k
761
+ cu_seqlens_k_new,
762
+ None, # seqused_q
763
+ cache_seqlens,
764
+ max_seqlen_q,
765
+ None, # max_seqlen_k
766
+ page_table,
767
+ cache_batch_idx,
768
+ cache_leftpad,
769
+ rotary_cos,
770
+ rotary_sin,
771
+ rotary_seqlens,
772
+ q_descale, k_descale, v_descale,
773
+ softmax_scale,
774
+ causal=causal,
775
+ window_size=window_size,
776
+ attention_chunk=attention_chunk,
777
+ softcap=softcap,
778
+ rotary_interleaved=rotary_interleaved,
779
+ scheduler_metadata=scheduler_metadata,
780
+ num_splits=num_splits,
781
+ pack_gqa=pack_gqa,
782
+ sm_margin=sm_margin,
783
+ )
784
+ # return (out, softmax_lse) if return_softmax_lse else out
785
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
786
+
787
+
788
+ def get_scheduler_metadata(
789
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
+ cache_seqlens: torch.Tensor,
791
+ qkv_dtype=torch.bfloat16,
792
+ headdim_v=None,
793
+ cu_seqlens_q: Optional[torch.Tensor] = None,
794
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
+ cache_leftpad: Optional[torch.Tensor] = None,
796
+ page_size: Optional[int] = None,
797
+ max_seqlen_k_new=0,
798
+ causal=False,
799
+ window_size=(-1, -1), # -1 means infinite context window
800
+ attention_chunk=0,
801
+ has_softcap=False,
802
+ num_splits=0, # Can be tuned for speed
803
+ pack_gqa=None, # Can be tuned for speed
804
+ sm_margin=0, # Can be tuned if some SMs are used for communication
805
+ ):
806
+ cache_seqlens = maybe_contiguous(cache_seqlens)
807
+ if headdim_v is None:
808
+ headdim_v = headdim
809
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
+ qkv_dtype,
812
+ cache_seqlens,
813
+ cu_seqlens_q,
814
+ None, # cu_seqlens_k
815
+ cu_seqlens_k_new,
816
+ None, # seqused_q
817
+ cache_leftpad,
818
+ page_size,
819
+ max_seqlen_k_new,
820
+ causal,
821
+ window_size[0], window_size[1],
822
+ attention_chunk,
823
+ has_softcap,
824
+ num_splits,
825
+ pack_gqa,
826
+ sm_margin,
827
+ )
828
+ return scheduler_metadata
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (438 Bytes). View file
 
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (530 Bytes). View file
 
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc ADDED
Binary file (26.2 kB). View file
 
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/_flash_attn3_8d4f83f.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9aef52109e5974778e3ccc2f697c4e6050b365624c843a675ce894b938341cc
3
+ size 822395576
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn3_8d4f83f
3
+ ops = torch.ops._flash_attn3_8d4f83f
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn3_8d4f83f::{op_name}"
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ._ops import ops as flash_attn_3_cuda
9
+
10
+ def maybe_contiguous(x):
11
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
+
13
+
14
+ def _flash_attn_forward(
15
+ q,
16
+ k,
17
+ v,
18
+ k_new,
19
+ v_new,
20
+ qv,
21
+ out,
22
+ cu_seqlens_q,
23
+ cu_seqlens_k,
24
+ cu_seqlens_k_new,
25
+ seqused_q,
26
+ seqused_k,
27
+ max_seqlen_q,
28
+ max_seqlen_k,
29
+ page_table,
30
+ kv_batch_idx,
31
+ leftpad_k,
32
+ rotary_cos,
33
+ rotary_sin,
34
+ seqlens_rotary,
35
+ q_descale,
36
+ k_descale,
37
+ v_descale,
38
+ softmax_scale,
39
+ causal,
40
+ window_size=(-1, -1),
41
+ attention_chunk=0,
42
+ softcap=0.0,
43
+ rotary_interleaved=True,
44
+ scheduler_metadata=None,
45
+ num_splits=1,
46
+ pack_gqa=None,
47
+ sm_margin=0):
48
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
+ ]
53
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
+ page_table, kv_batch_idx, leftpad_k = [
55
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
+ ]
57
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
+ out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
+ q,
61
+ k,
62
+ v,
63
+ k_new,
64
+ v_new,
65
+ qv,
66
+ out,
67
+ cu_seqlens_q,
68
+ cu_seqlens_k,
69
+ cu_seqlens_k_new,
70
+ seqused_q,
71
+ seqused_k,
72
+ max_seqlen_q,
73
+ max_seqlen_k,
74
+ page_table,
75
+ kv_batch_idx,
76
+ leftpad_k,
77
+ rotary_cos,
78
+ rotary_sin,
79
+ seqlens_rotary,
80
+ q_descale,
81
+ k_descale,
82
+ v_descale,
83
+ softmax_scale,
84
+ causal,
85
+ window_size[0],
86
+ window_size[1],
87
+ attention_chunk,
88
+ softcap,
89
+ rotary_interleaved,
90
+ scheduler_metadata,
91
+ num_splits,
92
+ pack_gqa,
93
+ sm_margin,
94
+ )
95
+ return out, softmax_lse, *rest
96
+
97
+
98
+ def _flash_attn_backward(
99
+ dout,
100
+ q,
101
+ k,
102
+ v,
103
+ out,
104
+ softmax_lse,
105
+ cu_seqlens_q,
106
+ cu_seqlens_k,
107
+ sequed_q,
108
+ sequed_k,
109
+ max_seqlen_q,
110
+ max_seqlen_k,
111
+ dq,
112
+ dk,
113
+ dv,
114
+ softmax_scale,
115
+ causal,
116
+ window_size=(-1, -1),
117
+ softcap=0.0,
118
+ deterministic=False,
119
+ sm_margin=0,
120
+ ):
121
+ # dq, dk, dv are allocated by us so they should already be contiguous
122
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
+ dout,
125
+ q,
126
+ k,
127
+ v,
128
+ out,
129
+ softmax_lse,
130
+ dq,
131
+ dk,
132
+ dv,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ sequed_q,
136
+ sequed_k,
137
+ max_seqlen_q,
138
+ max_seqlen_k,
139
+ softmax_scale,
140
+ causal,
141
+ window_size[0],
142
+ window_size[1],
143
+ softcap,
144
+ deterministic,
145
+ sm_margin,
146
+ )
147
+ return dq, dk, dv, softmax_d
148
+
149
+
150
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(
153
+ ctx,
154
+ qkv,
155
+ softmax_scale,
156
+ causal,
157
+ q_descale=None, k_descale=None, v_descale=None,
158
+ window_size=(-1, -1),
159
+ attention_chunk=0,
160
+ softcap=0.0,
161
+ deterministic=False,
162
+ num_heads_q=None,
163
+ sm_margin=0,
164
+ ):
165
+ if softmax_scale is None:
166
+ softmax_scale = qkv.shape[-1] ** (-0.5)
167
+ if qkv.dim() == 5:
168
+ assert qkv.shape[-3] == 3
169
+ q, k, v = qkv.unbind(dim=-3)
170
+ else:
171
+ assert qkv.dim() == 4
172
+ assert num_heads_q is not None
173
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
+ out, softmax_lse, *rest = _flash_attn_forward(
177
+ q,
178
+ k,
179
+ v,
180
+ None, None, # k_new, v_new
181
+ None, # qv
182
+ None, # out
183
+ None, None, None, # cu_seqlens_q/k/k_new
184
+ None, None, # seqused_q/k
185
+ None, None, # max_seqlen_q/k
186
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
+ None, None, None, # rotary_cos/sin, seqlens_rotary
188
+ q_descale, k_descale, v_descale,
189
+ softmax_scale,
190
+ causal=causal,
191
+ window_size=window_size,
192
+ attention_chunk=attention_chunk,
193
+ softcap=softcap,
194
+ sm_margin=sm_margin,
195
+ )
196
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
198
+ ctx.softmax_scale = softmax_scale
199
+ ctx.causal = causal
200
+ ctx.window_size = window_size
201
+ ctx.attention_chunk = attention_chunk
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ ctx.sm_margin = sm_margin
206
+ # return out, softmax_lse
207
+ return out
208
+
209
+ @staticmethod
210
+ def backward(ctx, dout, *args):
211
+ q, k, v, out, softmax_lse = ctx.saved_tensors
212
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
+ if ctx.ndim == 5:
214
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
+ dq, dk, dv = dqkv.unbind(dim=-3)
217
+ else:
218
+ num_heads_q = q.shape[2]
219
+ num_heads_k = k.shape[2]
220
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
+ _flash_attn_backward(
224
+ dout,
225
+ q,
226
+ k,
227
+ v,
228
+ out,
229
+ softmax_lse,
230
+ None, None, # cu_seqlens_q, cu_seqlens_k,
231
+ None, None, # sequed_q, sequed_k,
232
+ None, None, # max_seqlen_q, max_seqlen_k,
233
+ dq,
234
+ dk,
235
+ dv,
236
+ ctx.softmax_scale,
237
+ ctx.causal,
238
+ ctx.window_size,
239
+ ctx.softcap,
240
+ ctx.deterministic,
241
+ ctx.sm_margin,
242
+ )
243
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
+
246
+
247
+ class FlashAttnFunc(torch.autograd.Function):
248
+
249
+ @staticmethod
250
+ def forward(
251
+ ctx,
252
+ q,
253
+ k,
254
+ v,
255
+ softmax_scale,
256
+ causal,
257
+ qv=None,
258
+ q_descale=None, k_descale=None, v_descale=None,
259
+ window_size=(-1, -1),
260
+ attention_chunk=0,
261
+ softcap=0.0,
262
+ num_splits=1,
263
+ pack_gqa=None,
264
+ deterministic=False,
265
+ sm_margin=0,
266
+ ):
267
+ if softmax_scale is None:
268
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
+ out, softmax_lse, *rest = _flash_attn_forward(
271
+ q,
272
+ k,
273
+ v,
274
+ None, None, # k_new, v_new
275
+ qv, # qv
276
+ None, # out
277
+ None, None, None, # cu_seqlens_q/k/k_new
278
+ None, None, # seqused_q/k
279
+ None, None, # max_seqlen_q/k
280
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
+ None, None, None, # rotary_cos/sin, seqlens_rotary
282
+ q_descale, k_descale, v_descale,
283
+ softmax_scale,
284
+ causal=causal,
285
+ window_size=window_size,
286
+ attention_chunk=attention_chunk,
287
+ softcap=softcap,
288
+ num_splits=num_splits,
289
+ pack_gqa=pack_gqa,
290
+ sm_margin=sm_margin,
291
+ )
292
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
294
+ ctx.softmax_scale = softmax_scale
295
+ ctx.causal = causal
296
+ ctx.window_size = window_size
297
+ ctx.attention_chunk = attention_chunk
298
+ ctx.softcap = softcap
299
+ ctx.deterministic = deterministic
300
+ ctx.sm_margin = sm_margin
301
+ return out, softmax_lse
302
+
303
+ @staticmethod
304
+ def backward(ctx, dout, *args):
305
+ q, k, v, out, softmax_lse = ctx.saved_tensors
306
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
+ _flash_attn_backward(
309
+ dout,
310
+ q,
311
+ k,
312
+ v,
313
+ out,
314
+ softmax_lse,
315
+ None, None, # cu_seqlens_q, cu_seqlens_k,
316
+ None, None, # sequed_q, sequed_k,
317
+ None, None, # max_seqlen_q, max_seqlen_k,
318
+ dq,
319
+ dk,
320
+ dv,
321
+ ctx.softmax_scale,
322
+ ctx.causal,
323
+ ctx.window_size,
324
+ ctx.softcap,
325
+ ctx.deterministic,
326
+ ctx.sm_margin,
327
+ )
328
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
+ dk = dk[..., : k.shape[-1]]
330
+ dv = dv[..., : v.shape[-1]]
331
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
+
333
+
334
+ class FlashAttnVarlenFunc(torch.autograd.Function):
335
+
336
+ @staticmethod
337
+ def forward(
338
+ ctx,
339
+ q,
340
+ k,
341
+ v,
342
+ cu_seqlens_q,
343
+ cu_seqlens_k,
344
+ seqused_q,
345
+ seqused_k,
346
+ max_seqlen_q,
347
+ max_seqlen_k,
348
+ softmax_scale,
349
+ causal,
350
+ qv=None,
351
+ q_descale=None, k_descale=None, v_descale=None,
352
+ window_size=(-1, -1),
353
+ attention_chunk=0,
354
+ softcap=0.0,
355
+ num_splits=1,
356
+ pack_gqa=None,
357
+ deterministic=False,
358
+ sm_margin=0,
359
+ ):
360
+ if softmax_scale is None:
361
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
+ out, softmax_lse, *rest = _flash_attn_forward(
364
+ q,
365
+ k,
366
+ v,
367
+ None, None, # k_new, v_new
368
+ qv, # qv
369
+ None, # out
370
+ cu_seqlens_q,
371
+ cu_seqlens_k,
372
+ None, # cu_seqlens_k_new
373
+ seqused_q,
374
+ seqused_k,
375
+ max_seqlen_q,
376
+ max_seqlen_k,
377
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
+ None, None, None, # rotary_cos/sin, seqlens_rotary
379
+ q_descale, k_descale, v_descale,
380
+ softmax_scale,
381
+ causal=causal,
382
+ window_size=window_size,
383
+ attention_chunk=attention_chunk,
384
+ softcap=softcap,
385
+ num_splits=num_splits,
386
+ pack_gqa=pack_gqa,
387
+ sm_margin=sm_margin,
388
+ )
389
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
+ ctx.max_seqlen_q = max_seqlen_q
392
+ ctx.max_seqlen_k = max_seqlen_k
393
+ ctx.softmax_scale = softmax_scale
394
+ ctx.causal = causal
395
+ ctx.window_size = window_size
396
+ ctx.attention_chunk = attention_chunk
397
+ ctx.softcap = softcap
398
+ ctx.deterministic = deterministic
399
+ ctx.sm_margin = sm_margin
400
+ return out, softmax_lse
401
+
402
+ @staticmethod
403
+ def backward(ctx, dout, *args):
404
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
+ _flash_attn_backward(
408
+ dout,
409
+ q,
410
+ k,
411
+ v,
412
+ out,
413
+ softmax_lse,
414
+ cu_seqlens_q,
415
+ cu_seqlens_k,
416
+ seqused_q,
417
+ seqused_k,
418
+ ctx.max_seqlen_q,
419
+ ctx.max_seqlen_k,
420
+ dq,
421
+ dk,
422
+ dv,
423
+ ctx.softmax_scale,
424
+ ctx.causal,
425
+ ctx.window_size,
426
+ ctx.softcap,
427
+ ctx.deterministic,
428
+ ctx.sm_margin,
429
+ )
430
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
+ dk = dk[..., : k.shape[-1]]
432
+ dv = dv[..., : v.shape[-1]]
433
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
+
435
+
436
+ def flash_attn_qkvpacked_func(
437
+ qkv,
438
+ softmax_scale=None,
439
+ causal=False,
440
+ q_descale=None, k_descale=None, v_descale=None,
441
+ window_size=(-1, -1),
442
+ attention_chunk=0,
443
+ softcap=0.0,
444
+ deterministic=False,
445
+ num_heads_q=None,
446
+ sm_margin=0,
447
+ ):
448
+ """dropout_p should be set to 0.0 during evaluation
449
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
+ of the gradients of Q, K, V.
452
+ For multi-query and grouped-query attention (MQA/GQA), please see
453
+ flash_attn_kvpacked_func and flash_attn_func.
454
+
455
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
+
458
+ Arguments:
459
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
460
+ dropout_p: float. Dropout probability.
461
+ softmax_scale: float. The scaling of QK^T before applying softmax.
462
+ Default to 1 / sqrt(headdim).
463
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
+ softcap: float. Anything > 0 activates softcapping attention.
466
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
+ the attention score of query i and key j.
468
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
470
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
+ testing only. The returned probabilities are not guaranteed to be correct
472
+ (they might not have the right scaling).
473
+ Return:
474
+ out: (batch_size, seqlen, nheads, headdim).
475
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
+ normalization factor).
478
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
+ The output of softmax (possibly with different scaling). It also encodes the dropout
480
+ pattern (negative means that location was dropped, nonnegative means it was kept).
481
+ """
482
+ return FlashAttnQKVPackedFunc.apply(
483
+ qkv,
484
+ softmax_scale,
485
+ causal,
486
+ q_descale, k_descale, v_descale,
487
+ window_size,
488
+ attention_chunk,
489
+ softcap,
490
+ deterministic,
491
+ num_heads_q,
492
+ sm_margin,
493
+ )
494
+
495
+
496
+ def flash_attn_func(
497
+ q,
498
+ k,
499
+ v,
500
+ softmax_scale=None,
501
+ causal=False,
502
+ qv=None,
503
+ q_descale=None, k_descale=None, v_descale=None,
504
+ window_size=(-1, -1),
505
+ attention_chunk=0,
506
+ softcap=0.0,
507
+ num_splits=1,
508
+ pack_gqa=None,
509
+ deterministic=False,
510
+ sm_margin=0,
511
+ ):
512
+ """dropout_p should be set to 0.0 during evaluation
513
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
+
518
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
+ 1 1 1 1 0
521
+ 1 1 1 1 1
522
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
+ 0 0
524
+ 0 0
525
+ 0 0
526
+ 1 0
527
+ 1 1
528
+ If the row of the mask is all zero, the output will be zero.
529
+
530
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
+ will only attend to keys between
532
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
+
534
+ Arguments:
535
+ q: (batch_size, seqlen, nheads, headdim)
536
+ k: (batch_size, seqlen, nheads_k, headdim)
537
+ v: (batch_size, seqlen, nheads_k, headdim)
538
+ dropout_p: float. Dropout probability.
539
+ softmax_scale: float. The scaling of QK^T before applying softmax.
540
+ Default to 1 / sqrt(headdim).
541
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
+ is added to the attention score of query i and key j.
546
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
548
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
+ testing only. The returned probabilities are not guaranteed to be correct
550
+ (they might not have the right scaling).
551
+ Return:
552
+ out: (batch_size, seqlen, nheads, headdim).
553
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
+ normalization factor).
556
+ """
557
+ return FlashAttnFunc.apply(
558
+ q,
559
+ k,
560
+ v,
561
+ softmax_scale,
562
+ causal,
563
+ qv,
564
+ q_descale, k_descale, v_descale,
565
+ window_size,
566
+ attention_chunk,
567
+ softcap,
568
+ num_splits,
569
+ pack_gqa,
570
+ deterministic,
571
+ sm_margin,
572
+ )
573
+
574
+
575
+ def flash_attn_varlen_func(
576
+ q,
577
+ k,
578
+ v,
579
+ cu_seqlens_q,
580
+ cu_seqlens_k,
581
+ max_seqlen_q,
582
+ max_seqlen_k,
583
+ seqused_q=None,
584
+ seqused_k=None,
585
+ softmax_scale=None,
586
+ causal=False,
587
+ qv=None,
588
+ q_descale=None, k_descale=None, v_descale=None,
589
+ window_size=(-1, -1),
590
+ attention_chunk=0,
591
+ softcap=0.0,
592
+ num_splits=1,
593
+ pack_gqa=None,
594
+ deterministic=False,
595
+ sm_margin=0,
596
+ ):
597
+ return FlashAttnVarlenFunc.apply(
598
+ q,
599
+ k,
600
+ v,
601
+ cu_seqlens_q,
602
+ cu_seqlens_k,
603
+ seqused_q,
604
+ seqused_k,
605
+ max_seqlen_q,
606
+ max_seqlen_k,
607
+ softmax_scale,
608
+ causal,
609
+ qv,
610
+ q_descale, k_descale, v_descale,
611
+ window_size,
612
+ attention_chunk,
613
+ softcap,
614
+ num_splits,
615
+ pack_gqa,
616
+ deterministic,
617
+ sm_margin,
618
+ )
619
+
620
+
621
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
+
624
+
625
+ def flash_attn_with_kvcache(
626
+ q,
627
+ k_cache,
628
+ v_cache,
629
+ k=None,
630
+ v=None,
631
+ qv=None,
632
+ rotary_cos=None,
633
+ rotary_sin=None,
634
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
+ cache_batch_idx: Optional[torch.Tensor] = None,
636
+ cache_leftpad: Optional[torch.Tensor] = None,
637
+ page_table: Optional[torch.Tensor] = None,
638
+ cu_seqlens_q: Optional[torch.Tensor] = None,
639
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
+ max_seqlen_q: Optional[int] = None,
641
+ rotary_seqlens: Optional[torch.Tensor] = None,
642
+ q_descale: Optional[torch.Tensor] = None,
643
+ k_descale: Optional[torch.Tensor] = None,
644
+ v_descale: Optional[torch.Tensor] = None,
645
+ softmax_scale=None,
646
+ causal=False,
647
+ window_size=(-1, -1), # -1 means infinite context window
648
+ attention_chunk=0,
649
+ softcap=0.0, # 0.0 means deactivated
650
+ rotary_interleaved=True,
651
+ scheduler_metadata=None,
652
+ num_splits=0, # Can be tuned for speed
653
+ pack_gqa=None, # Can be tuned for speed
654
+ sm_margin=0, # Can be tuned if some SMs are used for communication
655
+ return_softmax_lse=False,
656
+ ):
657
+ """
658
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
+ the previous step, and update them with the new keys/values from the current step, and do
661
+ attention with the updated cache, all in 1 kernel.
662
+
663
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
+
667
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
+
674
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
+
676
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
+
681
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
+ 1 1 1 1 0
684
+ 1 1 1 1 1
685
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
+ 0 0
687
+ 0 0
688
+ 0 0
689
+ 1 0
690
+ 1 1
691
+ If the row of the mask is all zero, the output will be zero.
692
+
693
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
+ will only attend to keys between
695
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
+
697
+ Note: Does not support backward pass.
698
+
699
+ Arguments:
700
+ q: (batch_size, seqlen, nheads, headdim)
701
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
+ page_block_size must be a multiple of 256.
704
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
+ k with k_cache, starting at the indices specified by cache_seqlens.
708
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
+ KV cache.
715
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
718
+ might come from any of the duplicate indices.
719
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
+ softmax_scale: float. The scaling of QK^T before applying softmax.
722
+ Default to 1 / sqrt(headdim).
723
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
+ softcap: float. Anything > 0 activates softcapping attention.
726
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
+ (i.e. GPT-NeoX style).
730
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
+ to automatically determine the number of splits.
733
+ Don't change this unless you know what you are doing.
734
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
+
736
+ Return:
737
+ out: (batch_size, seqlen, nheads, headdim).
738
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
+ normalization factor).
741
+ """
742
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
+ if softmax_scale is None:
745
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
+ cache_seqlens = torch.full(
748
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
+ )
750
+ cache_seqlens = maybe_contiguous(cache_seqlens)
751
+ out, softmax_lse, *rest = _flash_attn_forward(
752
+ q,
753
+ k_cache,
754
+ v_cache,
755
+ k,
756
+ v,
757
+ qv,
758
+ None, # out
759
+ cu_seqlens_q,
760
+ None, # cu_seqlens_k
761
+ cu_seqlens_k_new,
762
+ None, # seqused_q
763
+ cache_seqlens,
764
+ max_seqlen_q,
765
+ None, # max_seqlen_k
766
+ page_table,
767
+ cache_batch_idx,
768
+ cache_leftpad,
769
+ rotary_cos,
770
+ rotary_sin,
771
+ rotary_seqlens,
772
+ q_descale, k_descale, v_descale,
773
+ softmax_scale,
774
+ causal=causal,
775
+ window_size=window_size,
776
+ attention_chunk=attention_chunk,
777
+ softcap=softcap,
778
+ rotary_interleaved=rotary_interleaved,
779
+ scheduler_metadata=scheduler_metadata,
780
+ num_splits=num_splits,
781
+ pack_gqa=pack_gqa,
782
+ sm_margin=sm_margin,
783
+ )
784
+ # return (out, softmax_lse) if return_softmax_lse else out
785
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
786
+
787
+
788
+ def get_scheduler_metadata(
789
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
+ cache_seqlens: torch.Tensor,
791
+ qkv_dtype=torch.bfloat16,
792
+ headdim_v=None,
793
+ cu_seqlens_q: Optional[torch.Tensor] = None,
794
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
+ cache_leftpad: Optional[torch.Tensor] = None,
796
+ page_size: Optional[int] = None,
797
+ max_seqlen_k_new=0,
798
+ causal=False,
799
+ window_size=(-1, -1), # -1 means infinite context window
800
+ attention_chunk=0,
801
+ has_softcap=False,
802
+ num_splits=0, # Can be tuned for speed
803
+ pack_gqa=None, # Can be tuned for speed
804
+ sm_margin=0, # Can be tuned if some SMs are used for communication
805
+ ):
806
+ cache_seqlens = maybe_contiguous(cache_seqlens)
807
+ if headdim_v is None:
808
+ headdim_v = headdim
809
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
+ qkv_dtype,
812
+ cache_seqlens,
813
+ cu_seqlens_q,
814
+ None, # cu_seqlens_k
815
+ cu_seqlens_k_new,
816
+ None, # seqused_q
817
+ cache_leftpad,
818
+ page_size,
819
+ max_seqlen_k_new,
820
+ causal,
821
+ window_size[0], window_size[1],
822
+ attention_chunk,
823
+ has_softcap,
824
+ num_splits,
825
+ pack_gqa,
826
+ sm_margin,
827
+ )
828
+ return scheduler_metadata
build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_48fe103_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc32b815563bc9051986a333a362ff61e37cbd967893212243292fef03b461a5
3
+ size 838544688
build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn3_48fe103_dirty
3
+ ops = torch.ops._flash_attn3_48fe103_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn3_48fe103_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ._ops import ops as flash_attn_3_cuda
9
+
10
+ def maybe_contiguous(x):
11
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
+
13
+
14
+ def _flash_attn_forward(
15
+ q,
16
+ k,
17
+ v,
18
+ k_new,
19
+ v_new,
20
+ qv,
21
+ out,
22
+ cu_seqlens_q,
23
+ cu_seqlens_k,
24
+ cu_seqlens_k_new,
25
+ seqused_q,
26
+ seqused_k,
27
+ max_seqlen_q,
28
+ max_seqlen_k,
29
+ page_table,
30
+ kv_batch_idx,
31
+ leftpad_k,
32
+ rotary_cos,
33
+ rotary_sin,
34
+ seqlens_rotary,
35
+ q_descale,
36
+ k_descale,
37
+ v_descale,
38
+ softmax_scale,
39
+ causal,
40
+ window_size=(-1, -1),
41
+ attention_chunk=0,
42
+ softcap=0.0,
43
+ rotary_interleaved=True,
44
+ scheduler_metadata=None,
45
+ num_splits=1,
46
+ pack_gqa=None,
47
+ sm_margin=0):
48
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
+ ]
53
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
+ page_table, kv_batch_idx, leftpad_k = [
55
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
+ ]
57
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
+ out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
+ q,
61
+ k,
62
+ v,
63
+ k_new,
64
+ v_new,
65
+ qv,
66
+ out,
67
+ cu_seqlens_q,
68
+ cu_seqlens_k,
69
+ cu_seqlens_k_new,
70
+ seqused_q,
71
+ seqused_k,
72
+ max_seqlen_q,
73
+ max_seqlen_k,
74
+ page_table,
75
+ kv_batch_idx,
76
+ leftpad_k,
77
+ rotary_cos,
78
+ rotary_sin,
79
+ seqlens_rotary,
80
+ q_descale,
81
+ k_descale,
82
+ v_descale,
83
+ softmax_scale,
84
+ causal,
85
+ window_size[0],
86
+ window_size[1],
87
+ attention_chunk,
88
+ softcap,
89
+ rotary_interleaved,
90
+ scheduler_metadata,
91
+ num_splits,
92
+ pack_gqa,
93
+ sm_margin,
94
+ )
95
+ return out, softmax_lse, *rest
96
+
97
+
98
+ def _flash_attn_backward(
99
+ dout,
100
+ q,
101
+ k,
102
+ v,
103
+ out,
104
+ softmax_lse,
105
+ cu_seqlens_q,
106
+ cu_seqlens_k,
107
+ sequed_q,
108
+ sequed_k,
109
+ max_seqlen_q,
110
+ max_seqlen_k,
111
+ dq,
112
+ dk,
113
+ dv,
114
+ softmax_scale,
115
+ causal,
116
+ window_size=(-1, -1),
117
+ softcap=0.0,
118
+ deterministic=False,
119
+ sm_margin=0,
120
+ ):
121
+ # dq, dk, dv are allocated by us so they should already be contiguous
122
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
+ dout,
125
+ q,
126
+ k,
127
+ v,
128
+ out,
129
+ softmax_lse,
130
+ dq,
131
+ dk,
132
+ dv,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ sequed_q,
136
+ sequed_k,
137
+ max_seqlen_q,
138
+ max_seqlen_k,
139
+ softmax_scale,
140
+ causal,
141
+ window_size[0],
142
+ window_size[1],
143
+ softcap,
144
+ deterministic,
145
+ sm_margin,
146
+ )
147
+ return dq, dk, dv, softmax_d
148
+
149
+
150
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(
153
+ ctx,
154
+ qkv,
155
+ softmax_scale,
156
+ causal,
157
+ q_descale=None, k_descale=None, v_descale=None,
158
+ window_size=(-1, -1),
159
+ attention_chunk=0,
160
+ softcap=0.0,
161
+ deterministic=False,
162
+ num_heads_q=None,
163
+ sm_margin=0,
164
+ ):
165
+ if softmax_scale is None:
166
+ softmax_scale = qkv.shape[-1] ** (-0.5)
167
+ if qkv.dim() == 5:
168
+ assert qkv.shape[-3] == 3
169
+ q, k, v = qkv.unbind(dim=-3)
170
+ else:
171
+ assert qkv.dim() == 4
172
+ assert num_heads_q is not None
173
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
+ out, softmax_lse, *rest = _flash_attn_forward(
177
+ q,
178
+ k,
179
+ v,
180
+ None, None, # k_new, v_new
181
+ None, # qv
182
+ None, # out
183
+ None, None, None, # cu_seqlens_q/k/k_new
184
+ None, None, # seqused_q/k
185
+ None, None, # max_seqlen_q/k
186
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
+ None, None, None, # rotary_cos/sin, seqlens_rotary
188
+ q_descale, k_descale, v_descale,
189
+ softmax_scale,
190
+ causal=causal,
191
+ window_size=window_size,
192
+ attention_chunk=attention_chunk,
193
+ softcap=softcap,
194
+ sm_margin=sm_margin,
195
+ )
196
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
198
+ ctx.softmax_scale = softmax_scale
199
+ ctx.causal = causal
200
+ ctx.window_size = window_size
201
+ ctx.attention_chunk = attention_chunk
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ ctx.sm_margin = sm_margin
206
+ # return out, softmax_lse
207
+ return out
208
+
209
+ @staticmethod
210
+ def backward(ctx, dout, *args):
211
+ q, k, v, out, softmax_lse = ctx.saved_tensors
212
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
+ if ctx.ndim == 5:
214
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
+ dq, dk, dv = dqkv.unbind(dim=-3)
217
+ else:
218
+ num_heads_q = q.shape[2]
219
+ num_heads_k = k.shape[2]
220
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
+ _flash_attn_backward(
224
+ dout,
225
+ q,
226
+ k,
227
+ v,
228
+ out,
229
+ softmax_lse,
230
+ None, None, # cu_seqlens_q, cu_seqlens_k,
231
+ None, None, # sequed_q, sequed_k,
232
+ None, None, # max_seqlen_q, max_seqlen_k,
233
+ dq,
234
+ dk,
235
+ dv,
236
+ ctx.softmax_scale,
237
+ ctx.causal,
238
+ ctx.window_size,
239
+ ctx.softcap,
240
+ ctx.deterministic,
241
+ ctx.sm_margin,
242
+ )
243
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
+
246
+
247
+ class FlashAttnFunc(torch.autograd.Function):
248
+
249
+ @staticmethod
250
+ def forward(
251
+ ctx,
252
+ q,
253
+ k,
254
+ v,
255
+ softmax_scale,
256
+ causal,
257
+ qv=None,
258
+ q_descale=None, k_descale=None, v_descale=None,
259
+ window_size=(-1, -1),
260
+ attention_chunk=0,
261
+ softcap=0.0,
262
+ num_splits=1,
263
+ pack_gqa=None,
264
+ deterministic=False,
265
+ sm_margin=0,
266
+ ):
267
+ if softmax_scale is None:
268
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
+ out, softmax_lse, *rest = _flash_attn_forward(
271
+ q,
272
+ k,
273
+ v,
274
+ None, None, # k_new, v_new
275
+ qv, # qv
276
+ None, # out
277
+ None, None, None, # cu_seqlens_q/k/k_new
278
+ None, None, # seqused_q/k
279
+ None, None, # max_seqlen_q/k
280
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
+ None, None, None, # rotary_cos/sin, seqlens_rotary
282
+ q_descale, k_descale, v_descale,
283
+ softmax_scale,
284
+ causal=causal,
285
+ window_size=window_size,
286
+ attention_chunk=attention_chunk,
287
+ softcap=softcap,
288
+ num_splits=num_splits,
289
+ pack_gqa=pack_gqa,
290
+ sm_margin=sm_margin,
291
+ )
292
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
294
+ ctx.softmax_scale = softmax_scale
295
+ ctx.causal = causal
296
+ ctx.window_size = window_size
297
+ ctx.attention_chunk = attention_chunk
298
+ ctx.softcap = softcap
299
+ ctx.deterministic = deterministic
300
+ ctx.sm_margin = sm_margin
301
+ return out, softmax_lse
302
+
303
+ @staticmethod
304
+ def backward(ctx, dout, *args):
305
+ q, k, v, out, softmax_lse = ctx.saved_tensors
306
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
+ _flash_attn_backward(
309
+ dout,
310
+ q,
311
+ k,
312
+ v,
313
+ out,
314
+ softmax_lse,
315
+ None, None, # cu_seqlens_q, cu_seqlens_k,
316
+ None, None, # sequed_q, sequed_k,
317
+ None, None, # max_seqlen_q, max_seqlen_k,
318
+ dq,
319
+ dk,
320
+ dv,
321
+ ctx.softmax_scale,
322
+ ctx.causal,
323
+ ctx.window_size,
324
+ ctx.softcap,
325
+ ctx.deterministic,
326
+ ctx.sm_margin,
327
+ )
328
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
+ dk = dk[..., : k.shape[-1]]
330
+ dv = dv[..., : v.shape[-1]]
331
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
+
333
+
334
+ class FlashAttnVarlenFunc(torch.autograd.Function):
335
+
336
+ @staticmethod
337
+ def forward(
338
+ ctx,
339
+ q,
340
+ k,
341
+ v,
342
+ cu_seqlens_q,
343
+ cu_seqlens_k,
344
+ seqused_q,
345
+ seqused_k,
346
+ max_seqlen_q,
347
+ max_seqlen_k,
348
+ softmax_scale,
349
+ causal,
350
+ qv=None,
351
+ q_descale=None, k_descale=None, v_descale=None,
352
+ window_size=(-1, -1),
353
+ attention_chunk=0,
354
+ softcap=0.0,
355
+ num_splits=1,
356
+ pack_gqa=None,
357
+ deterministic=False,
358
+ sm_margin=0,
359
+ ):
360
+ if softmax_scale is None:
361
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
+ out, softmax_lse, *rest = _flash_attn_forward(
364
+ q,
365
+ k,
366
+ v,
367
+ None, None, # k_new, v_new
368
+ qv, # qv
369
+ None, # out
370
+ cu_seqlens_q,
371
+ cu_seqlens_k,
372
+ None, # cu_seqlens_k_new
373
+ seqused_q,
374
+ seqused_k,
375
+ max_seqlen_q,
376
+ max_seqlen_k,
377
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
+ None, None, None, # rotary_cos/sin, seqlens_rotary
379
+ q_descale, k_descale, v_descale,
380
+ softmax_scale,
381
+ causal=causal,
382
+ window_size=window_size,
383
+ attention_chunk=attention_chunk,
384
+ softcap=softcap,
385
+ num_splits=num_splits,
386
+ pack_gqa=pack_gqa,
387
+ sm_margin=sm_margin,
388
+ )
389
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
+ ctx.max_seqlen_q = max_seqlen_q
392
+ ctx.max_seqlen_k = max_seqlen_k
393
+ ctx.softmax_scale = softmax_scale
394
+ ctx.causal = causal
395
+ ctx.window_size = window_size
396
+ ctx.attention_chunk = attention_chunk
397
+ ctx.softcap = softcap
398
+ ctx.deterministic = deterministic
399
+ ctx.sm_margin = sm_margin
400
+ return out, softmax_lse
401
+
402
+ @staticmethod
403
+ def backward(ctx, dout, *args):
404
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
+ _flash_attn_backward(
408
+ dout,
409
+ q,
410
+ k,
411
+ v,
412
+ out,
413
+ softmax_lse,
414
+ cu_seqlens_q,
415
+ cu_seqlens_k,
416
+ seqused_q,
417
+ seqused_k,
418
+ ctx.max_seqlen_q,
419
+ ctx.max_seqlen_k,
420
+ dq,
421
+ dk,
422
+ dv,
423
+ ctx.softmax_scale,
424
+ ctx.causal,
425
+ ctx.window_size,
426
+ ctx.softcap,
427
+ ctx.deterministic,
428
+ ctx.sm_margin,
429
+ )
430
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
+ dk = dk[..., : k.shape[-1]]
432
+ dv = dv[..., : v.shape[-1]]
433
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
+
435
+
436
+ def flash_attn_qkvpacked_func(
437
+ qkv,
438
+ softmax_scale=None,
439
+ causal=False,
440
+ q_descale=None, k_descale=None, v_descale=None,
441
+ window_size=(-1, -1),
442
+ attention_chunk=0,
443
+ softcap=0.0,
444
+ deterministic=False,
445
+ num_heads_q=None,
446
+ sm_margin=0,
447
+ ):
448
+ """dropout_p should be set to 0.0 during evaluation
449
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
+ of the gradients of Q, K, V.
452
+ For multi-query and grouped-query attention (MQA/GQA), please see
453
+ flash_attn_kvpacked_func and flash_attn_func.
454
+
455
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
+
458
+ Arguments:
459
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
460
+ dropout_p: float. Dropout probability.
461
+ softmax_scale: float. The scaling of QK^T before applying softmax.
462
+ Default to 1 / sqrt(headdim).
463
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
+ softcap: float. Anything > 0 activates softcapping attention.
466
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
+ the attention score of query i and key j.
468
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
470
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
+ testing only. The returned probabilities are not guaranteed to be correct
472
+ (they might not have the right scaling).
473
+ Return:
474
+ out: (batch_size, seqlen, nheads, headdim).
475
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
+ normalization factor).
478
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
+ The output of softmax (possibly with different scaling). It also encodes the dropout
480
+ pattern (negative means that location was dropped, nonnegative means it was kept).
481
+ """
482
+ return FlashAttnQKVPackedFunc.apply(
483
+ qkv,
484
+ softmax_scale,
485
+ causal,
486
+ q_descale, k_descale, v_descale,
487
+ window_size,
488
+ attention_chunk,
489
+ softcap,
490
+ deterministic,
491
+ num_heads_q,
492
+ sm_margin,
493
+ )
494
+
495
+
496
+ def flash_attn_func(
497
+ q,
498
+ k,
499
+ v,
500
+ softmax_scale=None,
501
+ causal=False,
502
+ qv=None,
503
+ q_descale=None, k_descale=None, v_descale=None,
504
+ window_size=(-1, -1),
505
+ attention_chunk=0,
506
+ softcap=0.0,
507
+ num_splits=1,
508
+ pack_gqa=None,
509
+ deterministic=False,
510
+ sm_margin=0,
511
+ ):
512
+ """dropout_p should be set to 0.0 during evaluation
513
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
+
518
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
+ 1 1 1 1 0
521
+ 1 1 1 1 1
522
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
+ 0 0
524
+ 0 0
525
+ 0 0
526
+ 1 0
527
+ 1 1
528
+ If the row of the mask is all zero, the output will be zero.
529
+
530
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
+ will only attend to keys between
532
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
+
534
+ Arguments:
535
+ q: (batch_size, seqlen, nheads, headdim)
536
+ k: (batch_size, seqlen, nheads_k, headdim)
537
+ v: (batch_size, seqlen, nheads_k, headdim)
538
+ dropout_p: float. Dropout probability.
539
+ softmax_scale: float. The scaling of QK^T before applying softmax.
540
+ Default to 1 / sqrt(headdim).
541
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
+ is added to the attention score of query i and key j.
546
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
548
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
+ testing only. The returned probabilities are not guaranteed to be correct
550
+ (they might not have the right scaling).
551
+ Return:
552
+ out: (batch_size, seqlen, nheads, headdim).
553
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
+ normalization factor).
556
+ """
557
+ return FlashAttnFunc.apply(
558
+ q,
559
+ k,
560
+ v,
561
+ softmax_scale,
562
+ causal,
563
+ qv,
564
+ q_descale, k_descale, v_descale,
565
+ window_size,
566
+ attention_chunk,
567
+ softcap,
568
+ num_splits,
569
+ pack_gqa,
570
+ deterministic,
571
+ sm_margin,
572
+ )
573
+
574
+
575
+ def flash_attn_varlen_func(
576
+ q,
577
+ k,
578
+ v,
579
+ cu_seqlens_q,
580
+ cu_seqlens_k,
581
+ max_seqlen_q,
582
+ max_seqlen_k,
583
+ seqused_q=None,
584
+ seqused_k=None,
585
+ softmax_scale=None,
586
+ causal=False,
587
+ qv=None,
588
+ q_descale=None, k_descale=None, v_descale=None,
589
+ window_size=(-1, -1),
590
+ attention_chunk=0,
591
+ softcap=0.0,
592
+ num_splits=1,
593
+ pack_gqa=None,
594
+ deterministic=False,
595
+ sm_margin=0,
596
+ ):
597
+ return FlashAttnVarlenFunc.apply(
598
+ q,
599
+ k,
600
+ v,
601
+ cu_seqlens_q,
602
+ cu_seqlens_k,
603
+ seqused_q,
604
+ seqused_k,
605
+ max_seqlen_q,
606
+ max_seqlen_k,
607
+ softmax_scale,
608
+ causal,
609
+ qv,
610
+ q_descale, k_descale, v_descale,
611
+ window_size,
612
+ attention_chunk,
613
+ softcap,
614
+ num_splits,
615
+ pack_gqa,
616
+ deterministic,
617
+ sm_margin,
618
+ )
619
+
620
+
621
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
+
624
+
625
+ def flash_attn_with_kvcache(
626
+ q,
627
+ k_cache,
628
+ v_cache,
629
+ k=None,
630
+ v=None,
631
+ qv=None,
632
+ rotary_cos=None,
633
+ rotary_sin=None,
634
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
+ cache_batch_idx: Optional[torch.Tensor] = None,
636
+ cache_leftpad: Optional[torch.Tensor] = None,
637
+ page_table: Optional[torch.Tensor] = None,
638
+ cu_seqlens_q: Optional[torch.Tensor] = None,
639
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
+ max_seqlen_q: Optional[int] = None,
641
+ rotary_seqlens: Optional[torch.Tensor] = None,
642
+ q_descale: Optional[torch.Tensor] = None,
643
+ k_descale: Optional[torch.Tensor] = None,
644
+ v_descale: Optional[torch.Tensor] = None,
645
+ softmax_scale=None,
646
+ causal=False,
647
+ window_size=(-1, -1), # -1 means infinite context window
648
+ attention_chunk=0,
649
+ softcap=0.0, # 0.0 means deactivated
650
+ rotary_interleaved=True,
651
+ scheduler_metadata=None,
652
+ num_splits=0, # Can be tuned for speed
653
+ pack_gqa=None, # Can be tuned for speed
654
+ sm_margin=0, # Can be tuned if some SMs are used for communication
655
+ return_softmax_lse=False,
656
+ ):
657
+ """
658
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
+ the previous step, and update them with the new keys/values from the current step, and do
661
+ attention with the updated cache, all in 1 kernel.
662
+
663
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
+
667
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
+
674
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
+
676
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
+
681
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
+ 1 1 1 1 0
684
+ 1 1 1 1 1
685
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
+ 0 0
687
+ 0 0
688
+ 0 0
689
+ 1 0
690
+ 1 1
691
+ If the row of the mask is all zero, the output will be zero.
692
+
693
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
+ will only attend to keys between
695
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
+
697
+ Note: Does not support backward pass.
698
+
699
+ Arguments:
700
+ q: (batch_size, seqlen, nheads, headdim)
701
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
+ page_block_size must be a multiple of 256.
704
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
+ k with k_cache, starting at the indices specified by cache_seqlens.
708
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
+ KV cache.
715
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
718
+ might come from any of the duplicate indices.
719
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
+ softmax_scale: float. The scaling of QK^T before applying softmax.
722
+ Default to 1 / sqrt(headdim).
723
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
+ softcap: float. Anything > 0 activates softcapping attention.
726
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
+ (i.e. GPT-NeoX style).
730
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
+ to automatically determine the number of splits.
733
+ Don't change this unless you know what you are doing.
734
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
+
736
+ Return:
737
+ out: (batch_size, seqlen, nheads, headdim).
738
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
+ normalization factor).
741
+ """
742
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
+ if softmax_scale is None:
745
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
+ cache_seqlens = torch.full(
748
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
+ )
750
+ cache_seqlens = maybe_contiguous(cache_seqlens)
751
+ out, softmax_lse, *rest = _flash_attn_forward(
752
+ q,
753
+ k_cache,
754
+ v_cache,
755
+ k,
756
+ v,
757
+ qv,
758
+ None, # out
759
+ cu_seqlens_q,
760
+ None, # cu_seqlens_k
761
+ cu_seqlens_k_new,
762
+ None, # seqused_q
763
+ cache_seqlens,
764
+ max_seqlen_q,
765
+ None, # max_seqlen_k
766
+ page_table,
767
+ cache_batch_idx,
768
+ cache_leftpad,
769
+ rotary_cos,
770
+ rotary_sin,
771
+ rotary_seqlens,
772
+ q_descale, k_descale, v_descale,
773
+ softmax_scale,
774
+ causal=causal,
775
+ window_size=window_size,
776
+ attention_chunk=attention_chunk,
777
+ softcap=softcap,
778
+ rotary_interleaved=rotary_interleaved,
779
+ scheduler_metadata=scheduler_metadata,
780
+ num_splits=num_splits,
781
+ pack_gqa=pack_gqa,
782
+ sm_margin=sm_margin,
783
+ )
784
+ # return (out, softmax_lse) if return_softmax_lse else out
785
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
786
+
787
+
788
+ def get_scheduler_metadata(
789
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
+ cache_seqlens: torch.Tensor,
791
+ qkv_dtype=torch.bfloat16,
792
+ headdim_v=None,
793
+ cu_seqlens_q: Optional[torch.Tensor] = None,
794
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
+ cache_leftpad: Optional[torch.Tensor] = None,
796
+ page_size: Optional[int] = None,
797
+ max_seqlen_k_new=0,
798
+ causal=False,
799
+ window_size=(-1, -1), # -1 means infinite context window
800
+ attention_chunk=0,
801
+ has_softcap=False,
802
+ num_splits=0, # Can be tuned for speed
803
+ pack_gqa=None, # Can be tuned for speed
804
+ sm_margin=0, # Can be tuned if some SMs are used for communication
805
+ ):
806
+ cache_seqlens = maybe_contiguous(cache_seqlens)
807
+ if headdim_v is None:
808
+ headdim_v = headdim
809
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
+ qkv_dtype,
812
+ cache_seqlens,
813
+ cu_seqlens_q,
814
+ None, # cu_seqlens_k
815
+ cu_seqlens_k_new,
816
+ None, # seqused_q
817
+ cache_leftpad,
818
+ page_size,
819
+ max_seqlen_k_new,
820
+ causal,
821
+ window_size[0], window_size[1],
822
+ attention_chunk,
823
+ has_softcap,
824
+ num_splits,
825
+ pack_gqa,
826
+ sm_margin,
827
+ )
828
+ return scheduler_metadata
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (438 Bytes). View file
 
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (530 Bytes). View file
 
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc ADDED
Binary file (26.2 kB). View file
 
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/_flash_attn3_8d4f83f.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9aef52109e5974778e3ccc2f697c4e6050b365624c843a675ce894b938341cc
3
+ size 822395576
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn3_8d4f83f
3
+ ops = torch.ops._flash_attn3_8d4f83f
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn3_8d4f83f::{op_name}"
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ._ops import ops as flash_attn_3_cuda
9
+
10
+ def maybe_contiguous(x):
11
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
+
13
+
14
+ def _flash_attn_forward(
15
+ q,
16
+ k,
17
+ v,
18
+ k_new,
19
+ v_new,
20
+ qv,
21
+ out,
22
+ cu_seqlens_q,
23
+ cu_seqlens_k,
24
+ cu_seqlens_k_new,
25
+ seqused_q,
26
+ seqused_k,
27
+ max_seqlen_q,
28
+ max_seqlen_k,
29
+ page_table,
30
+ kv_batch_idx,
31
+ leftpad_k,
32
+ rotary_cos,
33
+ rotary_sin,
34
+ seqlens_rotary,
35
+ q_descale,
36
+ k_descale,
37
+ v_descale,
38
+ softmax_scale,
39
+ causal,
40
+ window_size=(-1, -1),
41
+ attention_chunk=0,
42
+ softcap=0.0,
43
+ rotary_interleaved=True,
44
+ scheduler_metadata=None,
45
+ num_splits=1,
46
+ pack_gqa=None,
47
+ sm_margin=0):
48
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
+ ]
53
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
+ page_table, kv_batch_idx, leftpad_k = [
55
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
+ ]
57
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
+ out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
+ q,
61
+ k,
62
+ v,
63
+ k_new,
64
+ v_new,
65
+ qv,
66
+ out,
67
+ cu_seqlens_q,
68
+ cu_seqlens_k,
69
+ cu_seqlens_k_new,
70
+ seqused_q,
71
+ seqused_k,
72
+ max_seqlen_q,
73
+ max_seqlen_k,
74
+ page_table,
75
+ kv_batch_idx,
76
+ leftpad_k,
77
+ rotary_cos,
78
+ rotary_sin,
79
+ seqlens_rotary,
80
+ q_descale,
81
+ k_descale,
82
+ v_descale,
83
+ softmax_scale,
84
+ causal,
85
+ window_size[0],
86
+ window_size[1],
87
+ attention_chunk,
88
+ softcap,
89
+ rotary_interleaved,
90
+ scheduler_metadata,
91
+ num_splits,
92
+ pack_gqa,
93
+ sm_margin,
94
+ )
95
+ return out, softmax_lse, *rest
96
+
97
+
98
+ def _flash_attn_backward(
99
+ dout,
100
+ q,
101
+ k,
102
+ v,
103
+ out,
104
+ softmax_lse,
105
+ cu_seqlens_q,
106
+ cu_seqlens_k,
107
+ sequed_q,
108
+ sequed_k,
109
+ max_seqlen_q,
110
+ max_seqlen_k,
111
+ dq,
112
+ dk,
113
+ dv,
114
+ softmax_scale,
115
+ causal,
116
+ window_size=(-1, -1),
117
+ softcap=0.0,
118
+ deterministic=False,
119
+ sm_margin=0,
120
+ ):
121
+ # dq, dk, dv are allocated by us so they should already be contiguous
122
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
+ dout,
125
+ q,
126
+ k,
127
+ v,
128
+ out,
129
+ softmax_lse,
130
+ dq,
131
+ dk,
132
+ dv,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ sequed_q,
136
+ sequed_k,
137
+ max_seqlen_q,
138
+ max_seqlen_k,
139
+ softmax_scale,
140
+ causal,
141
+ window_size[0],
142
+ window_size[1],
143
+ softcap,
144
+ deterministic,
145
+ sm_margin,
146
+ )
147
+ return dq, dk, dv, softmax_d
148
+
149
+
150
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(
153
+ ctx,
154
+ qkv,
155
+ softmax_scale,
156
+ causal,
157
+ q_descale=None, k_descale=None, v_descale=None,
158
+ window_size=(-1, -1),
159
+ attention_chunk=0,
160
+ softcap=0.0,
161
+ deterministic=False,
162
+ num_heads_q=None,
163
+ sm_margin=0,
164
+ ):
165
+ if softmax_scale is None:
166
+ softmax_scale = qkv.shape[-1] ** (-0.5)
167
+ if qkv.dim() == 5:
168
+ assert qkv.shape[-3] == 3
169
+ q, k, v = qkv.unbind(dim=-3)
170
+ else:
171
+ assert qkv.dim() == 4
172
+ assert num_heads_q is not None
173
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
+ out, softmax_lse, *rest = _flash_attn_forward(
177
+ q,
178
+ k,
179
+ v,
180
+ None, None, # k_new, v_new
181
+ None, # qv
182
+ None, # out
183
+ None, None, None, # cu_seqlens_q/k/k_new
184
+ None, None, # seqused_q/k
185
+ None, None, # max_seqlen_q/k
186
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
+ None, None, None, # rotary_cos/sin, seqlens_rotary
188
+ q_descale, k_descale, v_descale,
189
+ softmax_scale,
190
+ causal=causal,
191
+ window_size=window_size,
192
+ attention_chunk=attention_chunk,
193
+ softcap=softcap,
194
+ sm_margin=sm_margin,
195
+ )
196
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
198
+ ctx.softmax_scale = softmax_scale
199
+ ctx.causal = causal
200
+ ctx.window_size = window_size
201
+ ctx.attention_chunk = attention_chunk
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ ctx.sm_margin = sm_margin
206
+ # return out, softmax_lse
207
+ return out
208
+
209
+ @staticmethod
210
+ def backward(ctx, dout, *args):
211
+ q, k, v, out, softmax_lse = ctx.saved_tensors
212
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
+ if ctx.ndim == 5:
214
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
+ dq, dk, dv = dqkv.unbind(dim=-3)
217
+ else:
218
+ num_heads_q = q.shape[2]
219
+ num_heads_k = k.shape[2]
220
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
+ _flash_attn_backward(
224
+ dout,
225
+ q,
226
+ k,
227
+ v,
228
+ out,
229
+ softmax_lse,
230
+ None, None, # cu_seqlens_q, cu_seqlens_k,
231
+ None, None, # sequed_q, sequed_k,
232
+ None, None, # max_seqlen_q, max_seqlen_k,
233
+ dq,
234
+ dk,
235
+ dv,
236
+ ctx.softmax_scale,
237
+ ctx.causal,
238
+ ctx.window_size,
239
+ ctx.softcap,
240
+ ctx.deterministic,
241
+ ctx.sm_margin,
242
+ )
243
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
+
246
+
247
+ class FlashAttnFunc(torch.autograd.Function):
248
+
249
+ @staticmethod
250
+ def forward(
251
+ ctx,
252
+ q,
253
+ k,
254
+ v,
255
+ softmax_scale,
256
+ causal,
257
+ qv=None,
258
+ q_descale=None, k_descale=None, v_descale=None,
259
+ window_size=(-1, -1),
260
+ attention_chunk=0,
261
+ softcap=0.0,
262
+ num_splits=1,
263
+ pack_gqa=None,
264
+ deterministic=False,
265
+ sm_margin=0,
266
+ ):
267
+ if softmax_scale is None:
268
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
+ out, softmax_lse, *rest = _flash_attn_forward(
271
+ q,
272
+ k,
273
+ v,
274
+ None, None, # k_new, v_new
275
+ qv, # qv
276
+ None, # out
277
+ None, None, None, # cu_seqlens_q/k/k_new
278
+ None, None, # seqused_q/k
279
+ None, None, # max_seqlen_q/k
280
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
+ None, None, None, # rotary_cos/sin, seqlens_rotary
282
+ q_descale, k_descale, v_descale,
283
+ softmax_scale,
284
+ causal=causal,
285
+ window_size=window_size,
286
+ attention_chunk=attention_chunk,
287
+ softcap=softcap,
288
+ num_splits=num_splits,
289
+ pack_gqa=pack_gqa,
290
+ sm_margin=sm_margin,
291
+ )
292
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
294
+ ctx.softmax_scale = softmax_scale
295
+ ctx.causal = causal
296
+ ctx.window_size = window_size
297
+ ctx.attention_chunk = attention_chunk
298
+ ctx.softcap = softcap
299
+ ctx.deterministic = deterministic
300
+ ctx.sm_margin = sm_margin
301
+ return out, softmax_lse
302
+
303
+ @staticmethod
304
+ def backward(ctx, dout, *args):
305
+ q, k, v, out, softmax_lse = ctx.saved_tensors
306
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
+ _flash_attn_backward(
309
+ dout,
310
+ q,
311
+ k,
312
+ v,
313
+ out,
314
+ softmax_lse,
315
+ None, None, # cu_seqlens_q, cu_seqlens_k,
316
+ None, None, # sequed_q, sequed_k,
317
+ None, None, # max_seqlen_q, max_seqlen_k,
318
+ dq,
319
+ dk,
320
+ dv,
321
+ ctx.softmax_scale,
322
+ ctx.causal,
323
+ ctx.window_size,
324
+ ctx.softcap,
325
+ ctx.deterministic,
326
+ ctx.sm_margin,
327
+ )
328
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
+ dk = dk[..., : k.shape[-1]]
330
+ dv = dv[..., : v.shape[-1]]
331
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
+
333
+
334
+ class FlashAttnVarlenFunc(torch.autograd.Function):
335
+
336
+ @staticmethod
337
+ def forward(
338
+ ctx,
339
+ q,
340
+ k,
341
+ v,
342
+ cu_seqlens_q,
343
+ cu_seqlens_k,
344
+ seqused_q,
345
+ seqused_k,
346
+ max_seqlen_q,
347
+ max_seqlen_k,
348
+ softmax_scale,
349
+ causal,
350
+ qv=None,
351
+ q_descale=None, k_descale=None, v_descale=None,
352
+ window_size=(-1, -1),
353
+ attention_chunk=0,
354
+ softcap=0.0,
355
+ num_splits=1,
356
+ pack_gqa=None,
357
+ deterministic=False,
358
+ sm_margin=0,
359
+ ):
360
+ if softmax_scale is None:
361
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
+ out, softmax_lse, *rest = _flash_attn_forward(
364
+ q,
365
+ k,
366
+ v,
367
+ None, None, # k_new, v_new
368
+ qv, # qv
369
+ None, # out
370
+ cu_seqlens_q,
371
+ cu_seqlens_k,
372
+ None, # cu_seqlens_k_new
373
+ seqused_q,
374
+ seqused_k,
375
+ max_seqlen_q,
376
+ max_seqlen_k,
377
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
+ None, None, None, # rotary_cos/sin, seqlens_rotary
379
+ q_descale, k_descale, v_descale,
380
+ softmax_scale,
381
+ causal=causal,
382
+ window_size=window_size,
383
+ attention_chunk=attention_chunk,
384
+ softcap=softcap,
385
+ num_splits=num_splits,
386
+ pack_gqa=pack_gqa,
387
+ sm_margin=sm_margin,
388
+ )
389
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
+ ctx.max_seqlen_q = max_seqlen_q
392
+ ctx.max_seqlen_k = max_seqlen_k
393
+ ctx.softmax_scale = softmax_scale
394
+ ctx.causal = causal
395
+ ctx.window_size = window_size
396
+ ctx.attention_chunk = attention_chunk
397
+ ctx.softcap = softcap
398
+ ctx.deterministic = deterministic
399
+ ctx.sm_margin = sm_margin
400
+ return out, softmax_lse
401
+
402
+ @staticmethod
403
+ def backward(ctx, dout, *args):
404
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
+ _flash_attn_backward(
408
+ dout,
409
+ q,
410
+ k,
411
+ v,
412
+ out,
413
+ softmax_lse,
414
+ cu_seqlens_q,
415
+ cu_seqlens_k,
416
+ seqused_q,
417
+ seqused_k,
418
+ ctx.max_seqlen_q,
419
+ ctx.max_seqlen_k,
420
+ dq,
421
+ dk,
422
+ dv,
423
+ ctx.softmax_scale,
424
+ ctx.causal,
425
+ ctx.window_size,
426
+ ctx.softcap,
427
+ ctx.deterministic,
428
+ ctx.sm_margin,
429
+ )
430
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
+ dk = dk[..., : k.shape[-1]]
432
+ dv = dv[..., : v.shape[-1]]
433
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
+
435
+
436
+ def flash_attn_qkvpacked_func(
437
+ qkv,
438
+ softmax_scale=None,
439
+ causal=False,
440
+ q_descale=None, k_descale=None, v_descale=None,
441
+ window_size=(-1, -1),
442
+ attention_chunk=0,
443
+ softcap=0.0,
444
+ deterministic=False,
445
+ num_heads_q=None,
446
+ sm_margin=0,
447
+ ):
448
+ """dropout_p should be set to 0.0 during evaluation
449
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
+ of the gradients of Q, K, V.
452
+ For multi-query and grouped-query attention (MQA/GQA), please see
453
+ flash_attn_kvpacked_func and flash_attn_func.
454
+
455
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
+
458
+ Arguments:
459
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
460
+ dropout_p: float. Dropout probability.
461
+ softmax_scale: float. The scaling of QK^T before applying softmax.
462
+ Default to 1 / sqrt(headdim).
463
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
+ softcap: float. Anything > 0 activates softcapping attention.
466
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
+ the attention score of query i and key j.
468
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
470
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
+ testing only. The returned probabilities are not guaranteed to be correct
472
+ (they might not have the right scaling).
473
+ Return:
474
+ out: (batch_size, seqlen, nheads, headdim).
475
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
+ normalization factor).
478
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
+ The output of softmax (possibly with different scaling). It also encodes the dropout
480
+ pattern (negative means that location was dropped, nonnegative means it was kept).
481
+ """
482
+ return FlashAttnQKVPackedFunc.apply(
483
+ qkv,
484
+ softmax_scale,
485
+ causal,
486
+ q_descale, k_descale, v_descale,
487
+ window_size,
488
+ attention_chunk,
489
+ softcap,
490
+ deterministic,
491
+ num_heads_q,
492
+ sm_margin,
493
+ )
494
+
495
+
496
+ def flash_attn_func(
497
+ q,
498
+ k,
499
+ v,
500
+ softmax_scale=None,
501
+ causal=False,
502
+ qv=None,
503
+ q_descale=None, k_descale=None, v_descale=None,
504
+ window_size=(-1, -1),
505
+ attention_chunk=0,
506
+ softcap=0.0,
507
+ num_splits=1,
508
+ pack_gqa=None,
509
+ deterministic=False,
510
+ sm_margin=0,
511
+ ):
512
+ """dropout_p should be set to 0.0 during evaluation
513
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
+
518
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
+ 1 1 1 1 0
521
+ 1 1 1 1 1
522
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
+ 0 0
524
+ 0 0
525
+ 0 0
526
+ 1 0
527
+ 1 1
528
+ If the row of the mask is all zero, the output will be zero.
529
+
530
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
+ will only attend to keys between
532
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
+
534
+ Arguments:
535
+ q: (batch_size, seqlen, nheads, headdim)
536
+ k: (batch_size, seqlen, nheads_k, headdim)
537
+ v: (batch_size, seqlen, nheads_k, headdim)
538
+ dropout_p: float. Dropout probability.
539
+ softmax_scale: float. The scaling of QK^T before applying softmax.
540
+ Default to 1 / sqrt(headdim).
541
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
+ is added to the attention score of query i and key j.
546
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
548
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
+ testing only. The returned probabilities are not guaranteed to be correct
550
+ (they might not have the right scaling).
551
+ Return:
552
+ out: (batch_size, seqlen, nheads, headdim).
553
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
+ normalization factor).
556
+ """
557
+ return FlashAttnFunc.apply(
558
+ q,
559
+ k,
560
+ v,
561
+ softmax_scale,
562
+ causal,
563
+ qv,
564
+ q_descale, k_descale, v_descale,
565
+ window_size,
566
+ attention_chunk,
567
+ softcap,
568
+ num_splits,
569
+ pack_gqa,
570
+ deterministic,
571
+ sm_margin,
572
+ )
573
+
574
+
575
+ def flash_attn_varlen_func(
576
+ q,
577
+ k,
578
+ v,
579
+ cu_seqlens_q,
580
+ cu_seqlens_k,
581
+ max_seqlen_q,
582
+ max_seqlen_k,
583
+ seqused_q=None,
584
+ seqused_k=None,
585
+ softmax_scale=None,
586
+ causal=False,
587
+ qv=None,
588
+ q_descale=None, k_descale=None, v_descale=None,
589
+ window_size=(-1, -1),
590
+ attention_chunk=0,
591
+ softcap=0.0,
592
+ num_splits=1,
593
+ pack_gqa=None,
594
+ deterministic=False,
595
+ sm_margin=0,
596
+ ):
597
+ return FlashAttnVarlenFunc.apply(
598
+ q,
599
+ k,
600
+ v,
601
+ cu_seqlens_q,
602
+ cu_seqlens_k,
603
+ seqused_q,
604
+ seqused_k,
605
+ max_seqlen_q,
606
+ max_seqlen_k,
607
+ softmax_scale,
608
+ causal,
609
+ qv,
610
+ q_descale, k_descale, v_descale,
611
+ window_size,
612
+ attention_chunk,
613
+ softcap,
614
+ num_splits,
615
+ pack_gqa,
616
+ deterministic,
617
+ sm_margin,
618
+ )
619
+
620
+
621
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
+
624
+
625
+ def flash_attn_with_kvcache(
626
+ q,
627
+ k_cache,
628
+ v_cache,
629
+ k=None,
630
+ v=None,
631
+ qv=None,
632
+ rotary_cos=None,
633
+ rotary_sin=None,
634
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
+ cache_batch_idx: Optional[torch.Tensor] = None,
636
+ cache_leftpad: Optional[torch.Tensor] = None,
637
+ page_table: Optional[torch.Tensor] = None,
638
+ cu_seqlens_q: Optional[torch.Tensor] = None,
639
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
+ max_seqlen_q: Optional[int] = None,
641
+ rotary_seqlens: Optional[torch.Tensor] = None,
642
+ q_descale: Optional[torch.Tensor] = None,
643
+ k_descale: Optional[torch.Tensor] = None,
644
+ v_descale: Optional[torch.Tensor] = None,
645
+ softmax_scale=None,
646
+ causal=False,
647
+ window_size=(-1, -1), # -1 means infinite context window
648
+ attention_chunk=0,
649
+ softcap=0.0, # 0.0 means deactivated
650
+ rotary_interleaved=True,
651
+ scheduler_metadata=None,
652
+ num_splits=0, # Can be tuned for speed
653
+ pack_gqa=None, # Can be tuned for speed
654
+ sm_margin=0, # Can be tuned if some SMs are used for communication
655
+ return_softmax_lse=False,
656
+ ):
657
+ """
658
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
+ the previous step, and update them with the new keys/values from the current step, and do
661
+ attention with the updated cache, all in 1 kernel.
662
+
663
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
+
667
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
+
674
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
+
676
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
+
681
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
+ 1 1 1 1 0
684
+ 1 1 1 1 1
685
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
+ 0 0
687
+ 0 0
688
+ 0 0
689
+ 1 0
690
+ 1 1
691
+ If the row of the mask is all zero, the output will be zero.
692
+
693
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
+ will only attend to keys between
695
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
+
697
+ Note: Does not support backward pass.
698
+
699
+ Arguments:
700
+ q: (batch_size, seqlen, nheads, headdim)
701
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
+ page_block_size must be a multiple of 256.
704
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
+ k with k_cache, starting at the indices specified by cache_seqlens.
708
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
+ KV cache.
715
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
718
+ might come from any of the duplicate indices.
719
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
+ softmax_scale: float. The scaling of QK^T before applying softmax.
722
+ Default to 1 / sqrt(headdim).
723
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
+ softcap: float. Anything > 0 activates softcapping attention.
726
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
+ (i.e. GPT-NeoX style).
730
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
+ to automatically determine the number of splits.
733
+ Don't change this unless you know what you are doing.
734
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
+
736
+ Return:
737
+ out: (batch_size, seqlen, nheads, headdim).
738
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
+ normalization factor).
741
+ """
742
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
+ if softmax_scale is None:
745
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
+ cache_seqlens = torch.full(
748
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
+ )
750
+ cache_seqlens = maybe_contiguous(cache_seqlens)
751
+ out, softmax_lse, *rest = _flash_attn_forward(
752
+ q,
753
+ k_cache,
754
+ v_cache,
755
+ k,
756
+ v,
757
+ qv,
758
+ None, # out
759
+ cu_seqlens_q,
760
+ None, # cu_seqlens_k
761
+ cu_seqlens_k_new,
762
+ None, # seqused_q
763
+ cache_seqlens,
764
+ max_seqlen_q,
765
+ None, # max_seqlen_k
766
+ page_table,
767
+ cache_batch_idx,
768
+ cache_leftpad,
769
+ rotary_cos,
770
+ rotary_sin,
771
+ rotary_seqlens,
772
+ q_descale, k_descale, v_descale,
773
+ softmax_scale,
774
+ causal=causal,
775
+ window_size=window_size,
776
+ attention_chunk=attention_chunk,
777
+ softcap=softcap,
778
+ rotary_interleaved=rotary_interleaved,
779
+ scheduler_metadata=scheduler_metadata,
780
+ num_splits=num_splits,
781
+ pack_gqa=pack_gqa,
782
+ sm_margin=sm_margin,
783
+ )
784
+ # return (out, softmax_lse) if return_softmax_lse else out
785
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
786
+
787
+
788
+ def get_scheduler_metadata(
789
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
+ cache_seqlens: torch.Tensor,
791
+ qkv_dtype=torch.bfloat16,
792
+ headdim_v=None,
793
+ cu_seqlens_q: Optional[torch.Tensor] = None,
794
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
+ cache_leftpad: Optional[torch.Tensor] = None,
796
+ page_size: Optional[int] = None,
797
+ max_seqlen_k_new=0,
798
+ causal=False,
799
+ window_size=(-1, -1), # -1 means infinite context window
800
+ attention_chunk=0,
801
+ has_softcap=False,
802
+ num_splits=0, # Can be tuned for speed
803
+ pack_gqa=None, # Can be tuned for speed
804
+ sm_margin=0, # Can be tuned if some SMs are used for communication
805
+ ):
806
+ cache_seqlens = maybe_contiguous(cache_seqlens)
807
+ if headdim_v is None:
808
+ headdim_v = headdim
809
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
+ qkv_dtype,
812
+ cache_seqlens,
813
+ cu_seqlens_q,
814
+ None, # cu_seqlens_k
815
+ cu_seqlens_k_new,
816
+ None, # seqused_q
817
+ cache_leftpad,
818
+ page_size,
819
+ max_seqlen_k_new,
820
+ causal,
821
+ window_size[0], window_size[1],
822
+ attention_chunk,
823
+ has_softcap,
824
+ num_splits,
825
+ pack_gqa,
826
+ sm_margin,
827
+ )
828
+ return scheduler_metadata
build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (438 Bytes). View file
 
build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (530 Bytes). View file
 
build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc ADDED
Binary file (26.2 kB). View file