medmekk HF Staff commited on
Commit
596776b
·
1 Parent(s): a8031ce

rm core tests for now

Browse files
Files changed (1) hide show
  1. tests/test_core.py +0 -73
tests/test_core.py CHANGED
@@ -1,73 +0,0 @@
1
- import math
2
- import pytest
3
- import torch
4
-
5
- import sage_attention as sa
6
-
7
-
8
- cuda_available = torch.cuda.is_available()
9
-
10
-
11
- def current_sm():
12
- if not cuda_available:
13
- return None
14
- major, minor = torch.cuda.get_device_capability(0)
15
- return f"sm{major}{minor}"
16
-
17
-
18
- @pytest.mark.skipif(not cuda_available, reason="CUDA is required")
19
- @pytest.mark.parametrize("tensor_layout", ["HND", "NHD"])
20
- @pytest.mark.parametrize("head_dim", [64, 128])
21
- @pytest.mark.parametrize("return_lse", [False, True])
22
- def test_sageattn_runs_and_shapes(tensor_layout, head_dim, return_lse):
23
- device = "cuda"
24
- dtype = torch.float16
25
-
26
- # Small, nontrivial shapes; pad path will be exercised for head_dim=64
27
- if tensor_layout == "HND":
28
- q = torch.randn(2, 6, 129, head_dim, dtype=dtype, device=device)
29
- k = torch.randn(2, 3, 257, head_dim, dtype=dtype, device=device)
30
- v = torch.randn(2, 3, 257, head_dim, dtype=dtype, device=device)
31
- expected_o_shape = (2, 6, 129, head_dim)
32
- expected_lse_shape = (2, 6, 129)
33
- else:
34
- q = torch.randn(2, 129, 6, head_dim, dtype=dtype, device=device)
35
- k = torch.randn(2, 257, 3, head_dim, dtype=dtype, device=device)
36
- v = torch.randn(2, 257, 3, head_dim, dtype=dtype, device=device)
37
- expected_o_shape = (2, 129, 6, head_dim)
38
- expected_lse_shape = (2, 6, 129)
39
-
40
- sm = current_sm()
41
-
42
- # Some backends may not be compiled on this GPU; skip gracefully if unsupported
43
- try:
44
- out = sa.sageattn(
45
- q, k, v, tensor_layout=tensor_layout, is_causal=False, return_lse=return_lse
46
- )
47
- except ValueError as e:
48
- if "Unsupported CUDA architecture" in str(e):
49
- pytest.skip(f"Unsupported arch for this build: {sm}")
50
- raise
51
-
52
- if return_lse:
53
- o, lse = out
54
- assert lse.shape == expected_lse_shape and torch.isfinite(lse).all()
55
- else:
56
- o = out
57
-
58
- assert o.shape == expected_o_shape
59
- assert o.dtype == dtype
60
- assert o.device.type == "cuda"
61
-
62
-
63
- @pytest.mark.skipif(not cuda_available, reason="CUDA is required")
64
- def test_sageattn_raises_on_unsupported_head_dim():
65
- device = "cuda"
66
- dtype = torch.float16
67
- # head_dim > 128 should raise
68
- q = torch.randn(1, 2, 8, 192, dtype=dtype, device=device)
69
- k = torch.randn(1, 1, 8, 192, dtype=dtype, device=device)
70
- v = torch.randn(1, 1, 8, 192, dtype=dtype, device=device)
71
-
72
- with pytest.raises(ValueError):
73
- sa.sageattn(q, k, v)