Enable ROCm build
Browse files- build.toml +52 -1
- flake.lock +6 -6
build.toml
CHANGED
|
@@ -15,7 +15,25 @@ src = [
|
|
| 15 |
"cuda-utils/cuda_utils_kernels.cu",
|
| 16 |
]
|
| 17 |
depends = []
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
[kernel.paged_attention]
|
| 21 |
backend = "cuda"
|
|
@@ -40,6 +58,39 @@ src = [
|
|
| 40 |
include = [ "cuda-utils", "paged-attention" ]
|
| 41 |
depends = [ "torch" ]
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
[kernel.paged_attention_metal]
|
| 45 |
backend = "metal"
|
|
|
|
| 15 |
"cuda-utils/cuda_utils_kernels.cu",
|
| 16 |
]
|
| 17 |
depends = []
|
| 18 |
+
|
| 19 |
+
[kernel.cuda_utils_rocm]
|
| 20 |
+
backend = "rocm"
|
| 21 |
+
rocm-archs = [
|
| 22 |
+
"gfx906",
|
| 23 |
+
"gfx908",
|
| 24 |
+
"gfx90a",
|
| 25 |
+
"gfx940",
|
| 26 |
+
"gfx941",
|
| 27 |
+
"gfx942",
|
| 28 |
+
"gfx1030",
|
| 29 |
+
"gfx1100",
|
| 30 |
+
"gfx1101",
|
| 31 |
+
]
|
| 32 |
+
src = [
|
| 33 |
+
"cuda-utils/cuda_utils.h",
|
| 34 |
+
"cuda-utils/cuda_utils_kernels.cu",
|
| 35 |
+
]
|
| 36 |
+
depends = ["torch"]
|
| 37 |
|
| 38 |
[kernel.paged_attention]
|
| 39 |
backend = "cuda"
|
|
|
|
| 58 |
include = [ "cuda-utils", "paged-attention" ]
|
| 59 |
depends = [ "torch" ]
|
| 60 |
|
| 61 |
+
[kernel.paged_attention_rocm]
|
| 62 |
+
backend = "rocm"
|
| 63 |
+
rocm-archs = [
|
| 64 |
+
"gfx906",
|
| 65 |
+
"gfx908",
|
| 66 |
+
"gfx90a",
|
| 67 |
+
"gfx940",
|
| 68 |
+
"gfx941",
|
| 69 |
+
"gfx942",
|
| 70 |
+
"gfx1030",
|
| 71 |
+
"gfx1100",
|
| 72 |
+
"gfx1101",
|
| 73 |
+
]
|
| 74 |
+
src = [
|
| 75 |
+
"cuda-utils/cuda_utils.h",
|
| 76 |
+
"paged-attention/attention/attention_dtypes.h",
|
| 77 |
+
"paged-attention/attention/attention_generic.cuh",
|
| 78 |
+
"paged-attention/attention/attention_kernels.cuh",
|
| 79 |
+
"paged-attention/attention/attention_utils.cuh",
|
| 80 |
+
"paged-attention/attention/dtype_bfloat16.cuh",
|
| 81 |
+
"paged-attention/attention/dtype_float16.cuh",
|
| 82 |
+
"paged-attention/attention/dtype_float32.cuh",
|
| 83 |
+
"paged-attention/attention/dtype_fp8.cuh",
|
| 84 |
+
"paged-attention/attention/paged_attention_v1.cu",
|
| 85 |
+
"paged-attention/attention/paged_attention_v2.cu",
|
| 86 |
+
"paged-attention/cache_kernels.cu",
|
| 87 |
+
"paged-attention/cuda_compat.h",
|
| 88 |
+
"paged-attention/dispatch_utils.h",
|
| 89 |
+
"paged-attention/quantization/fp8/amd/quant_utils.cuh",
|
| 90 |
+
"paged-attention/quantization/fp8/nvidia/quant_utils.cuh",
|
| 91 |
+
]
|
| 92 |
+
include = [ "cuda-utils", "paged-attention" ]
|
| 93 |
+
depends = [ "torch" ]
|
| 94 |
|
| 95 |
[kernel.paged_attention_metal]
|
| 96 |
backend = "metal"
|
flake.lock
CHANGED
|
@@ -73,11 +73,11 @@
|
|
| 73 |
"nixpkgs": "nixpkgs"
|
| 74 |
},
|
| 75 |
"locked": {
|
| 76 |
-
"lastModified":
|
| 77 |
-
"narHash": "sha256-
|
| 78 |
"owner": "huggingface",
|
| 79 |
"repo": "hf-nix",
|
| 80 |
-
"rev": "
|
| 81 |
"type": "github"
|
| 82 |
},
|
| 83 |
"original": {
|
|
@@ -98,11 +98,11 @@
|
|
| 98 |
]
|
| 99 |
},
|
| 100 |
"locked": {
|
| 101 |
-
"lastModified":
|
| 102 |
-
"narHash": "sha256
|
| 103 |
"owner": "huggingface",
|
| 104 |
"repo": "kernel-builder",
|
| 105 |
-
"rev": "
|
| 106 |
"type": "github"
|
| 107 |
},
|
| 108 |
"original": {
|
|
|
|
| 73 |
"nixpkgs": "nixpkgs"
|
| 74 |
},
|
| 75 |
"locked": {
|
| 76 |
+
"lastModified": 1751968576,
|
| 77 |
+
"narHash": "sha256-cmKrlWpNTG/hq1bCaHXfbdm9T+Y6V+5//EHAVc1TLBE=",
|
| 78 |
"owner": "huggingface",
|
| 79 |
"repo": "hf-nix",
|
| 80 |
+
"rev": "3fcd1e1b46da91b6691261640ffd6b7123d0cb9e",
|
| 81 |
"type": "github"
|
| 82 |
},
|
| 83 |
"original": {
|
|
|
|
| 98 |
]
|
| 99 |
},
|
| 100 |
"locked": {
|
| 101 |
+
"lastModified": 1753256281,
|
| 102 |
+
"narHash": "sha256-CfL3Fyf2ih7OtyL7ScZUCwOeCj+gjlRyPykhR6Zbt3I=",
|
| 103 |
"owner": "huggingface",
|
| 104 |
"repo": "kernel-builder",
|
| 105 |
+
"rev": "dcbbdf2d3c8e78b27321b205b2c9d67ffce6a706",
|
| 106 |
"type": "github"
|
| 107 |
},
|
| 108 |
"original": {
|