Improve tests for mps
Browse files- tests/kernels/conftest.py +0 -1
- tests/kernels/test_attention.py +16 -3
- tests/kernels/test_cache.py +11 -5
- tests/kernels/utils.py +11 -16
tests/kernels/conftest.py
CHANGED
|
@@ -36,7 +36,6 @@ def create_kv_caches_with_random(
|
|
| 36 |
seed: int = 0,
|
| 37 |
device: Optional[str] = "cuda",
|
| 38 |
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
| 39 |
-
|
| 40 |
if cache_dtype == "fp8" and head_size % 16:
|
| 41 |
raise ValueError(
|
| 42 |
f"Does not support key cache of type fp8 with head_size {head_size}"
|
|
|
|
| 36 |
seed: int = 0,
|
| 37 |
device: Optional[str] = "cuda",
|
| 38 |
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
|
|
| 39 |
if cache_dtype == "fp8" and head_size % 16:
|
| 40 |
raise ValueError(
|
| 41 |
f"Does not support key cache of type fp8 with head_size {head_size}"
|
tests/kernels/test_attention.py
CHANGED
|
@@ -43,6 +43,7 @@ if current_platform.is_mps():
|
|
| 43 |
else:
|
| 44 |
DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
| 45 |
|
|
|
|
| 46 |
def ref_masked_attention(
|
| 47 |
query: torch.Tensor,
|
| 48 |
key: torch.Tensor,
|
|
@@ -232,7 +233,11 @@ def test_paged_attention(
|
|
| 232 |
64,
|
| 233 |
0,
|
| 234 |
),
|
| 235 |
-
cond=(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
)
|
| 237 |
|
| 238 |
elif version in ("v2", "rocm"):
|
|
@@ -295,7 +300,11 @@ def test_paged_attention(
|
|
| 295 |
64,
|
| 296 |
0,
|
| 297 |
),
|
| 298 |
-
cond=(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
)
|
| 300 |
|
| 301 |
else:
|
|
@@ -340,7 +349,11 @@ def test_paged_attention(
|
|
| 340 |
k_scale,
|
| 341 |
v_scale,
|
| 342 |
),
|
| 343 |
-
cond=(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
)
|
| 345 |
|
| 346 |
else:
|
|
|
|
| 43 |
else:
|
| 44 |
DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
| 45 |
|
| 46 |
+
|
| 47 |
def ref_masked_attention(
|
| 48 |
query: torch.Tensor,
|
| 49 |
key: torch.Tensor,
|
|
|
|
| 233 |
64,
|
| 234 |
0,
|
| 235 |
),
|
| 236 |
+
cond=(
|
| 237 |
+
head_size == HEAD_SIZES[0]
|
| 238 |
+
and block_size == BLOCK_SIZES[0]
|
| 239 |
+
and not device.startswith("mps")
|
| 240 |
+
),
|
| 241 |
)
|
| 242 |
|
| 243 |
elif version in ("v2", "rocm"):
|
|
|
|
| 300 |
64,
|
| 301 |
0,
|
| 302 |
),
|
| 303 |
+
cond=(
|
| 304 |
+
head_size == HEAD_SIZES[0]
|
| 305 |
+
and block_size == BLOCK_SIZES[0]
|
| 306 |
+
and not device.startswith("mps")
|
| 307 |
+
),
|
| 308 |
)
|
| 309 |
|
| 310 |
else:
|
|
|
|
| 349 |
k_scale,
|
| 350 |
v_scale,
|
| 351 |
),
|
| 352 |
+
cond=(
|
| 353 |
+
head_size == HEAD_SIZES[0]
|
| 354 |
+
and block_size == BLOCK_SIZES[0]
|
| 355 |
+
and not device.startswith("mps")
|
| 356 |
+
),
|
| 357 |
)
|
| 358 |
|
| 359 |
else:
|
tests/kernels/test_cache.py
CHANGED
|
@@ -60,7 +60,9 @@ def test_copy_blocks(
|
|
| 60 |
if kv_cache_dtype == "fp8" and head_size % 16:
|
| 61 |
pytest.skip()
|
| 62 |
current_platform.seed_everything(seed)
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
# Generate random block mappings where each source block is mapped to two
|
| 65 |
# destination blocks.
|
| 66 |
assert 2 * num_mappings <= num_blocks
|
|
@@ -144,13 +146,15 @@ def test_reshape_and_cache(
|
|
| 144 |
if kv_cache_dtype == "fp8" and head_size % 16:
|
| 145 |
pytest.skip()
|
| 146 |
current_platform.seed_everything(seed)
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
# Create a random slot mapping.
|
| 149 |
num_slots = block_size * num_blocks
|
| 150 |
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
| 151 |
-
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
|
| 152 |
|
| 153 |
-
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
|
| 154 |
_, key, value = qkv.unbind(dim=1)
|
| 155 |
|
| 156 |
# Create the KV caches.
|
|
@@ -262,7 +266,9 @@ def test_reshape_and_cache_flash(
|
|
| 262 |
if current_platform.is_mps() and kv_cache_dtype == "fp8":
|
| 263 |
pytest.skip("reshape_and_cache_flash doesn't support FP8 on MPS")
|
| 264 |
current_platform.seed_everything(seed)
|
| 265 |
-
|
|
|
|
|
|
|
| 266 |
|
| 267 |
# Create a random slot mapping.
|
| 268 |
num_slots = block_size * num_blocks
|
|
|
|
| 60 |
if kv_cache_dtype == "fp8" and head_size % 16:
|
| 61 |
pytest.skip()
|
| 62 |
current_platform.seed_everything(seed)
|
| 63 |
+
# Don't set MPS as default device to avoid placeholder storage error
|
| 64 |
+
if not device.startswith("mps"):
|
| 65 |
+
torch.set_default_device(device)
|
| 66 |
# Generate random block mappings where each source block is mapped to two
|
| 67 |
# destination blocks.
|
| 68 |
assert 2 * num_mappings <= num_blocks
|
|
|
|
| 146 |
if kv_cache_dtype == "fp8" and head_size % 16:
|
| 147 |
pytest.skip()
|
| 148 |
current_platform.seed_everything(seed)
|
| 149 |
+
# Don't set MPS as default device to avoid placeholder storage error
|
| 150 |
+
if not device.startswith("mps"):
|
| 151 |
+
torch.set_default_device(device)
|
| 152 |
# Create a random slot mapping.
|
| 153 |
num_slots = block_size * num_blocks
|
| 154 |
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
| 155 |
+
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
| 156 |
|
| 157 |
+
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
|
| 158 |
_, key, value = qkv.unbind(dim=1)
|
| 159 |
|
| 160 |
# Create the KV caches.
|
|
|
|
| 266 |
if current_platform.is_mps() and kv_cache_dtype == "fp8":
|
| 267 |
pytest.skip("reshape_and_cache_flash doesn't support FP8 on MPS")
|
| 268 |
current_platform.seed_everything(seed)
|
| 269 |
+
# Don't set MPS as default device to avoid placeholder storage error
|
| 270 |
+
if not device.startswith("mps"):
|
| 271 |
+
torch.set_default_device(device)
|
| 272 |
|
| 273 |
# Create a random slot mapping.
|
| 274 |
num_slots = block_size * num_blocks
|
tests/kernels/utils.py
CHANGED
|
@@ -40,10 +40,18 @@ def fp8_allclose(
|
|
| 40 |
"""
|
| 41 |
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
return bool(
|
| 44 |
torch.all(
|
| 45 |
torch.isclose(
|
| 46 |
-
|
| 47 |
)
|
| 48 |
).item()
|
| 49 |
)
|
|
@@ -68,25 +76,12 @@ def opcheck(
|
|
| 68 |
*,
|
| 69 |
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
| 70 |
raise_exception: bool = True,
|
| 71 |
-
cond: bool = True
|
| 72 |
) -> Dict[str, str]:
|
| 73 |
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
| 74 |
if not cond:
|
| 75 |
return {}
|
| 76 |
-
|
| 77 |
-
# Check if any arguments are on MPS device and skip opcheck if so
|
| 78 |
-
# as MPS has issues with placeholder storage allocation in opcheck
|
| 79 |
-
def is_mps_tensor(x):
|
| 80 |
-
return hasattr(x, 'device') and x.device.type == 'mps'
|
| 81 |
-
|
| 82 |
-
def check_args_for_mps(args):
|
| 83 |
-
if isinstance(args, (list, tuple)):
|
| 84 |
-
return any(check_args_for_mps(arg) for arg in args)
|
| 85 |
-
return is_mps_tensor(args)
|
| 86 |
-
|
| 87 |
-
if check_args_for_mps(args):
|
| 88 |
-
return {}
|
| 89 |
-
|
| 90 |
return torch.library.opcheck(
|
| 91 |
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
| 92 |
)
|
|
|
|
| 40 |
"""
|
| 41 |
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
|
| 42 |
|
| 43 |
+
# MPS doesn't support float64, so use float32 for comparison
|
| 44 |
+
if a.device.type == "mps" or b.device.type == "mps":
|
| 45 |
+
a_cmp = a.float()
|
| 46 |
+
b_cmp = b.float()
|
| 47 |
+
else:
|
| 48 |
+
a_cmp = a.double()
|
| 49 |
+
b_cmp = b.double()
|
| 50 |
+
|
| 51 |
return bool(
|
| 52 |
torch.all(
|
| 53 |
torch.isclose(
|
| 54 |
+
a_cmp, b_cmp, rtol=rtol, atol=atol, equal_nan=equal_nan
|
| 55 |
)
|
| 56 |
).item()
|
| 57 |
)
|
|
|
|
| 76 |
*,
|
| 77 |
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
| 78 |
raise_exception: bool = True,
|
| 79 |
+
cond: bool = True,
|
| 80 |
) -> Dict[str, str]:
|
| 81 |
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
| 82 |
if not cond:
|
| 83 |
return {}
|
| 84 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
return torch.library.opcheck(
|
| 86 |
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
| 87 |
)
|