JeffreyXiang commited on
Commit
3de321c
·
1 Parent(s): 71d5a9d
trellis2/modules/attention/full_attn.py CHANGED
@@ -117,7 +117,8 @@ def scaled_dot_product_attention(*args, **kwargs):
117
  if num_all_args == 1:
118
  out = flash_attn_3.flash_attn_qkvpacked_func(qkv)
119
  elif num_all_args == 2:
120
- out = flash_attn_3.flash_attn_kvpacked_func(q, kv)
 
121
  elif num_all_args == 3:
122
  out = flash_attn_3.flash_attn_func(q, k, v)
123
  elif config.BACKEND == 'sdpa':
 
117
  if num_all_args == 1:
118
  out = flash_attn_3.flash_attn_qkvpacked_func(qkv)
119
  elif num_all_args == 2:
120
+ k, v = kv.unbind(dim=2)
121
+ out = flash_attn_3.flash_attn_func(q, k, v)
122
  elif num_all_args == 3:
123
  out = flash_attn_3.flash_attn_func(q, k, v)
124
  elif config.BACKEND == 'sdpa':
trellis2/modules/sparse/attention/full_attn.py CHANGED
@@ -197,14 +197,20 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
197
  if 'flash_attn_3' not in globals():
198
  import flash_attn_interface as flash_attn_3
199
  cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
200
- if num_all_args in [2, 3]:
201
- cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
202
  if num_all_args == 1:
203
- out = flash_attn_3.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
 
 
204
  elif num_all_args == 2:
205
- out = flash_attn_3.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
 
 
 
206
  elif num_all_args == 3:
207
- out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
 
 
 
208
  else:
209
  raise ValueError(f"Unknown attention module: {config.ATTN}")
210
 
 
197
  if 'flash_attn_3' not in globals():
198
  import flash_attn_interface as flash_attn_3
199
  cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
 
 
200
  if num_all_args == 1:
201
+ q, k, v = qkv.unbind(dim=1)
202
+ cu_seqlens_kv = cu_seqlens_q.clone()
203
+ max_q_seqlen = max_kv_seqlen = max(q_seqlen)
204
  elif num_all_args == 2:
205
+ k, v = kv.unbind(dim=1)
206
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
207
+ max_q_seqlen = max(q_seqlen)
208
+ max_kv_seqlen = max(kv_seqlen)
209
  elif num_all_args == 3:
210
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
211
+ max_q_seqlen = max(q_seqlen)
212
+ max_kv_seqlen = max(kv_seqlen)
213
+ out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen)
214
  else:
215
  raise ValueError(f"Unknown attention module: {config.ATTN}")
216