Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
3de321c
1
Parent(s):
71d5a9d
update
Browse files
trellis2/modules/attention/full_attn.py
CHANGED
|
@@ -117,7 +117,8 @@ def scaled_dot_product_attention(*args, **kwargs):
|
|
| 117 |
if num_all_args == 1:
|
| 118 |
out = flash_attn_3.flash_attn_qkvpacked_func(qkv)
|
| 119 |
elif num_all_args == 2:
|
| 120 |
-
|
|
|
|
| 121 |
elif num_all_args == 3:
|
| 122 |
out = flash_attn_3.flash_attn_func(q, k, v)
|
| 123 |
elif config.BACKEND == 'sdpa':
|
|
|
|
| 117 |
if num_all_args == 1:
|
| 118 |
out = flash_attn_3.flash_attn_qkvpacked_func(qkv)
|
| 119 |
elif num_all_args == 2:
|
| 120 |
+
k, v = kv.unbind(dim=2)
|
| 121 |
+
out = flash_attn_3.flash_attn_func(q, k, v)
|
| 122 |
elif num_all_args == 3:
|
| 123 |
out = flash_attn_3.flash_attn_func(q, k, v)
|
| 124 |
elif config.BACKEND == 'sdpa':
|
trellis2/modules/sparse/attention/full_attn.py
CHANGED
|
@@ -197,14 +197,20 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
|
|
| 197 |
if 'flash_attn_3' not in globals():
|
| 198 |
import flash_attn_interface as flash_attn_3
|
| 199 |
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
| 200 |
-
if num_all_args in [2, 3]:
|
| 201 |
-
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
| 202 |
if num_all_args == 1:
|
| 203 |
-
|
|
|
|
|
|
|
| 204 |
elif num_all_args == 2:
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
| 206 |
elif num_all_args == 3:
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
| 208 |
else:
|
| 209 |
raise ValueError(f"Unknown attention module: {config.ATTN}")
|
| 210 |
|
|
|
|
| 197 |
if 'flash_attn_3' not in globals():
|
| 198 |
import flash_attn_interface as flash_attn_3
|
| 199 |
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
|
|
|
|
|
|
| 200 |
if num_all_args == 1:
|
| 201 |
+
q, k, v = qkv.unbind(dim=1)
|
| 202 |
+
cu_seqlens_kv = cu_seqlens_q.clone()
|
| 203 |
+
max_q_seqlen = max_kv_seqlen = max(q_seqlen)
|
| 204 |
elif num_all_args == 2:
|
| 205 |
+
k, v = kv.unbind(dim=1)
|
| 206 |
+
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
| 207 |
+
max_q_seqlen = max(q_seqlen)
|
| 208 |
+
max_kv_seqlen = max(kv_seqlen)
|
| 209 |
elif num_all_args == 3:
|
| 210 |
+
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
| 211 |
+
max_q_seqlen = max(q_seqlen)
|
| 212 |
+
max_kv_seqlen = max(kv_seqlen)
|
| 213 |
+
out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen)
|
| 214 |
else:
|
| 215 |
raise ValueError(f"Unknown attention module: {config.ATTN}")
|
| 216 |
|