EricB HF Staff commited on
Commit
8276f6a
·
1 Parent(s): 85dbf0e

Update flake.nix

Browse files
build.toml CHANGED
@@ -1,5 +1,5 @@
1
  [general]
2
- name = "paged_attention"
3
  universal = false
4
 
5
  [torch]
 
1
  [general]
2
+ name = "kernels_paged_attention_metal"
3
  universal = false
4
 
5
  [torch]
build/torch27-metal-arm64-darwin/kernels_paged_attention_metal/__init__.py DELETED
@@ -1,21 +0,0 @@
1
- from ._custom_ops import (
2
- convert_fp8,
3
- copy_blocks,
4
- paged_attention_v1,
5
- paged_attention_v2,
6
- reshape_and_cache,
7
- reshape_and_cache_flash,
8
- swap_blocks,
9
- )
10
- from ._ops import ops
11
-
12
- __all__ = [
13
- "convert_fp8",
14
- "copy_blocks",
15
- "ops",
16
- "paged_attention_v1",
17
- "paged_attention_v2",
18
- "reshape_and_cache",
19
- "reshape_and_cache_flash",
20
- "swap_blocks",
21
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch27-metal-arm64-darwin/kernels_paged_attention_metal/_custom_ops.py DELETED
@@ -1,173 +0,0 @@
1
- from typing import List, Optional
2
-
3
- import torch
4
-
5
- from ._ops import ops
6
-
7
-
8
- # page attention ops
9
- def paged_attention_v1(
10
- out: torch.Tensor,
11
- query: torch.Tensor,
12
- key_cache: torch.Tensor,
13
- value_cache: torch.Tensor,
14
- num_kv_heads: int,
15
- scale: float,
16
- block_tables: torch.Tensor,
17
- seq_lens: torch.Tensor,
18
- block_size: int,
19
- max_seq_len: int,
20
- alibi_slopes: Optional[torch.Tensor],
21
- kv_cache_dtype: str,
22
- k_scale: float,
23
- v_scale: float,
24
- tp_rank: int = 0,
25
- blocksparse_local_blocks: int = 0,
26
- blocksparse_vert_stride: int = 0,
27
- blocksparse_block_size: int = 64,
28
- blocksparse_head_sliding_step: int = 0,
29
- ) -> None:
30
- ops.paged_attention_v1(
31
- out,
32
- query,
33
- key_cache,
34
- value_cache,
35
- num_kv_heads,
36
- scale,
37
- block_tables,
38
- seq_lens,
39
- block_size,
40
- max_seq_len,
41
- alibi_slopes,
42
- kv_cache_dtype,
43
- k_scale,
44
- v_scale,
45
- tp_rank,
46
- blocksparse_local_blocks,
47
- blocksparse_vert_stride,
48
- blocksparse_block_size,
49
- blocksparse_head_sliding_step,
50
- )
51
-
52
-
53
- def paged_attention_v2(
54
- out: torch.Tensor,
55
- exp_sum: torch.Tensor,
56
- max_logits: torch.Tensor,
57
- tmp_out: torch.Tensor,
58
- query: torch.Tensor,
59
- key_cache: torch.Tensor,
60
- value_cache: torch.Tensor,
61
- num_kv_heads: int,
62
- scale: float,
63
- block_tables: torch.Tensor,
64
- seq_lens: torch.Tensor,
65
- block_size: int,
66
- max_seq_len: int,
67
- alibi_slopes: Optional[torch.Tensor],
68
- kv_cache_dtype: str,
69
- k_scale: float,
70
- v_scale: float,
71
- tp_rank: int = 0,
72
- blocksparse_local_blocks: int = 0,
73
- blocksparse_vert_stride: int = 0,
74
- blocksparse_block_size: int = 64,
75
- blocksparse_head_sliding_step: int = 0,
76
- ) -> None:
77
- ops.paged_attention_v2(
78
- out,
79
- exp_sum,
80
- max_logits,
81
- tmp_out,
82
- query,
83
- key_cache,
84
- value_cache,
85
- num_kv_heads,
86
- scale,
87
- block_tables,
88
- seq_lens,
89
- block_size,
90
- max_seq_len,
91
- alibi_slopes,
92
- kv_cache_dtype,
93
- k_scale,
94
- v_scale,
95
- tp_rank,
96
- blocksparse_local_blocks,
97
- blocksparse_vert_stride,
98
- blocksparse_block_size,
99
- blocksparse_head_sliding_step,
100
- )
101
-
102
-
103
- def reshape_and_cache(
104
- key: torch.Tensor,
105
- value: torch.Tensor,
106
- key_cache: torch.Tensor,
107
- value_cache: torch.Tensor,
108
- slot_mapping: torch.Tensor,
109
- kv_cache_dtype: str,
110
- k_scale: float,
111
- v_scale: float,
112
- ) -> None:
113
- ops.reshape_and_cache(
114
- key,
115
- value,
116
- key_cache,
117
- value_cache,
118
- slot_mapping,
119
- kv_cache_dtype,
120
- k_scale,
121
- v_scale,
122
- )
123
-
124
-
125
- def reshape_and_cache_flash(
126
- key: torch.Tensor,
127
- value: torch.Tensor,
128
- key_cache: torch.Tensor,
129
- value_cache: torch.Tensor,
130
- slot_mapping: torch.Tensor,
131
- kv_cache_dtype: str,
132
- k_scale: torch.Tensor,
133
- v_scale: torch.Tensor,
134
- ) -> None:
135
- ops.reshape_and_cache_flash(
136
- key,
137
- value,
138
- key_cache,
139
- value_cache,
140
- slot_mapping,
141
- kv_cache_dtype,
142
- k_scale,
143
- v_scale,
144
- )
145
-
146
-
147
- def copy_blocks(
148
- key_caches: List[torch.Tensor],
149
- value_caches: List[torch.Tensor],
150
- block_mapping: torch.Tensor,
151
- ) -> None:
152
- ops.copy_blocks(key_caches, value_caches, block_mapping)
153
-
154
-
155
- def swap_blocks(
156
- src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
- ) -> None:
158
- ops.swap_blocks(src, dst, block_mapping)
159
-
160
-
161
- def convert_fp8(
162
- output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
- ) -> None:
164
- ops.convert_fp8(output, input, scale, kv_dtype)
165
-
166
-
167
- __all__ = [
168
- "convert_fp8",
169
- "paged_attention_v1",
170
- "paged_attention_v2",
171
- "reshape_and_cache",
172
- "copy_blocks",
173
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch27-metal-arm64-darwin/kernels_paged_attention_metal/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _paged_attention_2ee8d65
3
- ops = torch.ops._paged_attention_2ee8d65
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_paged_attention_2ee8d65::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch27-metal-arm64-darwin/kernels_paged_attention_metal/_paged_attention_2ee8d65.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b732f1413d0dcaf27b13b139d95f7aa17aefc238baac9a7b26e2ba461ef69de8
3
- size 214800
 
 
 
 
build/torch27-metal-arm64-darwin/kernels_paged_attention_metal/_paged_attention_2ee8d65.metallib DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c46eaf21c96da70c5227b2566308a8ef73ae09abf303278f40070dd4326ba0be
3
- size 4999876
 
 
 
 
build/torch27-metal-arm64-darwin/kernels_paged_attention_metal/platforms.py DELETED
@@ -1,92 +0,0 @@
1
- import os
2
- import random
3
- from abc import ABC, abstractmethod
4
- from functools import lru_cache, wraps
5
- from typing import Callable, ParamSpec, TypeVar
6
-
7
- import numpy as np
8
- import torch
9
-
10
- IS_ROCM = torch.version.hip is not None
11
- IS_MPS = torch.backends.mps.is_available()
12
-
13
-
14
- class Platform(ABC):
15
- @classmethod
16
- def seed_everything(cls, seed: int) -> None:
17
- """
18
- Set the seed of each random module.
19
- `torch.manual_seed` will set seed on all devices.
20
-
21
- Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
22
- """
23
- random.seed(seed)
24
- np.random.seed(seed)
25
- torch.manual_seed(seed)
26
-
27
- @abstractmethod
28
- def get_device_name(self, device_id: int = 0) -> str: ...
29
-
30
- @abstractmethod
31
- def is_cuda(self) -> bool: ...
32
-
33
- @abstractmethod
34
- def is_rocm(self) -> bool: ...
35
-
36
- @abstractmethod
37
- def is_mps(self) -> bool: ...
38
-
39
-
40
- class CudaPlatform(Platform):
41
- @classmethod
42
- @lru_cache(maxsize=8)
43
- def get_device_name(cls, device_id: int = 0) -> str:
44
- return torch.cuda.get_device_name(0)
45
-
46
- def is_cuda(self) -> bool:
47
- return True
48
-
49
- def is_rocm(self) -> bool:
50
- return False
51
-
52
- def is_mps(self) -> bool:
53
- return False
54
-
55
-
56
- class RocmPlatform(Platform):
57
- @classmethod
58
- @lru_cache(maxsize=8)
59
- def get_device_name(cls, device_id: int = 0) -> str:
60
- return torch.cuda.get_device_name(device_id)
61
-
62
- def is_cuda(self) -> bool:
63
- return False
64
-
65
- def is_rocm(self) -> bool:
66
- return True
67
-
68
- def is_mps(self) -> bool:
69
- return False
70
-
71
-
72
- class MpsPlatform(Platform):
73
- @classmethod
74
- @lru_cache(maxsize=8)
75
- def get_device_name(cls, device_id: int = 0) -> str:
76
- return torch.cuda.get_device_name(device_id)
77
-
78
- def is_cuda(self) -> bool:
79
- return False
80
-
81
- def is_rocm(self) -> bool:
82
- return False
83
-
84
- def is_mps(self) -> bool:
85
- return True
86
-
87
- current_platform = (
88
- RocmPlatform() if IS_ROCM else
89
- MpsPlatform() if IS_MPS else
90
- CudaPlatform() if torch.cuda.is_available() else
91
- None
92
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.lock DELETED
@@ -1,168 +0,0 @@
1
- {
2
- "nodes": {
3
- "flake-compat": {
4
- "locked": {
5
- "lastModified": 1747046372,
6
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
- "owner": "edolstra",
8
- "repo": "flake-compat",
9
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
- "type": "github"
11
- },
12
- "original": {
13
- "owner": "edolstra",
14
- "repo": "flake-compat",
15
- "type": "github"
16
- }
17
- },
18
- "flake-compat_2": {
19
- "locked": {
20
- "lastModified": 1733328505,
21
- "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
- "owner": "edolstra",
23
- "repo": "flake-compat",
24
- "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
- "type": "github"
26
- },
27
- "original": {
28
- "owner": "edolstra",
29
- "repo": "flake-compat",
30
- "type": "github"
31
- }
32
- },
33
- "flake-utils": {
34
- "inputs": {
35
- "systems": "systems"
36
- },
37
- "locked": {
38
- "lastModified": 1731533236,
39
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
- "owner": "numtide",
41
- "repo": "flake-utils",
42
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
- "type": "github"
44
- },
45
- "original": {
46
- "owner": "numtide",
47
- "repo": "flake-utils",
48
- "type": "github"
49
- }
50
- },
51
- "flake-utils_2": {
52
- "inputs": {
53
- "systems": "systems_2"
54
- },
55
- "locked": {
56
- "lastModified": 1731533236,
57
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
- "owner": "numtide",
59
- "repo": "flake-utils",
60
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
- "type": "github"
62
- },
63
- "original": {
64
- "owner": "numtide",
65
- "repo": "flake-utils",
66
- "type": "github"
67
- }
68
- },
69
- "hf-nix": {
70
- "inputs": {
71
- "flake-compat": "flake-compat_2",
72
- "flake-utils": "flake-utils_2",
73
- "nixpkgs": "nixpkgs"
74
- },
75
- "locked": {
76
- "lastModified": 1750234878,
77
- "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
- "owner": "huggingface",
79
- "repo": "hf-nix",
80
- "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
- "type": "github"
82
- },
83
- "original": {
84
- "owner": "huggingface",
85
- "repo": "hf-nix",
86
- "type": "github"
87
- }
88
- },
89
- "kernel-builder": {
90
- "inputs": {
91
- "flake-compat": "flake-compat",
92
- "flake-utils": "flake-utils",
93
- "hf-nix": "hf-nix",
94
- "nixpkgs": [
95
- "kernel-builder",
96
- "hf-nix",
97
- "nixpkgs"
98
- ]
99
- },
100
- "locked": {
101
- "lastModified": 1750430211,
102
- "narHash": "sha256-QEaSxFNjcqzBBB1WVYFBJ0/Uuol2k1kDSpuyoz/Slzc=",
103
- "owner": "huggingface",
104
- "repo": "kernel-builder",
105
- "rev": "3616c38e5c1fc6cc382510eff12b9d54d6797e84",
106
- "type": "github"
107
- },
108
- "original": {
109
- "owner": "huggingface",
110
- "repo": "kernel-builder",
111
- "type": "github"
112
- }
113
- },
114
- "nixpkgs": {
115
- "locked": {
116
- "lastModified": 1747820358,
117
- "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
- "owner": "danieldk",
119
- "repo": "nixpkgs",
120
- "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
- "type": "github"
122
- },
123
- "original": {
124
- "owner": "danieldk",
125
- "ref": "cudatoolkit-12.9-kernel-builder",
126
- "repo": "nixpkgs",
127
- "type": "github"
128
- }
129
- },
130
- "root": {
131
- "inputs": {
132
- "kernel-builder": "kernel-builder"
133
- }
134
- },
135
- "systems": {
136
- "locked": {
137
- "lastModified": 1681028828,
138
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
- "owner": "nix-systems",
140
- "repo": "default",
141
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
- "type": "github"
143
- },
144
- "original": {
145
- "owner": "nix-systems",
146
- "repo": "default",
147
- "type": "github"
148
- }
149
- },
150
- "systems_2": {
151
- "locked": {
152
- "lastModified": 1681028828,
153
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
- "owner": "nix-systems",
155
- "repo": "default",
156
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
- "type": "github"
158
- },
159
- "original": {
160
- "owner": "nix-systems",
161
- "repo": "default",
162
- "type": "github"
163
- }
164
- }
165
- },
166
- "root": "root",
167
- "version": 7
168
- }