Port vLLM attention kernels
Browse files- .gitattributes +1 -0
- README.md +4 -0
- build.toml +44 -0
- cuda-utils/cuda_utils_kernels.cu +29 -0
- flake.nix +14 -0
- paged-attention/attention/attention_dtypes.h +7 -0
- paged-attention/attention/attention_generic.cuh +65 -0
- paged-attention/attention/attention_kernels.cuh +676 -0
- paged-attention/attention/attention_utils.cuh +57 -0
- paged-attention/attention/dtype_bfloat16.cuh +463 -0
- paged-attention/attention/dtype_float16.cuh +504 -0
- paged-attention/attention/dtype_float32.cuh +251 -0
- paged-attention/attention/dtype_fp8.cuh +41 -0
- paged-attention/attention/paged_attention_v1.cu +196 -0
- paged-attention/attention/paged_attention_v2.cu +206 -0
- paged-attention/cache_kernels.cu +419 -0
- paged-attention/cuda_compat.h +49 -0
- paged-attention/dispatch_utils.h +49 -0
- paged-attention/quantization/fp8/amd/hip_float8.h +137 -0
- paged-attention/quantization/fp8/amd/hip_float8_impl.h +316 -0
- paged-attention/quantization/fp8/amd/quant_utils.cuh +577 -0
- paged-attention/quantization/fp8/nvidia/quant_utils.cuh +573 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/allclose_default.py +14 -0
- tests/kernels/conftest.py +158 -0
- tests/kernels/test_attention.py +418 -0
- tests/kernels/test_cache.py +486 -0
- tests/kernels/utils.py +92 -0
- torch-ext/attention/__init__.py +21 -0
- torch-ext/attention/_custom_ops.py +173 -0
- torch-ext/attention/platforms.py +62 -0
- torch-ext/registration.h +27 -0
- torch-ext/torch_binding.cpp +95 -0
- torch-ext/torch_binding.h +56 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,7 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
---
|
| 4 |
+
|
| 5 |
+
## attention
|
| 6 |
+
|
| 7 |
+
Paged attention kernels from [vLLM](https://github.com/vllm-project/).
|
build.toml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
version = "0.0.1"
|
| 3 |
+
|
| 4 |
+
[torch]
|
| 5 |
+
name = "attention"
|
| 6 |
+
src = [
|
| 7 |
+
"torch-ext/registration.h",
|
| 8 |
+
"torch-ext/torch_binding.cpp",
|
| 9 |
+
"torch-ext/torch_binding.h"
|
| 10 |
+
]
|
| 11 |
+
pyroot = "torch-ext"
|
| 12 |
+
|
| 13 |
+
[kernel.cuda_utils]
|
| 14 |
+
capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
| 15 |
+
src = [
|
| 16 |
+
"cuda-utils/cuda_utils_kernels.cu",
|
| 17 |
+
]
|
| 18 |
+
depends = []
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
[kernel.paged_attention]
|
| 22 |
+
capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
| 23 |
+
src = [
|
| 24 |
+
"paged-attention/attention/attention_dtypes.h",
|
| 25 |
+
"paged-attention/attention/attention_generic.cuh",
|
| 26 |
+
"paged-attention/attention/attention_kernels.cuh",
|
| 27 |
+
"paged-attention/attention/attention_utils.cuh",
|
| 28 |
+
"paged-attention/attention/dtype_bfloat16.cuh",
|
| 29 |
+
"paged-attention/attention/dtype_float16.cuh",
|
| 30 |
+
"paged-attention/attention/dtype_float32.cuh",
|
| 31 |
+
"paged-attention/attention/dtype_fp8.cuh",
|
| 32 |
+
"paged-attention/attention/paged_attention_v1.cu",
|
| 33 |
+
"paged-attention/attention/paged_attention_v2.cu",
|
| 34 |
+
"paged-attention/cache_kernels.cu",
|
| 35 |
+
"paged-attention/cuda_compat.h",
|
| 36 |
+
"paged-attention/dispatch_utils.h",
|
| 37 |
+
"paged-attention/quantization/fp8/amd/hip_float8.h",
|
| 38 |
+
"paged-attention/quantization/fp8/amd/hip_float8_impl.h",
|
| 39 |
+
"paged-attention/quantization/fp8/amd/quant_utils.cuh",
|
| 40 |
+
"paged-attention/quantization/fp8/nvidia/quant_utils.cuh",
|
| 41 |
+
]
|
| 42 |
+
include = [ "." ]
|
| 43 |
+
depends = [ "torch" ]
|
| 44 |
+
|
cuda-utils/cuda_utils_kernels.cu
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifdef USE_ROCM
|
| 2 |
+
#include <hip/hip_runtime.h>
|
| 3 |
+
#include <hip/hip_runtime_api.h>
|
| 4 |
+
#endif
|
| 5 |
+
int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
|
| 6 |
+
int device, value;
|
| 7 |
+
if (device_id < 0) {
|
| 8 |
+
cudaGetDevice(&device);
|
| 9 |
+
} else {
|
| 10 |
+
device = device_id;
|
| 11 |
+
}
|
| 12 |
+
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
|
| 13 |
+
device);
|
| 14 |
+
return value;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
|
| 18 |
+
int64_t attribute;
|
| 19 |
+
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
| 20 |
+
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
| 21 |
+
|
| 22 |
+
#ifdef USE_ROCM
|
| 23 |
+
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
| 24 |
+
#else
|
| 25 |
+
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
return get_device_attribute(attribute, device_id);
|
| 29 |
+
}
|
flake.nix
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for attention kernels";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "git+ssh://[email protected]/huggingface/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs =
|
| 9 |
+
{
|
| 10 |
+
self,
|
| 11 |
+
kernel-builder,
|
| 12 |
+
}:
|
| 13 |
+
kernel-builder.lib.genFlakeOutputs ./.;
|
| 14 |
+
}
|
paged-attention/attention/attention_dtypes.h
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "attention_generic.cuh"
|
| 4 |
+
#include "dtype_float16.cuh"
|
| 5 |
+
#include "dtype_float32.cuh"
|
| 6 |
+
#include "dtype_bfloat16.cuh"
|
| 7 |
+
#include "dtype_fp8.cuh"
|
paged-attention/attention/attention_generic.cuh
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Adapted from
|
| 3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
| 4 |
+
* Copyright (c) 2023, The vLLM team.
|
| 5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 6 |
+
*
|
| 7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
* you may not use this file except in compliance with the License.
|
| 9 |
+
* You may obtain a copy of the License at
|
| 10 |
+
*
|
| 11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
*
|
| 13 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
* See the License for the specific language governing permissions and
|
| 17 |
+
* limitations under the License.
|
| 18 |
+
*/
|
| 19 |
+
#pragma once
|
| 20 |
+
|
| 21 |
+
#include <stdint.h>
|
| 22 |
+
|
| 23 |
+
namespace vllm {
|
| 24 |
+
|
| 25 |
+
// A vector type to store Q, K, V elements.
|
| 26 |
+
template <typename T, int VEC_SIZE>
|
| 27 |
+
struct Vec {};
|
| 28 |
+
|
| 29 |
+
// A vector type to store FP32 accumulators.
|
| 30 |
+
template <typename T>
|
| 31 |
+
struct FloatVec {};
|
| 32 |
+
|
| 33 |
+
// Template vector operations.
|
| 34 |
+
template <typename Acc, typename A, typename B>
|
| 35 |
+
inline __device__ Acc mul(A a, B b);
|
| 36 |
+
|
| 37 |
+
template <typename T>
|
| 38 |
+
inline __device__ float sum(T v);
|
| 39 |
+
|
| 40 |
+
template <typename T>
|
| 41 |
+
inline __device__ float dot(T a, T b) {
|
| 42 |
+
return sum(mul<T, T, T>(a, b));
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <typename A, typename T>
|
| 46 |
+
inline __device__ float dot(T a, T b) {
|
| 47 |
+
return sum(mul<A, T, T>(a, b));
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
template <typename T>
|
| 51 |
+
inline __device__ void zero(T& dst) {
|
| 52 |
+
constexpr int WORDS = sizeof(T) / 4;
|
| 53 |
+
union {
|
| 54 |
+
T raw;
|
| 55 |
+
uint32_t words[WORDS];
|
| 56 |
+
} tmp;
|
| 57 |
+
|
| 58 |
+
#pragma unroll
|
| 59 |
+
for (int ii = 0; ii < WORDS; ++ii) {
|
| 60 |
+
tmp.words[ii] = 0u;
|
| 61 |
+
}
|
| 62 |
+
dst = tmp.raw;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
} // namespace vllm
|
paged-attention/attention/attention_kernels.cuh
ADDED
|
@@ -0,0 +1,676 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Adapted from
|
| 3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
| 4 |
+
* Copyright (c) 2023, The vLLM team.
|
| 5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 6 |
+
*
|
| 7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
* you may not use this file except in compliance with the License.
|
| 9 |
+
* You may obtain a copy of the License at
|
| 10 |
+
*
|
| 11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
*
|
| 13 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
* See the License for the specific language governing permissions and
|
| 17 |
+
* limitations under the License.
|
| 18 |
+
*/
|
| 19 |
+
|
| 20 |
+
#include <torch/all.h>
|
| 21 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 22 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 23 |
+
#include <algorithm>
|
| 24 |
+
|
| 25 |
+
#include "attention_dtypes.h"
|
| 26 |
+
#include "attention_utils.cuh"
|
| 27 |
+
|
| 28 |
+
#ifdef USE_ROCM
|
| 29 |
+
#include <hip/hip_bf16.h>
|
| 30 |
+
#include "../quantization/fp8/amd/quant_utils.cuh"
|
| 31 |
+
typedef __hip_bfloat16 __nv_bfloat16;
|
| 32 |
+
#else
|
| 33 |
+
#include "../quantization/fp8/nvidia/quant_utils.cuh"
|
| 34 |
+
#endif
|
| 35 |
+
|
| 36 |
+
#ifndef USE_ROCM
|
| 37 |
+
#define WARP_SIZE 32
|
| 38 |
+
#else
|
| 39 |
+
#define WARP_SIZE warpSize
|
| 40 |
+
#endif
|
| 41 |
+
|
| 42 |
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
| 43 |
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
| 44 |
+
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
| 45 |
+
|
| 46 |
+
namespace vllm {
|
| 47 |
+
|
| 48 |
+
// Utility function for attention softmax.
|
| 49 |
+
template <int NUM_WARPS>
|
| 50 |
+
inline __device__ float block_sum(float* red_smem, float sum) {
|
| 51 |
+
// Decompose the thread index into warp / lane.
|
| 52 |
+
int warp = threadIdx.x / WARP_SIZE;
|
| 53 |
+
int lane = threadIdx.x % WARP_SIZE;
|
| 54 |
+
|
| 55 |
+
// Compute the sum per warp.
|
| 56 |
+
#pragma unroll
|
| 57 |
+
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
| 58 |
+
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// Warp leaders store the data to shared memory.
|
| 62 |
+
if (lane == 0) {
|
| 63 |
+
red_smem[warp] = sum;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
// Make sure the data is in shared memory.
|
| 67 |
+
__syncthreads();
|
| 68 |
+
|
| 69 |
+
// The warps compute the final sums.
|
| 70 |
+
if (lane < NUM_WARPS) {
|
| 71 |
+
sum = red_smem[lane];
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
// Parallel reduction inside the warp.
|
| 75 |
+
#pragma unroll
|
| 76 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
| 77 |
+
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// Broadcast to other threads.
|
| 81 |
+
return VLLM_SHFL_SYNC(sum, 0);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// TODO(woosuk): Merge the last two dimensions of the grid.
|
| 85 |
+
// Grid: (num_heads, num_seqs, max_num_partitions).
|
| 86 |
+
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
| 87 |
+
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
| 88 |
+
bool IS_BLOCK_SPARSE,
|
| 89 |
+
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
| 90 |
+
__device__ void paged_attention_kernel(
|
| 91 |
+
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
| 92 |
+
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
| 93 |
+
// max_num_partitions]
|
| 94 |
+
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
| 95 |
+
// head_size]
|
| 96 |
+
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
| 97 |
+
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
| 98 |
+
// head_size/x, block_size, x]
|
| 99 |
+
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
| 100 |
+
// head_size, block_size]
|
| 101 |
+
const int num_kv_heads, // [num_heads]
|
| 102 |
+
const float scale,
|
| 103 |
+
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
| 104 |
+
const int* __restrict__ seq_lens, // [num_seqs]
|
| 105 |
+
const int max_num_blocks_per_seq,
|
| 106 |
+
const float* __restrict__ alibi_slopes, // [num_heads]
|
| 107 |
+
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
| 108 |
+
const float* k_scale, const float* v_scale, const int tp_rank,
|
| 109 |
+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
| 110 |
+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
| 111 |
+
const int seq_idx = blockIdx.y;
|
| 112 |
+
const int partition_idx = blockIdx.z;
|
| 113 |
+
const int max_num_partitions = gridDim.z;
|
| 114 |
+
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
|
| 115 |
+
const int seq_len = seq_lens[seq_idx];
|
| 116 |
+
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
|
| 117 |
+
// No work to do. Terminate the thread block.
|
| 118 |
+
return;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
| 122 |
+
const int num_blocks_per_partition =
|
| 123 |
+
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
|
| 124 |
+
|
| 125 |
+
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
| 126 |
+
const int start_block_idx =
|
| 127 |
+
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
| 128 |
+
const int end_block_idx =
|
| 129 |
+
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
|
| 130 |
+
const int num_blocks = end_block_idx - start_block_idx;
|
| 131 |
+
|
| 132 |
+
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
| 133 |
+
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
| 134 |
+
const int end_token_idx =
|
| 135 |
+
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
|
| 136 |
+
const int num_tokens = end_token_idx - start_token_idx;
|
| 137 |
+
|
| 138 |
+
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
| 139 |
+
constexpr int NUM_THREAD_GROUPS =
|
| 140 |
+
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
|
| 141 |
+
// divides NUM_THREADS
|
| 142 |
+
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
| 143 |
+
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
|
| 144 |
+
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
| 145 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
| 146 |
+
const int thread_idx = threadIdx.x;
|
| 147 |
+
const int warp_idx = thread_idx / WARP_SIZE;
|
| 148 |
+
const int lane = thread_idx % WARP_SIZE;
|
| 149 |
+
|
| 150 |
+
const int head_idx = blockIdx.x;
|
| 151 |
+
const int num_heads = gridDim.x;
|
| 152 |
+
const int num_queries_per_kv = num_heads / num_kv_heads;
|
| 153 |
+
const int kv_head_idx = head_idx / num_queries_per_kv;
|
| 154 |
+
const float alibi_slope =
|
| 155 |
+
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
| 156 |
+
|
| 157 |
+
// A vector type to store a part of a key or a query.
|
| 158 |
+
// The vector size is configured in such a way that the threads in a thread
|
| 159 |
+
// group fetch or compute 16 bytes at a time. For example, if the size of a
|
| 160 |
+
// thread group is 4 and the data type is half, then the vector size is 16 /
|
| 161 |
+
// (4 * sizeof(half)) == 2.
|
| 162 |
+
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
| 163 |
+
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
| 164 |
+
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
| 165 |
+
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
| 166 |
+
|
| 167 |
+
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
| 168 |
+
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
| 169 |
+
|
| 170 |
+
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
|
| 171 |
+
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
|
| 172 |
+
|
| 173 |
+
// Load the query to registers.
|
| 174 |
+
// Each thread in a thread group has a different part of the query.
|
| 175 |
+
// For example, if the the thread group size is 4, then the first thread in
|
| 176 |
+
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
| 177 |
+
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
|
| 178 |
+
// q is split from a qkv tensor, it may not be contiguous.
|
| 179 |
+
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
| 180 |
+
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
| 181 |
+
#pragma unroll
|
| 182 |
+
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
|
| 183 |
+
i += NUM_THREAD_GROUPS) {
|
| 184 |
+
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
| 185 |
+
q_vecs[thread_group_offset][i] =
|
| 186 |
+
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
| 187 |
+
}
|
| 188 |
+
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
|
| 189 |
+
// memory wall right before we use q_vecs
|
| 190 |
+
|
| 191 |
+
// Memory planning.
|
| 192 |
+
extern __shared__ char shared_mem[];
|
| 193 |
+
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
|
| 194 |
+
float* logits = reinterpret_cast<float*>(shared_mem);
|
| 195 |
+
// Workspace for reduction.
|
| 196 |
+
__shared__ float red_smem[2 * NUM_WARPS];
|
| 197 |
+
|
| 198 |
+
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
| 199 |
+
// Each thread group fetches x elements from the key at a time.
|
| 200 |
+
constexpr int x = 16 / sizeof(cache_t);
|
| 201 |
+
float qk_max = -FLT_MAX;
|
| 202 |
+
|
| 203 |
+
// Iterate over the key blocks.
|
| 204 |
+
// Each warp fetches a block of keys for each iteration.
|
| 205 |
+
// Each thread group in a warp fetches a key from the block, and computes
|
| 206 |
+
// dot product with the query.
|
| 207 |
+
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
| 208 |
+
|
| 209 |
+
// blocksparse specific vars
|
| 210 |
+
int bs_block_offset;
|
| 211 |
+
int q_bs_block_id;
|
| 212 |
+
if constexpr (IS_BLOCK_SPARSE) {
|
| 213 |
+
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
|
| 214 |
+
// blocksparse_block_size);
|
| 215 |
+
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
|
| 216 |
+
if (blocksparse_head_sliding_step >= 0)
|
| 217 |
+
// sliding on q heads
|
| 218 |
+
bs_block_offset =
|
| 219 |
+
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
|
| 220 |
+
else
|
| 221 |
+
// sliding on kv heads
|
| 222 |
+
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
|
| 223 |
+
(-blocksparse_head_sliding_step) +
|
| 224 |
+
1;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
| 228 |
+
block_idx += NUM_WARPS) {
|
| 229 |
+
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
| 230 |
+
// int64 because int32 can lead to overflow when this variable is multiplied
|
| 231 |
+
// by large numbers (e.g., kv_block_stride).
|
| 232 |
+
// For blocksparse attention: skip computation on blocks that are not
|
| 233 |
+
// attended
|
| 234 |
+
if constexpr (IS_BLOCK_SPARSE) {
|
| 235 |
+
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
|
| 236 |
+
const bool is_remote =
|
| 237 |
+
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
|
| 238 |
+
const bool is_local =
|
| 239 |
+
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
|
| 240 |
+
if (!is_remote && !is_local) {
|
| 241 |
+
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
| 242 |
+
const int physical_block_offset =
|
| 243 |
+
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
| 244 |
+
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
| 245 |
+
|
| 246 |
+
if (thread_group_offset == 0) {
|
| 247 |
+
// NOTE(linxihui): assign very large number to skipped tokens to
|
| 248 |
+
// avoid contribution to the sumexp softmax normalizer. This will
|
| 249 |
+
// not be used at computing sum(softmax*v) as the blocks will be
|
| 250 |
+
// skipped.
|
| 251 |
+
logits[token_idx - start_token_idx] = -FLT_MAX;
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
continue;
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
const int64_t physical_block_number =
|
| 258 |
+
static_cast<int64_t>(block_table[block_idx]);
|
| 259 |
+
|
| 260 |
+
// Load a key to registers.
|
| 261 |
+
// Each thread in a thread group has a different part of the key.
|
| 262 |
+
// For example, if the the thread group size is 4, then the first thread in
|
| 263 |
+
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
|
| 264 |
+
// has 1, 5, 9, ... th vectors of the key, and so on.
|
| 265 |
+
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
| 266 |
+
const int physical_block_offset =
|
| 267 |
+
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
| 268 |
+
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
| 269 |
+
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
| 270 |
+
|
| 271 |
+
#pragma unroll
|
| 272 |
+
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
| 273 |
+
const cache_t* k_ptr =
|
| 274 |
+
k_cache + physical_block_number * kv_block_stride +
|
| 275 |
+
kv_head_idx * kv_head_stride + physical_block_offset * x;
|
| 276 |
+
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
| 277 |
+
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
| 278 |
+
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
| 279 |
+
|
| 280 |
+
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
| 281 |
+
k_vecs[j] = *reinterpret_cast<const K_vec*>(
|
| 282 |
+
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
| 283 |
+
} else {
|
| 284 |
+
// Vector conversion from Quant_vec to K_vec.
|
| 285 |
+
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
| 286 |
+
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
| 287 |
+
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
| 288 |
+
k_vec_quant, *k_scale);
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
// Compute dot product.
|
| 293 |
+
// This includes a reduction across the threads in the same thread group.
|
| 294 |
+
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
|
| 295 |
+
q_vecs[thread_group_offset], k_vecs);
|
| 296 |
+
// Add the ALiBi bias if slopes are given.
|
| 297 |
+
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
|
| 298 |
+
|
| 299 |
+
if (thread_group_offset == 0) {
|
| 300 |
+
// Store the partial reductions to shared memory.
|
| 301 |
+
// NOTE(woosuk): It is required to zero out the masked logits.
|
| 302 |
+
const bool mask = token_idx >= seq_len;
|
| 303 |
+
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
| 304 |
+
// Update the max value.
|
| 305 |
+
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
| 306 |
+
}
|
| 307 |
+
}
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
// Perform reduction across the threads in the same warp to get the
|
| 311 |
+
// max qk value for each "warp" (not across the thread block yet).
|
| 312 |
+
// The 0-th thread of each thread group already has its max qk value.
|
| 313 |
+
#pragma unroll
|
| 314 |
+
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
| 315 |
+
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
| 316 |
+
}
|
| 317 |
+
if (lane == 0) {
|
| 318 |
+
red_smem[warp_idx] = qk_max;
|
| 319 |
+
}
|
| 320 |
+
__syncthreads();
|
| 321 |
+
|
| 322 |
+
// TODO(woosuk): Refactor this part.
|
| 323 |
+
// Get the max qk value for the sequence.
|
| 324 |
+
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
| 325 |
+
#pragma unroll
|
| 326 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
| 327 |
+
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
| 328 |
+
}
|
| 329 |
+
// Broadcast the max qk value to all threads.
|
| 330 |
+
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
| 331 |
+
|
| 332 |
+
// Get the sum of the exp values.
|
| 333 |
+
float exp_sum = 0.f;
|
| 334 |
+
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
| 335 |
+
float val = __expf(logits[i] - qk_max);
|
| 336 |
+
logits[i] = val;
|
| 337 |
+
exp_sum += val;
|
| 338 |
+
}
|
| 339 |
+
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
| 340 |
+
|
| 341 |
+
// Compute softmax.
|
| 342 |
+
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
| 343 |
+
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
| 344 |
+
logits[i] *= inv_sum;
|
| 345 |
+
}
|
| 346 |
+
__syncthreads();
|
| 347 |
+
|
| 348 |
+
// If partitioning is enabled, store the max logit and exp_sum.
|
| 349 |
+
if (USE_PARTITIONING && thread_idx == 0) {
|
| 350 |
+
float* max_logits_ptr = max_logits +
|
| 351 |
+
seq_idx * num_heads * max_num_partitions +
|
| 352 |
+
head_idx * max_num_partitions + partition_idx;
|
| 353 |
+
*max_logits_ptr = qk_max;
|
| 354 |
+
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
|
| 355 |
+
head_idx * max_num_partitions + partition_idx;
|
| 356 |
+
*exp_sums_ptr = exp_sum;
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
// Each thread will fetch 16 bytes from the value cache at a time.
|
| 360 |
+
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
| 361 |
+
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
| 362 |
+
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
| 363 |
+
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
| 364 |
+
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
| 365 |
+
|
| 366 |
+
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
| 367 |
+
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
| 368 |
+
constexpr int NUM_ROWS_PER_THREAD =
|
| 369 |
+
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
| 370 |
+
|
| 371 |
+
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
| 372 |
+
float accs[NUM_ROWS_PER_THREAD];
|
| 373 |
+
#pragma unroll
|
| 374 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 375 |
+
accs[i] = 0.f;
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
scalar_t zero_value;
|
| 379 |
+
zero(zero_value);
|
| 380 |
+
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
| 381 |
+
block_idx += NUM_WARPS) {
|
| 382 |
+
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
| 383 |
+
// int64 because int32 can lead to overflow when this variable is multiplied
|
| 384 |
+
// by large numbers (e.g., kv_block_stride).
|
| 385 |
+
// For blocksparse attention: skip computation on blocks that are not
|
| 386 |
+
// attended
|
| 387 |
+
if constexpr (IS_BLOCK_SPARSE) {
|
| 388 |
+
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
|
| 389 |
+
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
|
| 390 |
+
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
|
| 391 |
+
continue;
|
| 392 |
+
}
|
| 393 |
+
}
|
| 394 |
+
const int64_t physical_block_number =
|
| 395 |
+
static_cast<int64_t>(block_table[block_idx]);
|
| 396 |
+
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
| 397 |
+
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
| 398 |
+
L_vec logits_vec;
|
| 399 |
+
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
|
| 400 |
+
start_token_idx));
|
| 401 |
+
|
| 402 |
+
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
|
| 403 |
+
kv_head_idx * kv_head_stride;
|
| 404 |
+
#pragma unroll
|
| 405 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 406 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
| 407 |
+
if (row_idx < HEAD_SIZE) {
|
| 408 |
+
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
| 409 |
+
V_vec v_vec;
|
| 410 |
+
|
| 411 |
+
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
|
| 412 |
+
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
| 413 |
+
} else {
|
| 414 |
+
V_quant_vec v_quant_vec =
|
| 415 |
+
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
| 416 |
+
// Vector conversion from V_quant_vec to V_vec.
|
| 417 |
+
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
| 418 |
+
*v_scale);
|
| 419 |
+
}
|
| 420 |
+
if (block_idx == num_seq_blocks - 1) {
|
| 421 |
+
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
| 422 |
+
// context, we should explicitly zero out the values since they may
|
| 423 |
+
// contain NaNs. See
|
| 424 |
+
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
| 425 |
+
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
| 426 |
+
#pragma unroll
|
| 427 |
+
for (int j = 0; j < V_VEC_SIZE; j++) {
|
| 428 |
+
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
|
| 429 |
+
}
|
| 430 |
+
}
|
| 431 |
+
accs[i] += dot(logits_vec, v_vec);
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
// Perform reduction within each warp.
|
| 437 |
+
#pragma unroll
|
| 438 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 439 |
+
float acc = accs[i];
|
| 440 |
+
#pragma unroll
|
| 441 |
+
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
| 442 |
+
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
| 443 |
+
}
|
| 444 |
+
accs[i] = acc;
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
// NOTE(woosuk): A barrier is required because the shared memory space for
|
| 448 |
+
// logits is reused for the output.
|
| 449 |
+
__syncthreads();
|
| 450 |
+
|
| 451 |
+
// Perform reduction across warps.
|
| 452 |
+
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
| 453 |
+
#pragma unroll
|
| 454 |
+
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
| 455 |
+
int mid = i / 2;
|
| 456 |
+
// Upper warps write to shared memory.
|
| 457 |
+
if (warp_idx >= mid && warp_idx < i) {
|
| 458 |
+
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
| 459 |
+
#pragma unroll
|
| 460 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 461 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
| 462 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
| 463 |
+
dst[row_idx] = accs[i];
|
| 464 |
+
}
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
__syncthreads();
|
| 468 |
+
|
| 469 |
+
// Lower warps update the output.
|
| 470 |
+
if (warp_idx < mid) {
|
| 471 |
+
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
| 472 |
+
#pragma unroll
|
| 473 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 474 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
| 475 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
| 476 |
+
accs[i] += src[row_idx];
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
}
|
| 480 |
+
__syncthreads();
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
// Write the final output.
|
| 484 |
+
if (warp_idx == 0) {
|
| 485 |
+
scalar_t* out_ptr =
|
| 486 |
+
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
| 487 |
+
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
|
| 488 |
+
#pragma unroll
|
| 489 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 490 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
| 491 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
| 492 |
+
from_float(*(out_ptr + row_idx), accs[i]);
|
| 493 |
+
}
|
| 494 |
+
}
|
| 495 |
+
}
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
// Grid: (num_heads, num_seqs, 1).
|
| 499 |
+
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
| 500 |
+
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
| 501 |
+
bool IS_BLOCK_SPARSE>
|
| 502 |
+
__global__ void paged_attention_v1_kernel(
|
| 503 |
+
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
| 504 |
+
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
| 505 |
+
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
| 506 |
+
// head_size/x, block_size, x]
|
| 507 |
+
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
| 508 |
+
// head_size, block_size]
|
| 509 |
+
const int num_kv_heads, // [num_heads]
|
| 510 |
+
const float scale,
|
| 511 |
+
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
| 512 |
+
const int* __restrict__ seq_lens, // [num_seqs]
|
| 513 |
+
const int max_num_blocks_per_seq,
|
| 514 |
+
const float* __restrict__ alibi_slopes, // [num_heads]
|
| 515 |
+
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
| 516 |
+
const float* k_scale, const float* v_scale, const int tp_rank,
|
| 517 |
+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
| 518 |
+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
| 519 |
+
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
| 520 |
+
KV_DTYPE, IS_BLOCK_SPARSE>(
|
| 521 |
+
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
| 522 |
+
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
| 523 |
+
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
| 524 |
+
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
|
| 525 |
+
blocksparse_vert_stride, blocksparse_block_size,
|
| 526 |
+
blocksparse_head_sliding_step);
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
// Grid: (num_heads, num_seqs, max_num_partitions).
|
| 530 |
+
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
| 531 |
+
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
| 532 |
+
bool IS_BLOCK_SPARSE,
|
| 533 |
+
int PARTITION_SIZE>
|
| 534 |
+
__global__ void paged_attention_v2_kernel(
|
| 535 |
+
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
| 536 |
+
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
| 537 |
+
// max_num_partitions]
|
| 538 |
+
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
| 539 |
+
// max_num_partitions, head_size]
|
| 540 |
+
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
| 541 |
+
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
| 542 |
+
// head_size/x, block_size, x]
|
| 543 |
+
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
| 544 |
+
// head_size, block_size]
|
| 545 |
+
const int num_kv_heads, // [num_heads]
|
| 546 |
+
const float scale,
|
| 547 |
+
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
| 548 |
+
const int* __restrict__ seq_lens, // [num_seqs]
|
| 549 |
+
const int max_num_blocks_per_seq,
|
| 550 |
+
const float* __restrict__ alibi_slopes, // [num_heads]
|
| 551 |
+
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
| 552 |
+
const float* k_scale, const float* v_scale, const int tp_rank,
|
| 553 |
+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
| 554 |
+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
| 555 |
+
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
| 556 |
+
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
|
| 557 |
+
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
| 558 |
+
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
| 559 |
+
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
|
| 560 |
+
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
|
| 561 |
+
blocksparse_head_sliding_step);
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
// Grid: (num_heads, num_seqs).
|
| 565 |
+
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
|
| 566 |
+
int PARTITION_SIZE>
|
| 567 |
+
__global__ void paged_attention_v2_reduce_kernel(
|
| 568 |
+
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
| 569 |
+
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
| 570 |
+
// max_num_partitions]
|
| 571 |
+
const float* __restrict__ max_logits, // [num_seqs, num_heads,
|
| 572 |
+
// max_num_partitions]
|
| 573 |
+
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
| 574 |
+
// max_num_partitions, head_size]
|
| 575 |
+
const int* __restrict__ seq_lens, // [num_seqs]
|
| 576 |
+
const int max_num_partitions) {
|
| 577 |
+
const int num_heads = gridDim.x;
|
| 578 |
+
const int head_idx = blockIdx.x;
|
| 579 |
+
const int seq_idx = blockIdx.y;
|
| 580 |
+
const int seq_len = seq_lens[seq_idx];
|
| 581 |
+
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
| 582 |
+
if (num_partitions == 1) {
|
| 583 |
+
// No need to reduce. Only copy tmp_out to out.
|
| 584 |
+
scalar_t* out_ptr =
|
| 585 |
+
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
| 586 |
+
const scalar_t* tmp_out_ptr =
|
| 587 |
+
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
| 588 |
+
head_idx * max_num_partitions * HEAD_SIZE;
|
| 589 |
+
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
| 590 |
+
out_ptr[i] = tmp_out_ptr[i];
|
| 591 |
+
}
|
| 592 |
+
// Terminate the thread block.
|
| 593 |
+
return;
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
| 597 |
+
const int warp_idx = threadIdx.x / WARP_SIZE;
|
| 598 |
+
const int lane = threadIdx.x % WARP_SIZE;
|
| 599 |
+
|
| 600 |
+
// Size: 2 * num_partitions.
|
| 601 |
+
extern __shared__ char shared_mem[];
|
| 602 |
+
// Workspace for reduction.
|
| 603 |
+
__shared__ float red_smem[2 * NUM_WARPS];
|
| 604 |
+
|
| 605 |
+
// Load max logits to shared memory.
|
| 606 |
+
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
| 607 |
+
const float* max_logits_ptr = max_logits +
|
| 608 |
+
seq_idx * num_heads * max_num_partitions +
|
| 609 |
+
head_idx * max_num_partitions;
|
| 610 |
+
float max_logit = -FLT_MAX;
|
| 611 |
+
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
| 612 |
+
const float l = max_logits_ptr[i];
|
| 613 |
+
shared_max_logits[i] = l;
|
| 614 |
+
max_logit = fmaxf(max_logit, l);
|
| 615 |
+
}
|
| 616 |
+
__syncthreads();
|
| 617 |
+
|
| 618 |
+
// Get the global max logit.
|
| 619 |
+
// Reduce within the warp.
|
| 620 |
+
#pragma unroll
|
| 621 |
+
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
| 622 |
+
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
| 623 |
+
}
|
| 624 |
+
if (lane == 0) {
|
| 625 |
+
red_smem[warp_idx] = max_logit;
|
| 626 |
+
}
|
| 627 |
+
__syncthreads();
|
| 628 |
+
// Reduce across warps.
|
| 629 |
+
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
| 630 |
+
#pragma unroll
|
| 631 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
| 632 |
+
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
| 633 |
+
}
|
| 634 |
+
// Broadcast the max value to all threads.
|
| 635 |
+
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
| 636 |
+
|
| 637 |
+
// Load rescaled exp sums to shared memory.
|
| 638 |
+
float* shared_exp_sums =
|
| 639 |
+
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
| 640 |
+
const float* exp_sums_ptr = exp_sums +
|
| 641 |
+
seq_idx * num_heads * max_num_partitions +
|
| 642 |
+
head_idx * max_num_partitions;
|
| 643 |
+
float global_exp_sum = 0.0f;
|
| 644 |
+
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
| 645 |
+
float l = shared_max_logits[i];
|
| 646 |
+
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
|
| 647 |
+
global_exp_sum += rescaled_exp_sum;
|
| 648 |
+
shared_exp_sums[i] = rescaled_exp_sum;
|
| 649 |
+
}
|
| 650 |
+
__syncthreads();
|
| 651 |
+
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
|
| 652 |
+
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
| 653 |
+
|
| 654 |
+
// Aggregate tmp_out to out.
|
| 655 |
+
const scalar_t* tmp_out_ptr =
|
| 656 |
+
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
| 657 |
+
head_idx * max_num_partitions * HEAD_SIZE;
|
| 658 |
+
scalar_t* out_ptr =
|
| 659 |
+
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
| 660 |
+
#pragma unroll
|
| 661 |
+
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
| 662 |
+
float acc = 0.0f;
|
| 663 |
+
for (int j = 0; j < num_partitions; ++j) {
|
| 664 |
+
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
|
| 665 |
+
inv_global_exp_sum;
|
| 666 |
+
}
|
| 667 |
+
from_float(out_ptr[i], acc);
|
| 668 |
+
}
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
} // namespace vllm
|
| 672 |
+
|
| 673 |
+
#undef WARP_SIZE
|
| 674 |
+
#undef MAX
|
| 675 |
+
#undef MIN
|
| 676 |
+
#undef DIVIDE_ROUND_UP
|
paged-attention/attention/attention_utils.cuh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Adapted from
|
| 3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
| 4 |
+
* Copyright (c) 2023, The vLLM team.
|
| 5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 6 |
+
*
|
| 7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
* you may not use this file except in compliance with the License.
|
| 9 |
+
* You may obtain a copy of the License at
|
| 10 |
+
*
|
| 11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
*
|
| 13 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
* See the License for the specific language governing permissions and
|
| 17 |
+
* limitations under the License.
|
| 18 |
+
*/
|
| 19 |
+
#pragma once
|
| 20 |
+
|
| 21 |
+
#include "../cuda_compat.h"
|
| 22 |
+
#include "attention_dtypes.h"
|
| 23 |
+
|
| 24 |
+
#include <float.h>
|
| 25 |
+
#include <type_traits>
|
| 26 |
+
|
| 27 |
+
namespace vllm {
|
| 28 |
+
|
| 29 |
+
// Q*K^T operation.
|
| 30 |
+
template <int THREAD_GROUP_SIZE, typename Vec, int N>
|
| 31 |
+
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
| 32 |
+
using A_vec = typename FloatVec<Vec>::Type;
|
| 33 |
+
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
| 34 |
+
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
| 35 |
+
#pragma unroll
|
| 36 |
+
for (int ii = 1; ii < N; ++ii) {
|
| 37 |
+
qk_vec = vllm::fma(q[ii], k[ii], qk_vec);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// Finalize the reduction across lanes.
|
| 41 |
+
float qk = sum(qk_vec);
|
| 42 |
+
#pragma unroll
|
| 43 |
+
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
| 44 |
+
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
|
| 45 |
+
}
|
| 46 |
+
return qk;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
template <typename T, int THREAD_GROUP_SIZE>
|
| 50 |
+
struct Qk_dot {
|
| 51 |
+
template <typename Vec, int N>
|
| 52 |
+
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
|
| 53 |
+
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
|
| 54 |
+
}
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
} // namespace vllm
|
paged-attention/attention/dtype_bfloat16.cuh
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Adapted from
|
| 3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
| 4 |
+
* and
|
| 5 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
| 6 |
+
* Copyright (c) 2023, The vLLM team.
|
| 7 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 8 |
+
*
|
| 9 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
* you may not use this file except in compliance with the License.
|
| 11 |
+
* You may obtain a copy of the License at
|
| 12 |
+
*
|
| 13 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
*
|
| 15 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
* See the License for the specific language governing permissions and
|
| 19 |
+
* limitations under the License.
|
| 20 |
+
*/
|
| 21 |
+
#pragma once
|
| 22 |
+
|
| 23 |
+
#include "attention_generic.cuh"
|
| 24 |
+
#include "dtype_float32.cuh"
|
| 25 |
+
|
| 26 |
+
#ifndef USE_ROCM
|
| 27 |
+
#include <cuda_bf16.h>
|
| 28 |
+
#include <cuda_fp16.h>
|
| 29 |
+
#else
|
| 30 |
+
#include <hip/hip_bf16.h>
|
| 31 |
+
#include <hip/hip_fp16.h>
|
| 32 |
+
|
| 33 |
+
typedef __hip_bfloat162 __nv_bfloat162;
|
| 34 |
+
typedef __hip_bfloat16 __nv_bfloat16;
|
| 35 |
+
#endif
|
| 36 |
+
|
| 37 |
+
#include <stdint.h>
|
| 38 |
+
|
| 39 |
+
namespace vllm {
|
| 40 |
+
|
| 41 |
+
// Define custom BF16 vector data types.
|
| 42 |
+
struct bf16_4_t {
|
| 43 |
+
__nv_bfloat162 x;
|
| 44 |
+
__nv_bfloat162 y;
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
struct bf16_8_t {
|
| 48 |
+
__nv_bfloat162 x;
|
| 49 |
+
__nv_bfloat162 y;
|
| 50 |
+
__nv_bfloat162 z;
|
| 51 |
+
__nv_bfloat162 w;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
// BF16 vector types for Q, K, V.
|
| 55 |
+
template <>
|
| 56 |
+
struct Vec<__nv_bfloat16, 1> {
|
| 57 |
+
using Type = __nv_bfloat16;
|
| 58 |
+
};
|
| 59 |
+
template <>
|
| 60 |
+
struct Vec<__nv_bfloat16, 2> {
|
| 61 |
+
using Type = __nv_bfloat162;
|
| 62 |
+
};
|
| 63 |
+
template <>
|
| 64 |
+
struct Vec<__nv_bfloat16, 4> {
|
| 65 |
+
using Type = bf16_4_t;
|
| 66 |
+
};
|
| 67 |
+
template <>
|
| 68 |
+
struct Vec<__nv_bfloat16, 8> {
|
| 69 |
+
using Type = bf16_8_t;
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
// FP32 accumulator vector types corresponding to Vec.
|
| 73 |
+
template <>
|
| 74 |
+
struct FloatVec<__nv_bfloat16> {
|
| 75 |
+
using Type = float;
|
| 76 |
+
};
|
| 77 |
+
template <>
|
| 78 |
+
struct FloatVec<__nv_bfloat162> {
|
| 79 |
+
using Type = float2;
|
| 80 |
+
};
|
| 81 |
+
template <>
|
| 82 |
+
struct FloatVec<bf16_4_t> {
|
| 83 |
+
using Type = Float4_;
|
| 84 |
+
};
|
| 85 |
+
template <>
|
| 86 |
+
struct FloatVec<bf16_8_t> {
|
| 87 |
+
using Type = Float8_;
|
| 88 |
+
};
|
| 89 |
+
|
| 90 |
+
// Utility functions for type conversions.
|
| 91 |
+
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
| 92 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 93 |
+
assert(false);
|
| 94 |
+
#else
|
| 95 |
+
return __bfloat1622float2(val);
|
| 96 |
+
#endif
|
| 97 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
| 101 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 102 |
+
assert(false);
|
| 103 |
+
#else
|
| 104 |
+
return __bfloat162bfloat162(val);
|
| 105 |
+
#endif
|
| 106 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// Vector addition.
|
| 110 |
+
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
| 111 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 112 |
+
assert(false);
|
| 113 |
+
#else
|
| 114 |
+
#ifndef USE_ROCM
|
| 115 |
+
return a + b;
|
| 116 |
+
#else
|
| 117 |
+
return __hadd(a, b);
|
| 118 |
+
#endif
|
| 119 |
+
#endif
|
| 120 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
| 124 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 125 |
+
assert(false);
|
| 126 |
+
#else
|
| 127 |
+
return __hadd2(a, b);
|
| 128 |
+
#endif
|
| 129 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
|
| 133 |
+
bf16_4_t c;
|
| 134 |
+
c.x = add(a.x, b.x);
|
| 135 |
+
c.y = add(a.y, b.y);
|
| 136 |
+
return c;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
|
| 140 |
+
bf16_8_t c;
|
| 141 |
+
c.x = add(a.x, b.x);
|
| 142 |
+
c.y = add(a.y, b.y);
|
| 143 |
+
c.z = add(a.z, b.z);
|
| 144 |
+
c.w = add(a.w, b.w);
|
| 145 |
+
return c;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
|
| 149 |
+
float2 fa = bf1622float2(a);
|
| 150 |
+
return add(fa, fb);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
|
| 154 |
+
Float4_ fc;
|
| 155 |
+
fc.x = add(a.x, fb.x);
|
| 156 |
+
fc.y = add(a.y, fb.y);
|
| 157 |
+
return fc;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
|
| 161 |
+
Float8_ fc;
|
| 162 |
+
fc.x = add(a.x, fb.x);
|
| 163 |
+
fc.y = add(a.y, fb.y);
|
| 164 |
+
fc.z = add(a.z, fb.z);
|
| 165 |
+
fc.w = add(a.w, fb.w);
|
| 166 |
+
return fc;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// Vector multiplication.
|
| 170 |
+
template <>
|
| 171 |
+
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
| 172 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 173 |
+
assert(false);
|
| 174 |
+
#else
|
| 175 |
+
return __hmul(a, b);
|
| 176 |
+
#endif
|
| 177 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
template <>
|
| 181 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
| 182 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 183 |
+
assert(false);
|
| 184 |
+
#else
|
| 185 |
+
return __hmul2(a, b);
|
| 186 |
+
#endif
|
| 187 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
template <>
|
| 191 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
| 192 |
+
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
template <>
|
| 196 |
+
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
| 197 |
+
bf16_4_t c;
|
| 198 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
| 199 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
| 200 |
+
return c;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
template <>
|
| 204 |
+
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
| 205 |
+
__nv_bfloat162 s = bf162bf162(a);
|
| 206 |
+
bf16_4_t c;
|
| 207 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
| 208 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
| 209 |
+
return c;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
template <>
|
| 213 |
+
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
| 214 |
+
bf16_8_t c;
|
| 215 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
| 216 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
| 217 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
| 218 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
| 219 |
+
return c;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
template <>
|
| 223 |
+
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
| 224 |
+
__nv_bfloat162 s = bf162bf162(a);
|
| 225 |
+
bf16_8_t c;
|
| 226 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
| 227 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
| 228 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
| 229 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
| 230 |
+
return c;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
template <>
|
| 234 |
+
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
| 235 |
+
float fa = __bfloat162float(a);
|
| 236 |
+
float fb = __bfloat162float(b);
|
| 237 |
+
return fa * fb;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
template <>
|
| 241 |
+
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
| 242 |
+
float2 fa = bf1622float2(a);
|
| 243 |
+
float2 fb = bf1622float2(b);
|
| 244 |
+
return mul<float2, float2, float2>(fa, fb);
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
template <>
|
| 248 |
+
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
| 249 |
+
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
template <>
|
| 253 |
+
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
| 254 |
+
Float4_ fc;
|
| 255 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
| 256 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
| 257 |
+
return fc;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
template <>
|
| 261 |
+
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
| 262 |
+
__nv_bfloat162 s = bf162bf162(a);
|
| 263 |
+
Float4_ fc;
|
| 264 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
| 265 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
| 266 |
+
return fc;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
template <>
|
| 270 |
+
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
| 271 |
+
Float8_ fc;
|
| 272 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
| 273 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
| 274 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
| 275 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
| 276 |
+
return fc;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
template <>
|
| 280 |
+
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
| 281 |
+
__nv_bfloat162 s = bf162bf162(a);
|
| 282 |
+
Float8_ fc;
|
| 283 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
| 284 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
| 285 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
| 286 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
| 287 |
+
return fc;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
// Vector fused multiply-add.
|
| 291 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
|
| 292 |
+
__nv_bfloat162 c) {
|
| 293 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 294 |
+
assert(false);
|
| 295 |
+
#else
|
| 296 |
+
return __hfma2(a, b, c);
|
| 297 |
+
#endif
|
| 298 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
| 302 |
+
__nv_bfloat162 c) {
|
| 303 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 304 |
+
assert(false);
|
| 305 |
+
#else
|
| 306 |
+
return __hfma2(bf162bf162(a), b, c);
|
| 307 |
+
#endif
|
| 308 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
|
| 312 |
+
bf16_4_t d;
|
| 313 |
+
d.x = fma(a.x, b.x, c.x);
|
| 314 |
+
d.y = fma(a.y, b.y, c.y);
|
| 315 |
+
return d;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
|
| 319 |
+
__nv_bfloat162 s = bf162bf162(a);
|
| 320 |
+
bf16_4_t d;
|
| 321 |
+
d.x = fma(s, b.x, c.x);
|
| 322 |
+
d.y = fma(s, b.y, c.y);
|
| 323 |
+
return d;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
|
| 327 |
+
bf16_8_t d;
|
| 328 |
+
d.x = fma(a.x, b.x, c.x);
|
| 329 |
+
d.y = fma(a.y, b.y, c.y);
|
| 330 |
+
d.z = fma(a.z, b.z, c.z);
|
| 331 |
+
d.w = fma(a.w, b.w, c.w);
|
| 332 |
+
return d;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
|
| 336 |
+
__nv_bfloat162 s = bf162bf162(a);
|
| 337 |
+
bf16_8_t d;
|
| 338 |
+
d.x = fma(s, b.x, c.x);
|
| 339 |
+
d.y = fma(s, b.y, c.y);
|
| 340 |
+
d.z = fma(s, b.z, c.z);
|
| 341 |
+
d.w = fma(s, b.w, c.w);
|
| 342 |
+
return d;
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
|
| 346 |
+
return __bfloat162float(a) * __bfloat162float(b) + fc;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
|
| 350 |
+
float2 fa = bf1622float2(a);
|
| 351 |
+
float2 fb = bf1622float2(b);
|
| 352 |
+
return fma(fa, fb, fc);
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
|
| 356 |
+
return fma(bf162bf162(a), b, fc);
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
|
| 360 |
+
Float4_ fd;
|
| 361 |
+
fd.x = fma(a.x, b.x, fc.x);
|
| 362 |
+
fd.y = fma(a.y, b.y, fc.y);
|
| 363 |
+
return fd;
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
|
| 367 |
+
__nv_bfloat162 s = bf162bf162(a);
|
| 368 |
+
Float4_ fd;
|
| 369 |
+
fd.x = fma(s, b.x, fc.x);
|
| 370 |
+
fd.y = fma(s, b.y, fc.y);
|
| 371 |
+
return fd;
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
|
| 375 |
+
Float8_ fd;
|
| 376 |
+
fd.x = fma(a.x, b.x, fc.x);
|
| 377 |
+
fd.y = fma(a.y, b.y, fc.y);
|
| 378 |
+
fd.z = fma(a.z, b.z, fc.z);
|
| 379 |
+
fd.w = fma(a.w, b.w, fc.w);
|
| 380 |
+
return fd;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
|
| 384 |
+
__nv_bfloat162 s = bf162bf162(a);
|
| 385 |
+
Float8_ fd;
|
| 386 |
+
fd.x = fma(s, b.x, fc.x);
|
| 387 |
+
fd.y = fma(s, b.y, fc.y);
|
| 388 |
+
fd.z = fma(s, b.z, fc.z);
|
| 389 |
+
fd.w = fma(s, b.w, fc.w);
|
| 390 |
+
return fd;
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
// Vector sum.
|
| 394 |
+
template <>
|
| 395 |
+
inline __device__ float sum(__nv_bfloat16 v) {
|
| 396 |
+
return __bfloat162float(v);
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
template <>
|
| 400 |
+
inline __device__ float sum(__nv_bfloat162 v) {
|
| 401 |
+
float2 vf = bf1622float2(v);
|
| 402 |
+
return vf.x + vf.y;
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
template <>
|
| 406 |
+
inline __device__ float sum(bf16_4_t v) {
|
| 407 |
+
return sum(v.x) + sum(v.y);
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
template <>
|
| 411 |
+
inline __device__ float sum(bf16_8_t v) {
|
| 412 |
+
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
// From float32 to bfloat16.
|
| 416 |
+
inline __device__ void from_float(__nv_bfloat16& dst, float src) {
|
| 417 |
+
dst = __float2bfloat16(src);
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
|
| 421 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 422 |
+
assert(false);
|
| 423 |
+
#else
|
| 424 |
+
dst = __float22bfloat162_rn(src);
|
| 425 |
+
#endif
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
|
| 429 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 430 |
+
assert(false);
|
| 431 |
+
#else
|
| 432 |
+
dst.x = __float22bfloat162_rn(src.x);
|
| 433 |
+
dst.y = __float22bfloat162_rn(src.y);
|
| 434 |
+
#endif
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
| 438 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 439 |
+
assert(false);
|
| 440 |
+
#else
|
| 441 |
+
dst.x = __float22bfloat162_rn(src.x);
|
| 442 |
+
dst.y = __float22bfloat162_rn(src.y);
|
| 443 |
+
dst.z = __float22bfloat162_rn(src.z);
|
| 444 |
+
dst.w = __float22bfloat162_rn(src.w);
|
| 445 |
+
#endif
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
// From bfloat16 to float32.
|
| 449 |
+
inline __device__ float to_float(__nv_bfloat16 u) {
|
| 450 |
+
return __bfloat162float(u);
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
// Zero-out a variable.
|
| 454 |
+
inline __device__ void zero(__nv_bfloat16& dst) {
|
| 455 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 456 |
+
assert(false);
|
| 457 |
+
#else
|
| 458 |
+
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
| 459 |
+
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
| 460 |
+
#endif
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
} // namespace vllm
|
paged-attention/attention/dtype_float16.cuh
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Adapted from
|
| 3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
| 4 |
+
* and
|
| 5 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
| 6 |
+
* Copyright (c) 2023, The vLLM team.
|
| 7 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 8 |
+
*
|
| 9 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
* you may not use this file except in compliance with the License.
|
| 11 |
+
* You may obtain a copy of the License at
|
| 12 |
+
*
|
| 13 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
*
|
| 15 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
* See the License for the specific language governing permissions and
|
| 19 |
+
* limitations under the License.
|
| 20 |
+
*/
|
| 21 |
+
#pragma once
|
| 22 |
+
|
| 23 |
+
#include "attention_generic.cuh"
|
| 24 |
+
#include "dtype_float32.cuh"
|
| 25 |
+
|
| 26 |
+
#ifdef USE_ROCM
|
| 27 |
+
#include <hip/hip_fp16.h>
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
+
#include <stdint.h>
|
| 31 |
+
|
| 32 |
+
namespace vllm {
|
| 33 |
+
|
| 34 |
+
// FP16 vector types for Q, K, V.
|
| 35 |
+
template <>
|
| 36 |
+
struct Vec<uint16_t, 1> {
|
| 37 |
+
using Type = uint16_t;
|
| 38 |
+
};
|
| 39 |
+
template <>
|
| 40 |
+
struct Vec<uint16_t, 2> {
|
| 41 |
+
using Type = uint32_t;
|
| 42 |
+
};
|
| 43 |
+
template <>
|
| 44 |
+
struct Vec<uint16_t, 4> {
|
| 45 |
+
using Type = uint2;
|
| 46 |
+
};
|
| 47 |
+
template <>
|
| 48 |
+
struct Vec<uint16_t, 8> {
|
| 49 |
+
using Type = uint4;
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
// FP32 accumulator vector types corresponding to Vec.
|
| 53 |
+
template <>
|
| 54 |
+
struct FloatVec<uint16_t> {
|
| 55 |
+
using Type = float;
|
| 56 |
+
};
|
| 57 |
+
template <>
|
| 58 |
+
struct FloatVec<uint32_t> {
|
| 59 |
+
using Type = float2;
|
| 60 |
+
};
|
| 61 |
+
template <>
|
| 62 |
+
struct FloatVec<uint2> {
|
| 63 |
+
using Type = Float4_;
|
| 64 |
+
};
|
| 65 |
+
template <>
|
| 66 |
+
struct FloatVec<uint4> {
|
| 67 |
+
using Type = Float8_;
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
// Utility functions for type conversions.
|
| 71 |
+
inline __device__ uint32_t h0_h0(uint16_t a) {
|
| 72 |
+
#ifndef USE_ROCM
|
| 73 |
+
uint32_t b;
|
| 74 |
+
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
| 75 |
+
return b;
|
| 76 |
+
#else
|
| 77 |
+
union {
|
| 78 |
+
uint32_t u32;
|
| 79 |
+
uint16_t u16[2];
|
| 80 |
+
} tmp;
|
| 81 |
+
tmp.u16[0] = a;
|
| 82 |
+
tmp.u16[1] = a;
|
| 83 |
+
return tmp.u32;
|
| 84 |
+
#endif
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
inline __device__ float half_to_float(uint16_t h) {
|
| 88 |
+
float f;
|
| 89 |
+
#ifndef USE_ROCM
|
| 90 |
+
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
| 91 |
+
#else
|
| 92 |
+
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
| 93 |
+
#endif
|
| 94 |
+
return f;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
inline __device__ float2 half2_to_float2(uint32_t v) {
|
| 98 |
+
#ifndef USE_ROCM
|
| 99 |
+
uint16_t lo, hi;
|
| 100 |
+
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
| 101 |
+
return make_float2(half_to_float(lo), half_to_float(hi));
|
| 102 |
+
#else
|
| 103 |
+
union {
|
| 104 |
+
uint32_t u32;
|
| 105 |
+
uint16_t u16[2];
|
| 106 |
+
} tmp;
|
| 107 |
+
tmp.u32 = v;
|
| 108 |
+
float2 ret;
|
| 109 |
+
ret.x = half_to_float(tmp.u16[0]);
|
| 110 |
+
ret.y = half_to_float(tmp.u16[1]);
|
| 111 |
+
return ret;
|
| 112 |
+
#endif
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
inline __device__ uint16_t float_to_half(float f) {
|
| 116 |
+
union {
|
| 117 |
+
uint32_t u32;
|
| 118 |
+
uint16_t u16[2];
|
| 119 |
+
} tmp;
|
| 120 |
+
#ifndef USE_ROCM
|
| 121 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
| 122 |
+
#else
|
| 123 |
+
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
| 124 |
+
#endif
|
| 125 |
+
return tmp.u16[0];
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
inline __device__ uint32_t float2_to_half2(float2 f) {
|
| 129 |
+
union {
|
| 130 |
+
uint32_t u32;
|
| 131 |
+
uint16_t u16[2];
|
| 132 |
+
} tmp;
|
| 133 |
+
#ifndef USE_ROCM
|
| 134 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
| 135 |
+
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
|
| 136 |
+
: "=r"(tmp.u32)
|
| 137 |
+
: "f"(f.y), "f"(f.x));
|
| 138 |
+
#else
|
| 139 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
| 140 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
| 141 |
+
#endif
|
| 142 |
+
#else
|
| 143 |
+
tmp.u16[0] = float_to_half(f.x);
|
| 144 |
+
tmp.u16[1] = float_to_half(f.y);
|
| 145 |
+
#endif
|
| 146 |
+
return tmp.u32;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// Vector addition.
|
| 150 |
+
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
| 151 |
+
uint16_t c;
|
| 152 |
+
#ifndef USE_ROCM
|
| 153 |
+
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
| 154 |
+
#else
|
| 155 |
+
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
| 156 |
+
#endif
|
| 157 |
+
return c;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
| 161 |
+
uint32_t c;
|
| 162 |
+
#ifndef USE_ROCM
|
| 163 |
+
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
| 164 |
+
#else
|
| 165 |
+
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
| 166 |
+
#endif
|
| 167 |
+
return c;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
inline __device__ uint2 add(uint2 a, uint2 b) {
|
| 171 |
+
uint2 c;
|
| 172 |
+
c.x = add(a.x, b.x);
|
| 173 |
+
c.y = add(a.y, b.y);
|
| 174 |
+
return c;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
inline __device__ uint4 add(uint4 a, uint4 b) {
|
| 178 |
+
uint4 c;
|
| 179 |
+
c.x = add(a.x, b.x);
|
| 180 |
+
c.y = add(a.y, b.y);
|
| 181 |
+
c.z = add(a.z, b.z);
|
| 182 |
+
c.w = add(a.w, b.w);
|
| 183 |
+
return c;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
inline __device__ float2 add(uint32_t a, float2 fb) {
|
| 187 |
+
float2 fa = half2_to_float2(a);
|
| 188 |
+
return add(fa, fb);
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
inline __device__ Float4_ add(uint2 a, Float4_ fb) {
|
| 192 |
+
Float4_ fc;
|
| 193 |
+
fc.x = add(a.x, fb.x);
|
| 194 |
+
fc.y = add(a.y, fb.y);
|
| 195 |
+
return fc;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
| 199 |
+
Float8_ fc;
|
| 200 |
+
fc.x = add(a.x, fb.x);
|
| 201 |
+
fc.y = add(a.y, fb.y);
|
| 202 |
+
fc.z = add(a.z, fb.z);
|
| 203 |
+
fc.w = add(a.w, fb.w);
|
| 204 |
+
return fc;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
// Vector multiplication.
|
| 208 |
+
template <>
|
| 209 |
+
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
| 210 |
+
uint16_t c;
|
| 211 |
+
#ifndef USE_ROCM
|
| 212 |
+
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
| 213 |
+
#else
|
| 214 |
+
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
| 215 |
+
#endif
|
| 216 |
+
return c;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
template <>
|
| 220 |
+
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
| 221 |
+
uint32_t c;
|
| 222 |
+
#ifndef USE_ROCM
|
| 223 |
+
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
| 224 |
+
#else
|
| 225 |
+
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
| 226 |
+
#endif
|
| 227 |
+
return c;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
template <>
|
| 231 |
+
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
|
| 232 |
+
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
template <>
|
| 236 |
+
inline __device__ uint2 mul(uint2 a, uint2 b) {
|
| 237 |
+
uint2 c;
|
| 238 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
| 239 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
| 240 |
+
return c;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
template <>
|
| 244 |
+
inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
| 245 |
+
uint32_t s = h0_h0(a);
|
| 246 |
+
uint2 c;
|
| 247 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
| 248 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
| 249 |
+
return c;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
template <>
|
| 253 |
+
inline __device__ uint4 mul(uint4 a, uint4 b) {
|
| 254 |
+
uint4 c;
|
| 255 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
| 256 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
| 257 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
|
| 258 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
|
| 259 |
+
return c;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
template <>
|
| 263 |
+
inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
| 264 |
+
uint32_t s = h0_h0(a);
|
| 265 |
+
uint4 c;
|
| 266 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
| 267 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
| 268 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
|
| 269 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
|
| 270 |
+
return c;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
template <>
|
| 274 |
+
inline __device__ float mul(uint16_t a, uint16_t b) {
|
| 275 |
+
float fa = half_to_float(a);
|
| 276 |
+
float fb = half_to_float(b);
|
| 277 |
+
return fa * fb;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
template <>
|
| 281 |
+
inline __device__ float2 mul(uint32_t a, uint32_t b) {
|
| 282 |
+
float2 fa = half2_to_float2(a);
|
| 283 |
+
float2 fb = half2_to_float2(b);
|
| 284 |
+
return mul<float2, float2, float2>(fa, fb);
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
template <>
|
| 288 |
+
inline __device__ float2 mul(uint16_t a, uint32_t b) {
|
| 289 |
+
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
template <>
|
| 293 |
+
inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
| 294 |
+
Float4_ fc;
|
| 295 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
| 296 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
| 297 |
+
return fc;
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
template <>
|
| 301 |
+
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
| 302 |
+
uint32_t s = h0_h0(a);
|
| 303 |
+
Float4_ fc;
|
| 304 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
| 305 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
| 306 |
+
return fc;
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
template <>
|
| 310 |
+
inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
| 311 |
+
Float8_ fc;
|
| 312 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
| 313 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
| 314 |
+
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
|
| 315 |
+
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
|
| 316 |
+
return fc;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
template <>
|
| 320 |
+
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
| 321 |
+
uint32_t s = h0_h0(a);
|
| 322 |
+
Float8_ fc;
|
| 323 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
| 324 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
| 325 |
+
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
|
| 326 |
+
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
|
| 327 |
+
return fc;
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
// Vector fused multiply-add.
|
| 331 |
+
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
| 332 |
+
uint32_t d;
|
| 333 |
+
#ifndef USE_ROCM
|
| 334 |
+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
| 335 |
+
: "=r"(d)
|
| 336 |
+
: "r"(a), "r"(b), "r"(c));
|
| 337 |
+
#else
|
| 338 |
+
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
|
| 339 |
+
: "=v"(d)
|
| 340 |
+
: "v"(a), "v"(b), "v"(c));
|
| 341 |
+
#endif
|
| 342 |
+
return d;
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
|
| 346 |
+
return fma(h0_h0(a), b, c);
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
|
| 350 |
+
uint2 d;
|
| 351 |
+
d.x = fma(a.x, b.x, c.x);
|
| 352 |
+
d.y = fma(a.y, b.y, c.y);
|
| 353 |
+
return d;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
|
| 357 |
+
uint32_t s = h0_h0(a);
|
| 358 |
+
uint2 d;
|
| 359 |
+
d.x = fma(s, b.x, c.x);
|
| 360 |
+
d.y = fma(s, b.y, c.y);
|
| 361 |
+
return d;
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
|
| 365 |
+
uint4 d;
|
| 366 |
+
d.x = fma(a.x, b.x, c.x);
|
| 367 |
+
d.y = fma(a.y, b.y, c.y);
|
| 368 |
+
d.z = fma(a.z, b.z, c.z);
|
| 369 |
+
d.w = fma(a.w, b.w, c.w);
|
| 370 |
+
return d;
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
|
| 374 |
+
uint32_t s = h0_h0(a);
|
| 375 |
+
uint4 d;
|
| 376 |
+
d.x = fma(s, b.x, c.x);
|
| 377 |
+
d.y = fma(s, b.y, c.y);
|
| 378 |
+
d.z = fma(s, b.z, c.z);
|
| 379 |
+
d.w = fma(s, b.w, c.w);
|
| 380 |
+
return d;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
|
| 384 |
+
float fa = half_to_float(a);
|
| 385 |
+
float fb = half_to_float(b);
|
| 386 |
+
return fa * fb + fc;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
|
| 390 |
+
float2 fa = half2_to_float2(a);
|
| 391 |
+
float2 fb = half2_to_float2(b);
|
| 392 |
+
return fma(fa, fb, fc);
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
|
| 396 |
+
return fma(h0_h0(a), b, fc);
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
|
| 400 |
+
Float4_ fd;
|
| 401 |
+
fd.x = fma(a.x, b.x, fc.x);
|
| 402 |
+
fd.y = fma(a.y, b.y, fc.y);
|
| 403 |
+
return fd;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
|
| 407 |
+
uint32_t s = h0_h0(a);
|
| 408 |
+
Float4_ fd;
|
| 409 |
+
fd.x = fma(s, b.x, fc.x);
|
| 410 |
+
fd.y = fma(s, b.y, fc.y);
|
| 411 |
+
return fd;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
|
| 415 |
+
Float8_ fd;
|
| 416 |
+
fd.x = fma(a.x, b.x, fc.x);
|
| 417 |
+
fd.y = fma(a.y, b.y, fc.y);
|
| 418 |
+
fd.z = fma(a.z, b.z, fc.z);
|
| 419 |
+
fd.w = fma(a.w, b.w, fc.w);
|
| 420 |
+
return fd;
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
|
| 424 |
+
uint32_t s = h0_h0(a);
|
| 425 |
+
Float8_ fd;
|
| 426 |
+
fd.x = fma(s, b.x, fc.x);
|
| 427 |
+
fd.y = fma(s, b.y, fc.y);
|
| 428 |
+
fd.z = fma(s, b.z, fc.z);
|
| 429 |
+
fd.w = fma(s, b.w, fc.w);
|
| 430 |
+
return fd;
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
// Vector sum.
|
| 434 |
+
template <>
|
| 435 |
+
inline __device__ float sum(uint16_t v) {
|
| 436 |
+
return half_to_float(v);
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
template <>
|
| 440 |
+
inline __device__ float sum(uint32_t v) {
|
| 441 |
+
float2 tmp = half2_to_float2(v);
|
| 442 |
+
return tmp.x + tmp.y;
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
template <>
|
| 446 |
+
inline __device__ float sum(uint2 v) {
|
| 447 |
+
uint32_t c = add(v.x, v.y);
|
| 448 |
+
return sum(c);
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
template <>
|
| 452 |
+
inline __device__ float sum(uint4 v) {
|
| 453 |
+
uint32_t c = add(v.x, v.y);
|
| 454 |
+
c = add(c, v.z);
|
| 455 |
+
c = add(c, v.w);
|
| 456 |
+
return sum(c);
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
// From float32 to float16.
|
| 460 |
+
inline __device__ void from_float(uint16_t& dst, float src) {
|
| 461 |
+
dst = float_to_half(src);
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
inline __device__ void from_float(uint32_t& dst, float2 src) {
|
| 465 |
+
dst = float2_to_half2(src);
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
inline __device__ void from_float(uint2& dst, Float4_ src) {
|
| 469 |
+
dst.x = float2_to_half2(src.x);
|
| 470 |
+
dst.y = float2_to_half2(src.y);
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
inline __device__ void from_float(uint4& dst, Float8_ src) {
|
| 474 |
+
dst.x = float2_to_half2(src.x);
|
| 475 |
+
dst.y = float2_to_half2(src.y);
|
| 476 |
+
dst.z = float2_to_half2(src.z);
|
| 477 |
+
dst.w = float2_to_half2(src.w);
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
// From float16 to float32.
|
| 481 |
+
inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
|
| 482 |
+
|
| 483 |
+
inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
|
| 484 |
+
|
| 485 |
+
inline __device__ Float4_ to_float(uint2 u) {
|
| 486 |
+
Float4_ tmp;
|
| 487 |
+
tmp.x = half2_to_float2(u.x);
|
| 488 |
+
tmp.y = half2_to_float2(u.y);
|
| 489 |
+
return tmp;
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
inline __device__ Float8_ to_float(uint4 u) {
|
| 493 |
+
Float8_ tmp;
|
| 494 |
+
tmp.x = half2_to_float2(u.x);
|
| 495 |
+
tmp.y = half2_to_float2(u.y);
|
| 496 |
+
tmp.z = half2_to_float2(u.z);
|
| 497 |
+
tmp.w = half2_to_float2(u.w);
|
| 498 |
+
return tmp;
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
// Zero-out a variable.
|
| 502 |
+
inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
|
| 503 |
+
|
| 504 |
+
} // namespace vllm
|
paged-attention/attention/dtype_float32.cuh
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Adapted from
|
| 3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
| 4 |
+
* and
|
| 5 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
| 6 |
+
* Copyright (c) 2023, The vLLM team.
|
| 7 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 8 |
+
*
|
| 9 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
* you may not use this file except in compliance with the License.
|
| 11 |
+
* You may obtain a copy of the License at
|
| 12 |
+
*
|
| 13 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
*
|
| 15 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
* See the License for the specific language governing permissions and
|
| 19 |
+
* limitations under the License.
|
| 20 |
+
*/
|
| 21 |
+
#pragma once
|
| 22 |
+
|
| 23 |
+
#include "attention_generic.cuh"
|
| 24 |
+
|
| 25 |
+
#include <stdint.h>
|
| 26 |
+
|
| 27 |
+
namespace vllm {
|
| 28 |
+
|
| 29 |
+
// Define custom FP32 vector data types.
|
| 30 |
+
struct Float4_ {
|
| 31 |
+
float2 x;
|
| 32 |
+
float2 y;
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
struct Float8_ {
|
| 36 |
+
float2 x;
|
| 37 |
+
float2 y;
|
| 38 |
+
float2 z;
|
| 39 |
+
float2 w;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
// FP32 vector types for Q, K, V.
|
| 43 |
+
template <>
|
| 44 |
+
struct Vec<float, 1> {
|
| 45 |
+
using Type = float;
|
| 46 |
+
};
|
| 47 |
+
template <>
|
| 48 |
+
struct Vec<float, 2> {
|
| 49 |
+
using Type = float2;
|
| 50 |
+
};
|
| 51 |
+
template <>
|
| 52 |
+
struct Vec<float, 4> {
|
| 53 |
+
using Type = float4;
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
// FP32 accumulator vector types corresponding to Vec.
|
| 57 |
+
template <>
|
| 58 |
+
struct FloatVec<float> {
|
| 59 |
+
using Type = float;
|
| 60 |
+
};
|
| 61 |
+
template <>
|
| 62 |
+
struct FloatVec<float2> {
|
| 63 |
+
using Type = float2;
|
| 64 |
+
};
|
| 65 |
+
template <>
|
| 66 |
+
struct FloatVec<float4> {
|
| 67 |
+
using Type = float4;
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
// Vector addition.
|
| 71 |
+
inline __device__ float add(float a, float b) { return a + b; }
|
| 72 |
+
|
| 73 |
+
inline __device__ float2 add(float2 a, float2 b) {
|
| 74 |
+
float2 c;
|
| 75 |
+
c.x = add(a.x, b.x);
|
| 76 |
+
c.y = add(a.y, b.y);
|
| 77 |
+
return c;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
inline __device__ float4 add(float4 a, float4 b) {
|
| 81 |
+
float4 c;
|
| 82 |
+
c.x = add(a.x, b.x);
|
| 83 |
+
c.y = add(a.y, b.y);
|
| 84 |
+
c.z = add(a.z, b.z);
|
| 85 |
+
c.w = add(a.w, b.w);
|
| 86 |
+
return c;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// Vector multiplication.
|
| 90 |
+
template <>
|
| 91 |
+
inline __device__ float mul<float, float>(float a, float b) {
|
| 92 |
+
return a * b;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
template <>
|
| 96 |
+
inline __device__ float2 mul(float2 a, float2 b) {
|
| 97 |
+
float2 c;
|
| 98 |
+
c.x = a.x * b.x;
|
| 99 |
+
c.y = a.y * b.y;
|
| 100 |
+
return c;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
template <>
|
| 104 |
+
inline __device__ float2 mul(float a, float2 b) {
|
| 105 |
+
float2 c;
|
| 106 |
+
c.x = a * b.x;
|
| 107 |
+
c.y = a * b.y;
|
| 108 |
+
return c;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
template <>
|
| 112 |
+
inline __device__ float4 mul(float4 a, float4 b) {
|
| 113 |
+
float4 c;
|
| 114 |
+
c.x = a.x * b.x;
|
| 115 |
+
c.y = a.y * b.y;
|
| 116 |
+
c.z = a.z * b.z;
|
| 117 |
+
c.w = a.w * b.w;
|
| 118 |
+
return c;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
template <>
|
| 122 |
+
inline __device__ float4 mul(float a, float4 b) {
|
| 123 |
+
float4 c;
|
| 124 |
+
c.x = a * b.x;
|
| 125 |
+
c.y = a * b.y;
|
| 126 |
+
c.z = a * b.z;
|
| 127 |
+
c.w = a * b.w;
|
| 128 |
+
return c;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
// Vector fused multiply-add.
|
| 132 |
+
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
|
| 133 |
+
|
| 134 |
+
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
| 135 |
+
float2 d;
|
| 136 |
+
d.x = fma(a.x, b.x, c.x);
|
| 137 |
+
d.y = fma(a.y, b.y, c.y);
|
| 138 |
+
return d;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
inline __device__ float2 fma(float a, float2 b, float2 c) {
|
| 142 |
+
float2 d;
|
| 143 |
+
d.x = fma(a, b.x, c.x);
|
| 144 |
+
d.y = fma(a, b.y, c.y);
|
| 145 |
+
return d;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
|
| 149 |
+
float4 d;
|
| 150 |
+
d.x = fma(a.x, b.x, c.x);
|
| 151 |
+
d.y = fma(a.y, b.y, c.y);
|
| 152 |
+
d.z = fma(a.z, b.z, c.z);
|
| 153 |
+
d.w = fma(a.w, b.w, c.w);
|
| 154 |
+
return d;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
inline __device__ float4 fma(float a, float4 b, float4 c) {
|
| 158 |
+
float4 d;
|
| 159 |
+
d.x = fma(a, b.x, c.x);
|
| 160 |
+
d.y = fma(a, b.y, c.y);
|
| 161 |
+
d.z = fma(a, b.z, c.z);
|
| 162 |
+
d.w = fma(a, b.w, c.w);
|
| 163 |
+
return d;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
|
| 167 |
+
Float4_ d;
|
| 168 |
+
d.x = fma(a, b.x, c.x);
|
| 169 |
+
d.y = fma(a, b.y, c.y);
|
| 170 |
+
return d;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
|
| 174 |
+
Float8_ d;
|
| 175 |
+
d.x = fma(a, b.x, c.x);
|
| 176 |
+
d.y = fma(a, b.y, c.y);
|
| 177 |
+
d.z = fma(a, b.z, c.z);
|
| 178 |
+
d.w = fma(a, b.w, c.w);
|
| 179 |
+
return d;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
// Vector sum.
|
| 183 |
+
template <>
|
| 184 |
+
inline __device__ float sum(float v) {
|
| 185 |
+
return v;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
template <>
|
| 189 |
+
inline __device__ float sum(float2 v) {
|
| 190 |
+
return v.x + v.y;
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
template <>
|
| 194 |
+
inline __device__ float sum(float4 v) {
|
| 195 |
+
return v.x + v.y + v.z + v.w;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
template <>
|
| 199 |
+
inline __device__ float sum(Float4_ v) {
|
| 200 |
+
return v.x.x + v.x.y + v.y.x + v.y.y;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
template <>
|
| 204 |
+
inline __device__ float sum(Float8_ v) {
|
| 205 |
+
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
// Vector dot product.
|
| 209 |
+
inline __device__ float dot(float a, float b) { return a * b; }
|
| 210 |
+
|
| 211 |
+
inline __device__ float dot(float2 a, float2 b) {
|
| 212 |
+
float2 c = mul<float2, float2, float2>(a, b);
|
| 213 |
+
return c.x + c.y;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
inline __device__ float dot(Float4_ a, Float4_ b) {
|
| 217 |
+
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
| 218 |
+
acc = fma(a.y, b.y, acc);
|
| 219 |
+
return acc.x + acc.y;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
inline __device__ float dot(Float8_ a, Float8_ b) {
|
| 223 |
+
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
| 224 |
+
acc = fma(a.y, b.y, acc);
|
| 225 |
+
acc = fma(a.z, b.z, acc);
|
| 226 |
+
acc = fma(a.w, b.w, acc);
|
| 227 |
+
return acc.x + acc.y;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
// From float to float.
|
| 231 |
+
inline __device__ void from_float(float& dst, float src) { dst = src; }
|
| 232 |
+
|
| 233 |
+
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
|
| 234 |
+
|
| 235 |
+
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
|
| 236 |
+
|
| 237 |
+
// From float to float.
|
| 238 |
+
inline __device__ float to_float(float u) { return u; }
|
| 239 |
+
|
| 240 |
+
inline __device__ float2 to_float(float2 u) { return u; }
|
| 241 |
+
|
| 242 |
+
inline __device__ float4 to_float(float4 u) { return u; }
|
| 243 |
+
|
| 244 |
+
inline __device__ Float4_ to_float(Float4_ u) { return u; }
|
| 245 |
+
|
| 246 |
+
inline __device__ Float8_ to_float(Float8_ u) { return u; }
|
| 247 |
+
|
| 248 |
+
// Zero-out a variable.
|
| 249 |
+
inline __device__ void zero(float& dst) { dst = 0.f; }
|
| 250 |
+
|
| 251 |
+
} // namespace vllm
|
paged-attention/attention/dtype_fp8.cuh
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "attention_generic.cuh"
|
| 4 |
+
|
| 5 |
+
#include <stdint.h>
|
| 6 |
+
#ifdef ENABLE_FP8
|
| 7 |
+
#ifndef USE_ROCM
|
| 8 |
+
#include <cuda_fp8.h>
|
| 9 |
+
#endif // USE_ROCM
|
| 10 |
+
#endif // ENABLE_FP8
|
| 11 |
+
|
| 12 |
+
namespace vllm {
|
| 13 |
+
|
| 14 |
+
enum class Fp8KVCacheDataType {
|
| 15 |
+
kAuto = 0,
|
| 16 |
+
kFp8E4M3 = 1,
|
| 17 |
+
kFp8E5M2 = 2,
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
// fp8 vector types for quantization of kv cache
|
| 21 |
+
template <>
|
| 22 |
+
struct Vec<uint8_t, 1> {
|
| 23 |
+
using Type = uint8_t;
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
template <>
|
| 27 |
+
struct Vec<uint8_t, 2> {
|
| 28 |
+
using Type = uint16_t;
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
template <>
|
| 32 |
+
struct Vec<uint8_t, 4> {
|
| 33 |
+
using Type = uint32_t;
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
template <>
|
| 37 |
+
struct Vec<uint8_t, 8> {
|
| 38 |
+
using Type = uint2;
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
} // namespace vllm
|
paged-attention/attention/paged_attention_v1.cu
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Adapted from
|
| 3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
| 4 |
+
* Copyright (c) 2023, The vLLM team.
|
| 5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 6 |
+
*
|
| 7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
* you may not use this file except in compliance with the License.
|
| 9 |
+
* You may obtain a copy of the License at
|
| 10 |
+
*
|
| 11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
*
|
| 13 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
* See the License for the specific language governing permissions and
|
| 17 |
+
* limitations under the License.
|
| 18 |
+
*/
|
| 19 |
+
|
| 20 |
+
#include "attention_kernels.cuh"
|
| 21 |
+
|
| 22 |
+
#ifndef USE_ROCM
|
| 23 |
+
#define WARP_SIZE 32
|
| 24 |
+
#else
|
| 25 |
+
#define WARP_SIZE warpSize
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
| 29 |
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
| 30 |
+
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
| 31 |
+
|
| 32 |
+
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
| 33 |
+
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
| 34 |
+
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
|
| 35 |
+
BLOCK_SIZE, NUM_THREADS, \
|
| 36 |
+
KV_DTYPE, IS_BLOCK_SPARSE>), \
|
| 37 |
+
shared_mem_size); \
|
| 38 |
+
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
| 39 |
+
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
|
| 40 |
+
<<<grid, block, shared_mem_size, stream>>>( \
|
| 41 |
+
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
| 42 |
+
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
| 43 |
+
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
| 44 |
+
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
|
| 45 |
+
blocksparse_vert_stride, blocksparse_block_size, \
|
| 46 |
+
blocksparse_head_sliding_step);
|
| 47 |
+
|
| 48 |
+
// TODO(woosuk): Tune NUM_THREADS.
|
| 49 |
+
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
| 50 |
+
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
|
| 51 |
+
int NUM_THREADS = 128>
|
| 52 |
+
void paged_attention_v1_launcher(
|
| 53 |
+
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
| 54 |
+
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
| 55 |
+
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
| 56 |
+
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
| 57 |
+
torch::Tensor& v_scale, const int tp_rank,
|
| 58 |
+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
| 59 |
+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
| 60 |
+
int num_seqs = query.size(0);
|
| 61 |
+
int num_heads = query.size(1);
|
| 62 |
+
int head_size = query.size(2);
|
| 63 |
+
int max_num_blocks_per_seq = block_tables.size(1);
|
| 64 |
+
int q_stride = query.stride(0);
|
| 65 |
+
int kv_block_stride = key_cache.stride(0);
|
| 66 |
+
int kv_head_stride = key_cache.stride(1);
|
| 67 |
+
|
| 68 |
+
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
| 69 |
+
assert(head_size % thread_group_size == 0);
|
| 70 |
+
|
| 71 |
+
// NOTE: alibi_slopes is optional.
|
| 72 |
+
const float* alibi_slopes_ptr =
|
| 73 |
+
alibi_slopes
|
| 74 |
+
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
| 75 |
+
: nullptr;
|
| 76 |
+
|
| 77 |
+
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
| 78 |
+
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
| 79 |
+
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
| 80 |
+
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
| 81 |
+
int* block_tables_ptr = block_tables.data_ptr<int>();
|
| 82 |
+
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
| 83 |
+
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
| 84 |
+
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
| 85 |
+
|
| 86 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
| 87 |
+
int padded_max_seq_len =
|
| 88 |
+
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
| 89 |
+
int logits_size = padded_max_seq_len * sizeof(float);
|
| 90 |
+
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
| 91 |
+
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
| 92 |
+
// Keep that in sync with the logic here!
|
| 93 |
+
int shared_mem_size = std::max(logits_size, outputs_size);
|
| 94 |
+
|
| 95 |
+
dim3 grid(num_heads, num_seqs, 1);
|
| 96 |
+
dim3 block(NUM_THREADS);
|
| 97 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
| 98 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 99 |
+
switch (head_size) {
|
| 100 |
+
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
| 101 |
+
// head sizes that we use in the model. However, we can easily extend this
|
| 102 |
+
// to support any head size which is a multiple of 16.
|
| 103 |
+
case 32:
|
| 104 |
+
LAUNCH_PAGED_ATTENTION_V1(32);
|
| 105 |
+
break;
|
| 106 |
+
case 64:
|
| 107 |
+
LAUNCH_PAGED_ATTENTION_V1(64);
|
| 108 |
+
break;
|
| 109 |
+
case 80:
|
| 110 |
+
LAUNCH_PAGED_ATTENTION_V1(80);
|
| 111 |
+
break;
|
| 112 |
+
case 96:
|
| 113 |
+
LAUNCH_PAGED_ATTENTION_V1(96);
|
| 114 |
+
break;
|
| 115 |
+
case 112:
|
| 116 |
+
LAUNCH_PAGED_ATTENTION_V1(112);
|
| 117 |
+
break;
|
| 118 |
+
case 120:
|
| 119 |
+
LAUNCH_PAGED_ATTENTION_V1(120);
|
| 120 |
+
break;
|
| 121 |
+
case 128:
|
| 122 |
+
LAUNCH_PAGED_ATTENTION_V1(128);
|
| 123 |
+
break;
|
| 124 |
+
case 192:
|
| 125 |
+
LAUNCH_PAGED_ATTENTION_V1(192);
|
| 126 |
+
break;
|
| 127 |
+
case 256:
|
| 128 |
+
LAUNCH_PAGED_ATTENTION_V1(256);
|
| 129 |
+
break;
|
| 130 |
+
default:
|
| 131 |
+
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
| 132 |
+
break;
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
|
| 137 |
+
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
| 138 |
+
IS_BLOCK_SPARSE>( \
|
| 139 |
+
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
| 140 |
+
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
|
| 141 |
+
blocksparse_local_blocks, blocksparse_vert_stride, \
|
| 142 |
+
blocksparse_block_size, blocksparse_head_sliding_step);
|
| 143 |
+
|
| 144 |
+
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
| 145 |
+
if (is_block_sparse) { \
|
| 146 |
+
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
|
| 147 |
+
} else { \
|
| 148 |
+
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
| 152 |
+
// 1, 2, 4, 64, 128, 256.
|
| 153 |
+
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
| 154 |
+
switch (block_size) { \
|
| 155 |
+
case 8: \
|
| 156 |
+
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
|
| 157 |
+
break; \
|
| 158 |
+
case 16: \
|
| 159 |
+
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
|
| 160 |
+
break; \
|
| 161 |
+
case 32: \
|
| 162 |
+
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
|
| 163 |
+
break; \
|
| 164 |
+
default: \
|
| 165 |
+
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
| 166 |
+
break; \
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
void paged_attention_v1(
|
| 170 |
+
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
| 171 |
+
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
| 172 |
+
torch::Tensor&
|
| 173 |
+
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
| 174 |
+
torch::Tensor&
|
| 175 |
+
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
| 176 |
+
int64_t num_kv_heads, // [num_heads]
|
| 177 |
+
double scale,
|
| 178 |
+
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
| 179 |
+
torch::Tensor& seq_lens, // [num_seqs]
|
| 180 |
+
int64_t block_size, int64_t max_seq_len,
|
| 181 |
+
const std::optional<torch::Tensor>& alibi_slopes,
|
| 182 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
| 183 |
+
torch::Tensor& v_scale, const int64_t tp_rank,
|
| 184 |
+
const int64_t blocksparse_local_blocks,
|
| 185 |
+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
| 186 |
+
const int64_t blocksparse_head_sliding_step) {
|
| 187 |
+
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
| 188 |
+
|
| 189 |
+
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
| 190 |
+
CALL_V1_LAUNCHER_BLOCK_SIZE)
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
#undef WARP_SIZE
|
| 194 |
+
#undef MAX
|
| 195 |
+
#undef MIN
|
| 196 |
+
#undef DIVIDE_ROUND_UP
|
paged-attention/attention/paged_attention_v2.cu
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Adapted from
|
| 3 |
+
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
| 4 |
+
* Copyright (c) 2023, The vLLM team.
|
| 5 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 6 |
+
*
|
| 7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
* you may not use this file except in compliance with the License.
|
| 9 |
+
* You may obtain a copy of the License at
|
| 10 |
+
*
|
| 11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
*
|
| 13 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
* See the License for the specific language governing permissions and
|
| 17 |
+
* limitations under the License.
|
| 18 |
+
*/
|
| 19 |
+
|
| 20 |
+
#include "attention_kernels.cuh"
|
| 21 |
+
|
| 22 |
+
#ifndef USE_ROCM
|
| 23 |
+
#define WARP_SIZE 32
|
| 24 |
+
#else
|
| 25 |
+
#define WARP_SIZE warpSize
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
| 29 |
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
| 30 |
+
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
| 31 |
+
|
| 32 |
+
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
| 33 |
+
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
| 34 |
+
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
|
| 35 |
+
PARTITION_SIZE> \
|
| 36 |
+
<<<grid, block, shared_mem_size, stream>>>( \
|
| 37 |
+
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
| 38 |
+
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
| 39 |
+
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
| 40 |
+
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
|
| 41 |
+
blocksparse_local_blocks, blocksparse_vert_stride, \
|
| 42 |
+
blocksparse_block_size, blocksparse_head_sliding_step); \
|
| 43 |
+
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
| 44 |
+
PARTITION_SIZE> \
|
| 45 |
+
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
| 46 |
+
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
|
| 47 |
+
max_num_partitions);
|
| 48 |
+
|
| 49 |
+
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
| 50 |
+
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
|
| 51 |
+
int NUM_THREADS = 128, int PARTITION_SIZE = 512>
|
| 52 |
+
void paged_attention_v2_launcher(
|
| 53 |
+
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
| 54 |
+
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
| 55 |
+
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
| 56 |
+
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
| 57 |
+
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
| 58 |
+
torch::Tensor& v_scale, const int tp_rank,
|
| 59 |
+
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
| 60 |
+
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
| 61 |
+
int num_seqs = query.size(0);
|
| 62 |
+
int num_heads = query.size(1);
|
| 63 |
+
int head_size = query.size(2);
|
| 64 |
+
int max_num_blocks_per_seq = block_tables.size(1);
|
| 65 |
+
int q_stride = query.stride(0);
|
| 66 |
+
int kv_block_stride = key_cache.stride(0);
|
| 67 |
+
int kv_head_stride = key_cache.stride(1);
|
| 68 |
+
|
| 69 |
+
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
| 70 |
+
assert(head_size % thread_group_size == 0);
|
| 71 |
+
|
| 72 |
+
// NOTE: alibi_slopes is optional.
|
| 73 |
+
const float* alibi_slopes_ptr =
|
| 74 |
+
alibi_slopes
|
| 75 |
+
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
| 76 |
+
: nullptr;
|
| 77 |
+
|
| 78 |
+
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
| 79 |
+
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
| 80 |
+
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
| 81 |
+
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
| 82 |
+
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
| 83 |
+
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
| 84 |
+
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
| 85 |
+
int* block_tables_ptr = block_tables.data_ptr<int>();
|
| 86 |
+
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
| 87 |
+
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
| 88 |
+
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
| 89 |
+
|
| 90 |
+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
| 91 |
+
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
| 92 |
+
int logits_size = PARTITION_SIZE * sizeof(float);
|
| 93 |
+
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
| 94 |
+
|
| 95 |
+
// For paged attention v2 kernel.
|
| 96 |
+
dim3 grid(num_heads, num_seqs, max_num_partitions);
|
| 97 |
+
int shared_mem_size = std::max(logits_size, outputs_size);
|
| 98 |
+
// For paged attention v2 reduce kernel.
|
| 99 |
+
dim3 reduce_grid(num_heads, num_seqs);
|
| 100 |
+
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
| 101 |
+
|
| 102 |
+
dim3 block(NUM_THREADS);
|
| 103 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
| 104 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 105 |
+
switch (head_size) {
|
| 106 |
+
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
| 107 |
+
// head sizes that we use in the model. However, we can easily extend this
|
| 108 |
+
// to support any head size which is a multiple of 16.
|
| 109 |
+
case 32:
|
| 110 |
+
LAUNCH_PAGED_ATTENTION_V2(32);
|
| 111 |
+
break;
|
| 112 |
+
case 64:
|
| 113 |
+
LAUNCH_PAGED_ATTENTION_V2(64);
|
| 114 |
+
break;
|
| 115 |
+
case 80:
|
| 116 |
+
LAUNCH_PAGED_ATTENTION_V2(80);
|
| 117 |
+
break;
|
| 118 |
+
case 96:
|
| 119 |
+
LAUNCH_PAGED_ATTENTION_V2(96);
|
| 120 |
+
break;
|
| 121 |
+
case 112:
|
| 122 |
+
LAUNCH_PAGED_ATTENTION_V2(112);
|
| 123 |
+
break;
|
| 124 |
+
case 120:
|
| 125 |
+
LAUNCH_PAGED_ATTENTION_V2(120);
|
| 126 |
+
break;
|
| 127 |
+
case 128:
|
| 128 |
+
LAUNCH_PAGED_ATTENTION_V2(128);
|
| 129 |
+
break;
|
| 130 |
+
case 192:
|
| 131 |
+
LAUNCH_PAGED_ATTENTION_V2(192);
|
| 132 |
+
break;
|
| 133 |
+
case 256:
|
| 134 |
+
LAUNCH_PAGED_ATTENTION_V2(256);
|
| 135 |
+
break;
|
| 136 |
+
default:
|
| 137 |
+
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
| 138 |
+
break;
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
|
| 143 |
+
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
| 144 |
+
IS_BLOCK_SPARSE>( \
|
| 145 |
+
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
| 146 |
+
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
| 147 |
+
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
| 148 |
+
blocksparse_vert_stride, blocksparse_block_size, \
|
| 149 |
+
blocksparse_head_sliding_step);
|
| 150 |
+
|
| 151 |
+
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
| 152 |
+
if (is_block_sparse) { \
|
| 153 |
+
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
|
| 154 |
+
} else { \
|
| 155 |
+
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
| 159 |
+
// 1, 2, 4, 64, 128, 256.
|
| 160 |
+
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
| 161 |
+
switch (block_size) { \
|
| 162 |
+
case 8: \
|
| 163 |
+
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
|
| 164 |
+
break; \
|
| 165 |
+
case 16: \
|
| 166 |
+
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
|
| 167 |
+
break; \
|
| 168 |
+
case 32: \
|
| 169 |
+
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
|
| 170 |
+
break; \
|
| 171 |
+
default: \
|
| 172 |
+
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
| 173 |
+
break; \
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
void paged_attention_v2(
|
| 177 |
+
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
| 178 |
+
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
| 179 |
+
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
| 180 |
+
torch::Tensor&
|
| 181 |
+
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
| 182 |
+
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
| 183 |
+
torch::Tensor&
|
| 184 |
+
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
| 185 |
+
torch::Tensor&
|
| 186 |
+
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
| 187 |
+
int64_t num_kv_heads, // [num_heads]
|
| 188 |
+
double scale,
|
| 189 |
+
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
| 190 |
+
torch::Tensor& seq_lens, // [num_seqs]
|
| 191 |
+
int64_t block_size, int64_t max_seq_len,
|
| 192 |
+
const std::optional<torch::Tensor>& alibi_slopes,
|
| 193 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
| 194 |
+
torch::Tensor& v_scale, const int64_t tp_rank,
|
| 195 |
+
const int64_t blocksparse_local_blocks,
|
| 196 |
+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
| 197 |
+
const int64_t blocksparse_head_sliding_step) {
|
| 198 |
+
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
| 199 |
+
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
| 200 |
+
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
#undef WARP_SIZE
|
| 204 |
+
#undef MAX
|
| 205 |
+
#undef MIN
|
| 206 |
+
#undef DIVIDE_ROUND_UP
|
paged-attention/cache_kernels.cu
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/all.h>
|
| 2 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 3 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 4 |
+
|
| 5 |
+
#include "cuda_compat.h"
|
| 6 |
+
#include "dispatch_utils.h"
|
| 7 |
+
|
| 8 |
+
#ifdef USE_ROCM
|
| 9 |
+
#include "quantization/fp8/amd/quant_utils.cuh"
|
| 10 |
+
#else
|
| 11 |
+
#include "quantization/fp8/nvidia/quant_utils.cuh"
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
#include <algorithm>
|
| 15 |
+
#include <cassert>
|
| 16 |
+
#include <map>
|
| 17 |
+
#include <vector>
|
| 18 |
+
|
| 19 |
+
#ifdef USE_ROCM
|
| 20 |
+
#include <hip/hip_bf16.h>
|
| 21 |
+
typedef __hip_bfloat16 __nv_bfloat16;
|
| 22 |
+
#endif
|
| 23 |
+
|
| 24 |
+
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
| 25 |
+
const torch::Tensor& block_mapping) {
|
| 26 |
+
torch::Device src_device = src.device();
|
| 27 |
+
torch::Device dst_device = dst.device();
|
| 28 |
+
cudaMemcpyKind memcpy_type;
|
| 29 |
+
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
| 30 |
+
TORCH_CHECK(src_device.index() == dst_device.index(),
|
| 31 |
+
"src and dst must be on the same GPU");
|
| 32 |
+
memcpy_type = cudaMemcpyDeviceToDevice;
|
| 33 |
+
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
| 34 |
+
memcpy_type = cudaMemcpyDeviceToHost;
|
| 35 |
+
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
| 36 |
+
memcpy_type = cudaMemcpyHostToDevice;
|
| 37 |
+
} else {
|
| 38 |
+
TORCH_CHECK(false, "Invalid device combination");
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
// NOTE(youkaichao): keep in mind that `block_mapping` should be
|
| 42 |
+
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
|
| 43 |
+
// synchronization.
|
| 44 |
+
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
| 45 |
+
|
| 46 |
+
char* src_ptr = static_cast<char*>(src.data_ptr());
|
| 47 |
+
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
| 48 |
+
|
| 49 |
+
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
| 50 |
+
const at::cuda::OptionalCUDAGuard device_guard(
|
| 51 |
+
src_device.is_cuda() ? src_device : dst_device);
|
| 52 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 53 |
+
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
| 54 |
+
const int64_t num_blocks = block_mapping.size(0);
|
| 55 |
+
for (size_t i = 0; i < num_blocks; i++) {
|
| 56 |
+
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
|
| 57 |
+
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
| 58 |
+
int64_t src_offset = src_block_number * block_size_in_bytes;
|
| 59 |
+
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
| 60 |
+
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
|
| 61 |
+
block_size_in_bytes, memcpy_type, stream);
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
namespace vllm {
|
| 66 |
+
|
| 67 |
+
// Grid: (num_layers, num_pairs)
|
| 68 |
+
template <typename scalar_t>
|
| 69 |
+
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
|
| 70 |
+
int64_t* value_cache_ptrs,
|
| 71 |
+
const int64_t* __restrict__ block_mapping,
|
| 72 |
+
const int numel_per_block) {
|
| 73 |
+
const int layer_idx = blockIdx.x;
|
| 74 |
+
const int pair_idx = blockIdx.y;
|
| 75 |
+
|
| 76 |
+
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
| 77 |
+
scalar_t* value_cache =
|
| 78 |
+
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
| 79 |
+
int64_t src_block_number = block_mapping[2 * pair_idx];
|
| 80 |
+
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
| 81 |
+
|
| 82 |
+
const int64_t src_block_offset = src_block_number * numel_per_block;
|
| 83 |
+
const int64_t dst_block_offset = dst_block_number * numel_per_block;
|
| 84 |
+
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
| 85 |
+
int64_t src_offset = src_block_offset + i;
|
| 86 |
+
int64_t dst_offset = dst_block_offset + i;
|
| 87 |
+
key_cache[dst_offset] = key_cache[src_offset];
|
| 88 |
+
}
|
| 89 |
+
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
| 90 |
+
int64_t src_offset = src_block_offset + i;
|
| 91 |
+
int64_t dst_offset = dst_block_offset + i;
|
| 92 |
+
value_cache[dst_offset] = value_cache[src_offset];
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
} // namespace vllm
|
| 97 |
+
|
| 98 |
+
// Note: the key_caches and value_caches vectors are constant but
|
| 99 |
+
// not the Tensors they contain. The vectors need to be const refs
|
| 100 |
+
// in order to satisfy pytorch's C++ operator registration code.
|
| 101 |
+
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
| 102 |
+
std::vector<torch::Tensor> const& value_caches,
|
| 103 |
+
const torch::Tensor& block_mapping) {
|
| 104 |
+
int num_layers = key_caches.size();
|
| 105 |
+
TORCH_CHECK(num_layers == value_caches.size());
|
| 106 |
+
if (num_layers == 0) {
|
| 107 |
+
return;
|
| 108 |
+
}
|
| 109 |
+
torch::Device cache_device = key_caches[0].device();
|
| 110 |
+
TORCH_CHECK(cache_device.is_cuda());
|
| 111 |
+
|
| 112 |
+
// Create data structures for the kernel.
|
| 113 |
+
// Create an array of pointers to the key and value caches.
|
| 114 |
+
int64_t key_cache_ptrs[num_layers];
|
| 115 |
+
int64_t value_cache_ptrs[num_layers];
|
| 116 |
+
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
| 117 |
+
key_cache_ptrs[layer_idx] =
|
| 118 |
+
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
| 119 |
+
value_cache_ptrs[layer_idx] =
|
| 120 |
+
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
| 124 |
+
int num_pairs = block_mapping.size(0);
|
| 125 |
+
|
| 126 |
+
// Move the data structures to the GPU.
|
| 127 |
+
// NOTE: This synchronizes the CPU and GPU.
|
| 128 |
+
torch::Tensor key_cache_ptrs_tensor =
|
| 129 |
+
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
|
| 130 |
+
.to(cache_device);
|
| 131 |
+
torch::Tensor value_cache_ptrs_tensor =
|
| 132 |
+
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
|
| 133 |
+
.to(cache_device);
|
| 134 |
+
|
| 135 |
+
// Launch the kernel.
|
| 136 |
+
const int numel_per_block = key_caches[0][0].numel();
|
| 137 |
+
dim3 grid(num_layers, num_pairs);
|
| 138 |
+
dim3 block(std::min(1024, numel_per_block));
|
| 139 |
+
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
| 140 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 141 |
+
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
| 142 |
+
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
| 143 |
+
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
| 144 |
+
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
| 145 |
+
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
| 146 |
+
block_mapping.data_ptr<int64_t>(), numel_per_block);
|
| 147 |
+
}));
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
namespace vllm {
|
| 151 |
+
|
| 152 |
+
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
| 153 |
+
__global__ void reshape_and_cache_kernel(
|
| 154 |
+
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
| 155 |
+
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
| 156 |
+
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
|
| 157 |
+
// block_size, x]
|
| 158 |
+
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
|
| 159 |
+
// block_size]
|
| 160 |
+
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
| 161 |
+
const int key_stride, const int value_stride, const int num_heads,
|
| 162 |
+
const int head_size, const int block_size, const int x,
|
| 163 |
+
const float* k_scale, const float* v_scale) {
|
| 164 |
+
const int64_t token_idx = blockIdx.x;
|
| 165 |
+
const int64_t slot_idx = slot_mapping[token_idx];
|
| 166 |
+
if (slot_idx < 0) {
|
| 167 |
+
// Padding token that should be ignored.
|
| 168 |
+
return;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
const int64_t block_idx = slot_idx / block_size;
|
| 172 |
+
const int64_t block_offset = slot_idx % block_size;
|
| 173 |
+
|
| 174 |
+
const int n = num_heads * head_size;
|
| 175 |
+
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
| 176 |
+
const int64_t src_key_idx = token_idx * key_stride + i;
|
| 177 |
+
const int64_t src_value_idx = token_idx * value_stride + i;
|
| 178 |
+
|
| 179 |
+
const int head_idx = i / head_size;
|
| 180 |
+
const int head_offset = i % head_size;
|
| 181 |
+
const int x_idx = head_offset / x;
|
| 182 |
+
const int x_offset = head_offset % x;
|
| 183 |
+
|
| 184 |
+
const int64_t tgt_key_idx =
|
| 185 |
+
block_idx * num_heads * (head_size / x) * block_size * x +
|
| 186 |
+
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
|
| 187 |
+
block_offset * x + x_offset;
|
| 188 |
+
const int64_t tgt_value_idx =
|
| 189 |
+
block_idx * num_heads * head_size * block_size +
|
| 190 |
+
head_idx * head_size * block_size + head_offset * block_size +
|
| 191 |
+
block_offset;
|
| 192 |
+
scalar_t tgt_key = key[src_key_idx];
|
| 193 |
+
scalar_t tgt_value = value[src_value_idx];
|
| 194 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
| 195 |
+
key_cache[tgt_key_idx] = tgt_key;
|
| 196 |
+
value_cache[tgt_value_idx] = tgt_value;
|
| 197 |
+
} else {
|
| 198 |
+
key_cache[tgt_key_idx] =
|
| 199 |
+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
| 200 |
+
value_cache[tgt_value_idx] =
|
| 201 |
+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
| 207 |
+
__global__ void reshape_and_cache_flash_kernel(
|
| 208 |
+
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
| 209 |
+
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
| 210 |
+
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
|
| 211 |
+
// head_size]
|
| 212 |
+
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
| 213 |
+
// head_size]
|
| 214 |
+
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
| 215 |
+
const int block_stride, const int key_stride, const int value_stride,
|
| 216 |
+
const int num_heads, const int head_size, const int block_size,
|
| 217 |
+
const float* k_scale, const float* v_scale) {
|
| 218 |
+
const int64_t token_idx = blockIdx.x;
|
| 219 |
+
const int64_t slot_idx = slot_mapping[token_idx];
|
| 220 |
+
// NOTE: slot_idx can be -1 if the token is padded
|
| 221 |
+
if (slot_idx < 0) {
|
| 222 |
+
return;
|
| 223 |
+
}
|
| 224 |
+
const int64_t block_idx = slot_idx / block_size;
|
| 225 |
+
const int64_t block_offset = slot_idx % block_size;
|
| 226 |
+
const int n = num_heads * head_size;
|
| 227 |
+
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
| 228 |
+
const int64_t src_key_idx = token_idx * key_stride + i;
|
| 229 |
+
const int64_t src_value_idx = token_idx * value_stride + i;
|
| 230 |
+
const int head_idx = i / head_size;
|
| 231 |
+
const int head_offset = i % head_size;
|
| 232 |
+
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
| 233 |
+
block_offset * num_heads * head_size +
|
| 234 |
+
head_idx * head_size + head_offset;
|
| 235 |
+
scalar_t tgt_key = key[src_key_idx];
|
| 236 |
+
scalar_t tgt_value = value[src_value_idx];
|
| 237 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
| 238 |
+
key_cache[tgt_key_value_idx] = tgt_key;
|
| 239 |
+
value_cache[tgt_key_value_idx] = tgt_value;
|
| 240 |
+
} else {
|
| 241 |
+
key_cache[tgt_key_value_idx] =
|
| 242 |
+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
| 243 |
+
value_cache[tgt_key_value_idx] =
|
| 244 |
+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
} // namespace vllm
|
| 249 |
+
|
| 250 |
+
// KV_T is the stored data type of kv-cache.
|
| 251 |
+
// CACHE_T is the data type of key and value tensors.
|
| 252 |
+
// KV_DTYPE is the real data type of kv-cache.
|
| 253 |
+
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
| 254 |
+
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
| 255 |
+
<<<grid, block, 0, stream>>>( \
|
| 256 |
+
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
| 257 |
+
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
| 258 |
+
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
| 259 |
+
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
| 260 |
+
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
| 261 |
+
num_heads, head_size, block_size, x, \
|
| 262 |
+
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
| 263 |
+
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
| 264 |
+
|
| 265 |
+
void reshape_and_cache(
|
| 266 |
+
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
| 267 |
+
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
| 268 |
+
torch::Tensor&
|
| 269 |
+
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
| 270 |
+
torch::Tensor&
|
| 271 |
+
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
| 272 |
+
torch::Tensor& slot_mapping, // [num_tokens]
|
| 273 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
| 274 |
+
torch::Tensor& v_scale) {
|
| 275 |
+
int num_tokens = key.size(0);
|
| 276 |
+
int num_heads = key.size(1);
|
| 277 |
+
int head_size = key.size(2);
|
| 278 |
+
int block_size = key_cache.size(3);
|
| 279 |
+
int x = key_cache.size(4);
|
| 280 |
+
|
| 281 |
+
int key_stride = key.stride(0);
|
| 282 |
+
int value_stride = value.stride(0);
|
| 283 |
+
|
| 284 |
+
dim3 grid(num_tokens);
|
| 285 |
+
dim3 block(std::min(num_heads * head_size, 512));
|
| 286 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
| 287 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 288 |
+
|
| 289 |
+
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
| 290 |
+
CALL_RESHAPE_AND_CACHE)
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
// KV_T is the stored data type of kv-cache.
|
| 294 |
+
// CACHE_T is the data type of key and value tensors.
|
| 295 |
+
// KV_DTYPE is the real data type of kv-cache.
|
| 296 |
+
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
| 297 |
+
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
| 298 |
+
<<<grid, block, 0, stream>>>( \
|
| 299 |
+
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
| 300 |
+
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
| 301 |
+
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
| 302 |
+
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
| 303 |
+
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
| 304 |
+
value_stride, num_heads, head_size, block_size, \
|
| 305 |
+
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
| 306 |
+
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
| 307 |
+
|
| 308 |
+
void reshape_and_cache_flash(
|
| 309 |
+
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
| 310 |
+
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
| 311 |
+
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
|
| 312 |
+
torch::Tensor&
|
| 313 |
+
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
| 314 |
+
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
| 315 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
| 316 |
+
torch::Tensor& v_scale) {
|
| 317 |
+
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
| 318 |
+
// slot_mapping.size(0) because of padding for CUDA graphs.
|
| 319 |
+
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
| 320 |
+
// both include padding.
|
| 321 |
+
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
|
| 322 |
+
// since key includes padding for CUDA graphs, while slot_mapping does not.
|
| 323 |
+
// In this case, slot_mapping.size(0) represents the actual number of tokens
|
| 324 |
+
// before padding.
|
| 325 |
+
// For compatibility with both cases, we use slot_mapping.size(0) as the
|
| 326 |
+
// number of tokens.
|
| 327 |
+
int num_tokens = slot_mapping.size(0);
|
| 328 |
+
int num_heads = key.size(1);
|
| 329 |
+
int head_size = key.size(2);
|
| 330 |
+
int block_size = key_cache.size(1);
|
| 331 |
+
|
| 332 |
+
int key_stride = key.stride(0);
|
| 333 |
+
int value_stride = value.stride(0);
|
| 334 |
+
int block_stride = key_cache.stride(0);
|
| 335 |
+
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
| 336 |
+
|
| 337 |
+
dim3 grid(num_tokens);
|
| 338 |
+
dim3 block(std::min(num_heads * head_size, 512));
|
| 339 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
| 340 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 341 |
+
|
| 342 |
+
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
| 343 |
+
CALL_RESHAPE_AND_CACHE_FLASH);
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
namespace vllm {
|
| 347 |
+
|
| 348 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
| 349 |
+
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
| 350 |
+
Tout* __restrict__ dst_cache,
|
| 351 |
+
const float scale,
|
| 352 |
+
const int64_t block_stride) {
|
| 353 |
+
const int64_t block_idx = blockIdx.x;
|
| 354 |
+
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
| 355 |
+
int64_t idx = block_idx * block_stride + i;
|
| 356 |
+
dst_cache[idx] =
|
| 357 |
+
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
|
| 358 |
+
}
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
} // namespace vllm
|
| 362 |
+
|
| 363 |
+
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
| 364 |
+
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
| 365 |
+
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
| 366 |
+
reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
|
| 367 |
+
|
| 368 |
+
// Only for testing.
|
| 369 |
+
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
| 370 |
+
const double scale, const std::string& kv_cache_dtype) {
|
| 371 |
+
torch::Device src_device = src_cache.device();
|
| 372 |
+
torch::Device dst_device = dst_cache.device();
|
| 373 |
+
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
| 374 |
+
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
| 375 |
+
TORCH_CHECK(src_device.index() == dst_device.index(),
|
| 376 |
+
"src and dst must be on the same GPU");
|
| 377 |
+
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
| 378 |
+
|
| 379 |
+
int64_t num_blocks = src_cache.size(0);
|
| 380 |
+
int64_t block_stride = src_cache.stride(0);
|
| 381 |
+
|
| 382 |
+
dim3 grid(num_blocks);
|
| 383 |
+
dim3 block(std::min(block_stride, int64_t(512)));
|
| 384 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 385 |
+
|
| 386 |
+
if (kv_cache_dtype == "auto") {
|
| 387 |
+
if (src_cache.dtype() == at::ScalarType::Float) {
|
| 388 |
+
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
|
| 389 |
+
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
| 390 |
+
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
|
| 391 |
+
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
| 392 |
+
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
|
| 393 |
+
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
| 394 |
+
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
| 395 |
+
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
| 396 |
+
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
| 397 |
+
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
| 398 |
+
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
| 399 |
+
}
|
| 400 |
+
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
|
| 401 |
+
if (src_cache.dtype() == at::ScalarType::Float) {
|
| 402 |
+
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
| 403 |
+
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
| 404 |
+
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
| 405 |
+
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
| 406 |
+
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
|
| 407 |
+
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
| 408 |
+
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
| 409 |
+
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
| 410 |
+
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
| 411 |
+
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
| 412 |
+
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
| 413 |
+
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
|
| 414 |
+
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
| 415 |
+
}
|
| 416 |
+
} else {
|
| 417 |
+
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
| 418 |
+
}
|
| 419 |
+
}
|
paged-attention/cuda_compat.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef USE_ROCM
|
| 4 |
+
#include <hip/hip_runtime.h>
|
| 5 |
+
#endif
|
| 6 |
+
|
| 7 |
+
#ifndef USE_ROCM
|
| 8 |
+
#define WARP_SIZE 32
|
| 9 |
+
#else
|
| 10 |
+
#define WARP_SIZE warpSize
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#ifndef USE_ROCM
|
| 14 |
+
#define VLLM_LDG(arg) __ldg(arg)
|
| 15 |
+
#else
|
| 16 |
+
#define VLLM_LDG(arg) *(arg)
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#ifndef USE_ROCM
|
| 20 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
| 21 |
+
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
| 22 |
+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
| 23 |
+
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
|
| 24 |
+
#else
|
| 25 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
| 26 |
+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
| 27 |
+
__shfl_xor(var, lane_mask, width)
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
+
#ifndef USE_ROCM
|
| 31 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
| 32 |
+
#else
|
| 33 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
| 34 |
+
#endif
|
| 35 |
+
|
| 36 |
+
#ifndef USE_ROCM
|
| 37 |
+
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
|
| 38 |
+
__shfl_down_sync(uint32_t(-1), var, lane_delta)
|
| 39 |
+
#else
|
| 40 |
+
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
| 41 |
+
#endif
|
| 42 |
+
|
| 43 |
+
#ifndef USE_ROCM
|
| 44 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
| 45 |
+
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
| 46 |
+
#else
|
| 47 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
| 48 |
+
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
| 49 |
+
#endif
|
paged-attention/dispatch_utils.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Adapted from
|
| 3 |
+
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
| 4 |
+
*/
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <torch/all.h>
|
| 8 |
+
|
| 9 |
+
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
| 10 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
| 11 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
| 12 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
| 13 |
+
|
| 14 |
+
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
| 15 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
| 16 |
+
|
| 17 |
+
// TODO(luka/varun): use FP8_TYPE macro after refactoring
|
| 18 |
+
#ifndef USE_ROCM
|
| 19 |
+
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
| 20 |
+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
| 21 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
| 22 |
+
#else
|
| 23 |
+
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
| 24 |
+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
|
| 25 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
| 29 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
| 30 |
+
|
| 31 |
+
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
| 32 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
| 33 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
| 34 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
| 35 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
| 36 |
+
|
| 37 |
+
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
| 38 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
| 39 |
+
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
| 40 |
+
|
| 41 |
+
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
| 42 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
| 43 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
| 44 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
| 45 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
| 46 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
| 47 |
+
|
| 48 |
+
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
| 49 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
paged-attention/quantization/fp8/amd/hip_float8.h
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef __HIPCC__
|
| 4 |
+
#include <hip/hip_runtime.h>
|
| 5 |
+
#else
|
| 6 |
+
#include <type_traits>
|
| 7 |
+
#include <stdint.h>
|
| 8 |
+
#include <math.h>
|
| 9 |
+
#include <iostream>
|
| 10 |
+
#endif
|
| 11 |
+
|
| 12 |
+
#include "hip_float8_impl.h"
|
| 13 |
+
|
| 14 |
+
struct alignas(1) hip_fp8 {
|
| 15 |
+
struct from_bits_t {};
|
| 16 |
+
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
| 17 |
+
return from_bits_t();
|
| 18 |
+
}
|
| 19 |
+
uint8_t data;
|
| 20 |
+
|
| 21 |
+
hip_fp8() = default;
|
| 22 |
+
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
| 23 |
+
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
| 24 |
+
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
| 25 |
+
: data(v) {}
|
| 26 |
+
|
| 27 |
+
#ifdef __HIP__MI300__
|
| 28 |
+
// NOTE: ON-DEVICE... always optimal bias
|
| 29 |
+
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
| 30 |
+
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
|
| 31 |
+
|
| 32 |
+
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
| 33 |
+
: hip_fp8(static_cast<float>(v)) {}
|
| 34 |
+
|
| 35 |
+
// Host only implementation using s/w simulation
|
| 36 |
+
explicit HIP_FP8_HOST
|
| 37 |
+
#else // __HIP__MI300__
|
| 38 |
+
// both Host and DEVICE for non-MI300 using s/w simulation
|
| 39 |
+
explicit HIP_FP8_HOST_DEVICE
|
| 40 |
+
#endif // __HIP__MI300__
|
| 41 |
+
hip_fp8(float v) {
|
| 42 |
+
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
|
| 43 |
+
true /*clip*/>(v);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
| 47 |
+
: hip_fp8(static_cast<float>(v)) {}
|
| 48 |
+
|
| 49 |
+
#ifdef __HIP__MI300__
|
| 50 |
+
// upcast using device specific intrinsic
|
| 51 |
+
explicit inline HIP_FP8_DEVICE operator float() const {
|
| 52 |
+
float fval;
|
| 53 |
+
uint32_t i32val = static_cast<uint32_t>(data);
|
| 54 |
+
|
| 55 |
+
// upcast
|
| 56 |
+
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
|
| 57 |
+
: "=v"(fval)
|
| 58 |
+
: "v"(i32val));
|
| 59 |
+
|
| 60 |
+
return fval;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
explicit inline HIP_FP8_HOST operator float() const
|
| 64 |
+
#else // __HIP__MI300__
|
| 65 |
+
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
| 66 |
+
#endif // __HIP__MI300__
|
| 67 |
+
{
|
| 68 |
+
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
|
| 69 |
+
data);
|
| 70 |
+
}
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
namespace std {
|
| 74 |
+
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
|
| 75 |
+
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
|
| 76 |
+
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
|
| 77 |
+
} // namespace std
|
| 78 |
+
|
| 79 |
+
// Special operator overloading
|
| 80 |
+
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
|
| 81 |
+
return os << float(f8);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// all + operator overloading with mixed types
|
| 85 |
+
// mixed types, always converts to f32, does computation in f32, and returns
|
| 86 |
+
// float
|
| 87 |
+
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
|
| 88 |
+
return (fa + float(b));
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
|
| 92 |
+
return (float(a) + fb);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
|
| 96 |
+
return hip_fp8(float(a) + float(b));
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
|
| 100 |
+
return a = hip_fp8(float(a) + float(b));
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
// overloading multiplication, always returns float,
|
| 104 |
+
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
|
| 105 |
+
return float(a) * float(b);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
|
| 109 |
+
return (a * float(b));
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
|
| 113 |
+
return (float(a) * b);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
|
| 117 |
+
return ((float)a * float(b));
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
|
| 121 |
+
return ((float)a * float(b));
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
// overloading for compare
|
| 125 |
+
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
|
| 126 |
+
return (a.data == b.data);
|
| 127 |
+
}
|
| 128 |
+
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
|
| 129 |
+
return (a.data != b.data);
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
|
| 133 |
+
return static_cast<float>(a) >= static_cast<float>(b);
|
| 134 |
+
}
|
| 135 |
+
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
|
| 136 |
+
return static_cast<float>(a) > static_cast<float>(b);
|
| 137 |
+
}
|
paged-attention/quantization/fp8/amd/hip_float8_impl.h
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#if defined(__HIPCC__) && \
|
| 4 |
+
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
| 5 |
+
#define __HIP__MI300__
|
| 6 |
+
#endif
|
| 7 |
+
|
| 8 |
+
#ifdef __HIPCC__
|
| 9 |
+
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
| 10 |
+
#define HIP_FP8_HOST __host__
|
| 11 |
+
#define HIP_FP8_DEVICE __device__
|
| 12 |
+
#else
|
| 13 |
+
#define HIP_FP8_HOST_DEVICE
|
| 14 |
+
#define HIP_FP8_HOST
|
| 15 |
+
#define HIP_FP8_DEVICE
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
namespace hip_fp8_impl {
|
| 19 |
+
|
| 20 |
+
#ifdef __HIP__MI300__
|
| 21 |
+
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
|
| 22 |
+
uint8_t i8data;
|
| 23 |
+
union {
|
| 24 |
+
float fval;
|
| 25 |
+
uint32_t i32val;
|
| 26 |
+
uint8_t i8val[4]; // NOTE: not endian independent
|
| 27 |
+
} val;
|
| 28 |
+
|
| 29 |
+
uint32_t ival = 0;
|
| 30 |
+
val.fval = v;
|
| 31 |
+
|
| 32 |
+
if ((val.i32val & 0x7F800000) !=
|
| 33 |
+
0x7F800000) { /// propagate NAN/INF, no clipping
|
| 34 |
+
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
| 38 |
+
false); // false -> WORD0
|
| 39 |
+
val.i32val = ival;
|
| 40 |
+
i8data = val.i8val[0];
|
| 41 |
+
|
| 42 |
+
return i8data;
|
| 43 |
+
}
|
| 44 |
+
#endif // __HIP__MI300__
|
| 45 |
+
|
| 46 |
+
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
| 47 |
+
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
| 48 |
+
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
| 49 |
+
#endif
|
| 50 |
+
|
| 51 |
+
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
| 52 |
+
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
|
| 53 |
+
uint32_t rng = 0) {
|
| 54 |
+
#ifdef __HIPCC__
|
| 55 |
+
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
| 56 |
+
#else
|
| 57 |
+
constexpr bool is_half = false;
|
| 58 |
+
#endif
|
| 59 |
+
constexpr bool is_float = std::is_same<T, float>::value;
|
| 60 |
+
static_assert(wm + we == 7, "wm+we==7");
|
| 61 |
+
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
| 62 |
+
|
| 63 |
+
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
| 64 |
+
uint32_t x;
|
| 65 |
+
if (sizeof(T) == 4) {
|
| 66 |
+
x = reinterpret_cast<uint32_t&>(_x);
|
| 67 |
+
} else {
|
| 68 |
+
x = reinterpret_cast<uint16_t&>(_x);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
uint32_t head, mantissa;
|
| 72 |
+
int exponent, bias;
|
| 73 |
+
uint32_t sign;
|
| 74 |
+
|
| 75 |
+
if (sizeof(T) == 4) {
|
| 76 |
+
head = x & 0xFF800000;
|
| 77 |
+
mantissa = x & 0x7FFFFF;
|
| 78 |
+
exponent = (head >> 23) & 0xFF;
|
| 79 |
+
sign = head >> 31;
|
| 80 |
+
bias = 127;
|
| 81 |
+
} else {
|
| 82 |
+
head = x & 0xFC00;
|
| 83 |
+
mantissa = x & 0x3FF;
|
| 84 |
+
exponent = (head >> 10) & 0x1F;
|
| 85 |
+
sign = head >> 15;
|
| 86 |
+
bias = 15;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
| 90 |
+
|
| 91 |
+
// Deal with inf and NaNs
|
| 92 |
+
if (negative_zero_nan) {
|
| 93 |
+
if (sizeof(T) == 4) {
|
| 94 |
+
if ((x & 0x7F800000) == 0x7F800000) {
|
| 95 |
+
return 0x80;
|
| 96 |
+
}
|
| 97 |
+
} else {
|
| 98 |
+
// if(__hisinf(x) || __hisnan(x))
|
| 99 |
+
if ((x & 0x7C00) == 0x7C00) {
|
| 100 |
+
return 0x80;
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
} else {
|
| 104 |
+
if (sizeof(T) == 4) {
|
| 105 |
+
if ((x & 0x7F800000) == 0x7F800000) {
|
| 106 |
+
return signed_inf + (mantissa != 0 ? 1 : 0);
|
| 107 |
+
}
|
| 108 |
+
} else {
|
| 109 |
+
if ((x & 0x7C00) == 0x7C00) {
|
| 110 |
+
return signed_inf + (mantissa != 0 ? 1 : 0);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
if (x == 0) {
|
| 115 |
+
return 0;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
// First need to check if it is normal or denorm as there is a difference of
|
| 119 |
+
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
| 120 |
+
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
| 121 |
+
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
| 122 |
+
// need to check whether there is carry and adjust exponent and mantissa again
|
| 123 |
+
|
| 124 |
+
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
| 125 |
+
// bits
|
| 126 |
+
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
| 127 |
+
const int f8_denormal_act_exponent =
|
| 128 |
+
1 - f8_bias; // actual exponent of f8 denormal
|
| 129 |
+
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
| 130 |
+
// f8_exponent is the converted f8 exponent with bias encoding
|
| 131 |
+
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
| 132 |
+
// the difference needs to be adjusted and mantissa shifted
|
| 133 |
+
int act_exponent, f8_exponent, exponent_diff;
|
| 134 |
+
|
| 135 |
+
if (exponent == 0) { // fp32/fp16 is in denormal.
|
| 136 |
+
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
| 137 |
+
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
| 138 |
+
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
| 139 |
+
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
| 140 |
+
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
| 141 |
+
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
| 142 |
+
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
| 143 |
+
act_exponent = exponent - bias + 1;
|
| 144 |
+
exponent_diff =
|
| 145 |
+
f8_denormal_act_exponent -
|
| 146 |
+
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
| 147 |
+
} else { // fp32/fp16 is normal with implicit 1
|
| 148 |
+
act_exponent = exponent - bias;
|
| 149 |
+
if (act_exponent <= f8_denormal_act_exponent) {
|
| 150 |
+
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
| 151 |
+
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
| 152 |
+
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
| 153 |
+
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
| 154 |
+
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
| 155 |
+
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
| 156 |
+
} else { // both fp32/fp16 and f8 are in normal range
|
| 157 |
+
exponent_diff = 0; // exponent_diff=0 does not mean there is no
|
| 158 |
+
// difference for this case, act_exponent could be
|
| 159 |
+
// larger. Just that it does not need shift mantissa
|
| 160 |
+
}
|
| 161 |
+
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
| 165 |
+
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
| 166 |
+
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
| 167 |
+
done before we shift right as shift right could rip off some residual part
|
| 168 |
+
and make something not midpoint look like midpoint. For example, the fp16
|
| 169 |
+
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
| 170 |
+
shift right by 4 bits, it would look like midpoint.
|
| 171 |
+
*/
|
| 172 |
+
|
| 173 |
+
if (exponent_diff > 0) {
|
| 174 |
+
mantissa >>= exponent_diff;
|
| 175 |
+
} else if (exponent_diff == -1) {
|
| 176 |
+
mantissa <<= -exponent_diff;
|
| 177 |
+
}
|
| 178 |
+
bool implicit_one = mantissa & (1 << mfmt);
|
| 179 |
+
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
| 180 |
+
// to denorm exponent
|
| 181 |
+
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
|
| 182 |
+
f8_bias - (implicit_one ? 0 : 1);
|
| 183 |
+
|
| 184 |
+
// Now we have the exponent and mantissa adjusted
|
| 185 |
+
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
| 186 |
+
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
|
| 187 |
+
// that is not truncated is 1
|
| 188 |
+
mantissa +=
|
| 189 |
+
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
|
| 190 |
+
drop_mask;
|
| 191 |
+
|
| 192 |
+
// Now we deal with overflow
|
| 193 |
+
if (f8_exponent == 0) {
|
| 194 |
+
if ((1 << mfmt) & mantissa) {
|
| 195 |
+
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
| 196 |
+
}
|
| 197 |
+
} else {
|
| 198 |
+
if ((1 << (mfmt + 1)) & mantissa) {
|
| 199 |
+
mantissa >>= 1;
|
| 200 |
+
f8_exponent++;
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
mantissa >>= (mfmt - wm);
|
| 205 |
+
|
| 206 |
+
// above range: quantize to maximum possible float of the same sign
|
| 207 |
+
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
| 208 |
+
if (f8_exponent > max_exp) {
|
| 209 |
+
if (clip) {
|
| 210 |
+
mantissa = (1 << wm) - 1;
|
| 211 |
+
f8_exponent = max_exp;
|
| 212 |
+
} else {
|
| 213 |
+
return signed_inf;
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
if (f8_exponent == 0 && mantissa == 0) {
|
| 218 |
+
return negative_zero_nan ? 0 : (sign << 7);
|
| 219 |
+
}
|
| 220 |
+
mantissa &= (1 << wm) - 1;
|
| 221 |
+
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
| 225 |
+
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
|
| 226 |
+
#ifdef __HIPCC__
|
| 227 |
+
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
| 228 |
+
#else
|
| 229 |
+
constexpr bool is_half = false;
|
| 230 |
+
#endif
|
| 231 |
+
constexpr bool is_float = std::is_same<T, float>::value;
|
| 232 |
+
static_assert(is_half || is_float, "only half and float are supported");
|
| 233 |
+
|
| 234 |
+
constexpr int weo = is_half ? 5 : 8;
|
| 235 |
+
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
| 236 |
+
|
| 237 |
+
T fInf, fNegInf, fNaN, fNeg0;
|
| 238 |
+
|
| 239 |
+
#ifdef __HIPCC__
|
| 240 |
+
if (is_half) {
|
| 241 |
+
const uint16_t ihInf = 0x7C00;
|
| 242 |
+
const uint16_t ihNegInf = 0xFC00;
|
| 243 |
+
const uint16_t ihNaN = 0x7C01;
|
| 244 |
+
const uint16_t ihNeg0 = 0x8000;
|
| 245 |
+
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
| 246 |
+
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
| 247 |
+
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
| 248 |
+
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
| 249 |
+
} else
|
| 250 |
+
#endif
|
| 251 |
+
if (is_float) {
|
| 252 |
+
const uint32_t ifInf = 0x7F800000;
|
| 253 |
+
const uint32_t ifNegInf = 0xFF800000;
|
| 254 |
+
const uint32_t ifNaN = 0x7F800001;
|
| 255 |
+
const uint32_t ifNeg0 = 0x80000000;
|
| 256 |
+
fInf = reinterpret_cast<const float&>(ifInf);
|
| 257 |
+
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
| 258 |
+
fNaN = reinterpret_cast<const float&>(ifNaN);
|
| 259 |
+
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
if (x == 0) {
|
| 263 |
+
return 0;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
uint32_t sign = x >> 7;
|
| 267 |
+
uint32_t mantissa = x & ((1 << wm) - 1);
|
| 268 |
+
int exponent = (x & 0x7F) >> wm;
|
| 269 |
+
if (negative_zero_nan) {
|
| 270 |
+
if (x == 0x80) {
|
| 271 |
+
return fNaN;
|
| 272 |
+
}
|
| 273 |
+
} else {
|
| 274 |
+
if (x == 0x80) {
|
| 275 |
+
return fNeg0;
|
| 276 |
+
}
|
| 277 |
+
if (exponent == ((1 << we) - 1)) {
|
| 278 |
+
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
| 279 |
+
}
|
| 280 |
+
}
|
| 281 |
+
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
| 282 |
+
if (we == 5 && is_half && !negative_zero_nan) {
|
| 283 |
+
retval = x << 8;
|
| 284 |
+
return reinterpret_cast<const T&>(retval);
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
const int exp_low_cutoff =
|
| 288 |
+
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
| 289 |
+
|
| 290 |
+
// subnormal input
|
| 291 |
+
if (exponent == 0) {
|
| 292 |
+
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
| 293 |
+
int sh = 1 + clz(mantissa) - (32 - wm);
|
| 294 |
+
mantissa <<= sh;
|
| 295 |
+
exponent += 1 - sh;
|
| 296 |
+
mantissa &= ((1 << wm) - 1);
|
| 297 |
+
}
|
| 298 |
+
exponent += exp_low_cutoff - 1;
|
| 299 |
+
mantissa <<= wmo - wm;
|
| 300 |
+
|
| 301 |
+
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
| 302 |
+
if (exponent <= 0) {
|
| 303 |
+
mantissa |= 1 << wmo;
|
| 304 |
+
mantissa >>= 1 - exponent;
|
| 305 |
+
exponent = 0;
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
if (sizeof(T) == 2) {
|
| 309 |
+
retval = (sign << 15) | (exponent << 10) | mantissa;
|
| 310 |
+
} else {
|
| 311 |
+
retval = (sign << 31) | (exponent << 23) | mantissa;
|
| 312 |
+
}
|
| 313 |
+
return reinterpret_cast<const T&>(retval);
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
} // namespace hip_fp8_impl
|
paged-attention/quantization/fp8/amd/quant_utils.cuh
ADDED
|
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include "hip_float8.h"
|
| 3 |
+
|
| 4 |
+
#include <hip/hip_fp16.h>
|
| 5 |
+
#include <hip/hip_bf16.h>
|
| 6 |
+
#include <hip/hip_bfloat16.h>
|
| 7 |
+
|
| 8 |
+
#include "../../../attention/dtype_fp8.cuh"
|
| 9 |
+
#include "../../../attention/dtype_float32.cuh"
|
| 10 |
+
#include "../../../attention/dtype_bfloat16.cuh"
|
| 11 |
+
|
| 12 |
+
namespace vllm {
|
| 13 |
+
#ifdef USE_ROCM
|
| 14 |
+
|
| 15 |
+
namespace fp8 {
|
| 16 |
+
#ifdef ENABLE_FP8
|
| 17 |
+
|
| 18 |
+
template <typename Tout, typename Tin>
|
| 19 |
+
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
| 20 |
+
return x;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
template <typename Tout, typename Tin>
|
| 24 |
+
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
| 25 |
+
const float scale) {
|
| 26 |
+
return x;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// fp8 -> half
|
| 30 |
+
template <>
|
| 31 |
+
__inline__ __device__ uint16_t
|
| 32 |
+
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
| 33 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
| 34 |
+
__half_raw res;
|
| 35 |
+
res.data = static_cast<float>(f8);
|
| 36 |
+
return res.x;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// fp8x2 -> half2
|
| 40 |
+
template <>
|
| 41 |
+
__inline__ __device__ uint32_t
|
| 42 |
+
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
| 43 |
+
#if defined(__HIP__MI300__) && \
|
| 44 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
| 45 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
| 46 |
+
union {
|
| 47 |
+
__half2_raw h2r;
|
| 48 |
+
uint32_t ui32;
|
| 49 |
+
} tmp;
|
| 50 |
+
tmp.h2r.x.data = f2[0];
|
| 51 |
+
tmp.h2r.y.data = f2[1];
|
| 52 |
+
return tmp.ui32;
|
| 53 |
+
#else
|
| 54 |
+
union {
|
| 55 |
+
uint16_t u16[2];
|
| 56 |
+
uint32_t u32;
|
| 57 |
+
} tmp;
|
| 58 |
+
|
| 59 |
+
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
| 60 |
+
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
| 61 |
+
return tmp.u32;
|
| 62 |
+
#endif
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
// fp8x4 -> half2x2
|
| 66 |
+
template <>
|
| 67 |
+
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
|
| 68 |
+
union {
|
| 69 |
+
uint2 u32x2;
|
| 70 |
+
uint32_t u32[2];
|
| 71 |
+
} tmp;
|
| 72 |
+
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
| 73 |
+
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
| 74 |
+
return tmp.u32x2;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
// fp8x8 -> half2x4
|
| 78 |
+
template <>
|
| 79 |
+
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
|
| 80 |
+
union {
|
| 81 |
+
uint4 u64x2;
|
| 82 |
+
uint2 u64[2];
|
| 83 |
+
} tmp;
|
| 84 |
+
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
| 85 |
+
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
| 86 |
+
return tmp.u64x2;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
using __nv_bfloat16 = __hip_bfloat16;
|
| 90 |
+
|
| 91 |
+
// fp8 -> __nv_bfloat16
|
| 92 |
+
template <>
|
| 93 |
+
__inline__ __device__ __nv_bfloat16
|
| 94 |
+
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
| 95 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
| 96 |
+
float f{f8};
|
| 97 |
+
return __float2bfloat16(f);
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
using __nv_bfloat162 = __hip_bfloat162;
|
| 101 |
+
|
| 102 |
+
// fp8x2 -> __nv_bfloat162
|
| 103 |
+
template <>
|
| 104 |
+
__inline__ __device__ __nv_bfloat162
|
| 105 |
+
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
|
| 106 |
+
__nv_bfloat162 res;
|
| 107 |
+
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
| 108 |
+
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
| 109 |
+
return res;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
// fp8x4 -> bf16_4_t
|
| 113 |
+
template <>
|
| 114 |
+
__inline__ __device__ bf16_4_t
|
| 115 |
+
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
|
| 116 |
+
bf16_4_t res;
|
| 117 |
+
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
| 118 |
+
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
| 119 |
+
return res;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
// fp8x8 -> bf16_8_t
|
| 123 |
+
template <>
|
| 124 |
+
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
| 125 |
+
bf16_4_t tmp1, tmp2;
|
| 126 |
+
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
| 127 |
+
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
| 128 |
+
bf16_8_t res;
|
| 129 |
+
res.x = tmp1.x;
|
| 130 |
+
res.y = tmp1.y;
|
| 131 |
+
res.z = tmp2.x;
|
| 132 |
+
res.w = tmp2.y;
|
| 133 |
+
return res;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// fp8 -> float
|
| 137 |
+
template <>
|
| 138 |
+
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
| 139 |
+
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
| 140 |
+
return static_cast<float>(fp8);
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
// fp8x2 -> float2
|
| 144 |
+
template <>
|
| 145 |
+
__inline__ __device__ float2
|
| 146 |
+
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
| 147 |
+
#if defined(__HIP__MI300__) && \
|
| 148 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
| 149 |
+
float2 res;
|
| 150 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
| 151 |
+
res.x = f2[0];
|
| 152 |
+
res.y = f2[1];
|
| 153 |
+
return res;
|
| 154 |
+
#else
|
| 155 |
+
float2 res;
|
| 156 |
+
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
| 157 |
+
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
| 158 |
+
return res;
|
| 159 |
+
#endif
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// fp8x4 -> float4
|
| 163 |
+
template <>
|
| 164 |
+
__inline__ __device__ Float4_
|
| 165 |
+
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
| 166 |
+
Float4_ res;
|
| 167 |
+
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
| 168 |
+
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
| 169 |
+
return res;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
// fp8x8 -> float8
|
| 173 |
+
template <>
|
| 174 |
+
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
| 175 |
+
Float4_ tmp1, tmp2;
|
| 176 |
+
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
| 177 |
+
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
| 178 |
+
Float8_ res;
|
| 179 |
+
res.x = tmp1.x;
|
| 180 |
+
res.y = tmp1.y;
|
| 181 |
+
res.z = tmp2.x;
|
| 182 |
+
res.w = tmp2.y;
|
| 183 |
+
return res;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
// half -> fp8
|
| 187 |
+
template <>
|
| 188 |
+
__inline__ __device__ uint8_t
|
| 189 |
+
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
| 190 |
+
__half_raw tmp;
|
| 191 |
+
tmp.x = a;
|
| 192 |
+
|
| 193 |
+
hip_fp8 f8{static_cast<float>(tmp.data)};
|
| 194 |
+
return f8.data;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
// bf16 -> fp8
|
| 198 |
+
template <>
|
| 199 |
+
__inline__ __device__ uint8_t
|
| 200 |
+
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
| 201 |
+
hip_fp8 res{__bfloat162float(a)};
|
| 202 |
+
return res.data;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
// float -> fp8
|
| 206 |
+
template <>
|
| 207 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
| 208 |
+
hip_fp8 f8(a);
|
| 209 |
+
return f8.data;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
// fp8x4 -> float4
|
| 213 |
+
template <>
|
| 214 |
+
__inline__ __device__ float4
|
| 215 |
+
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
| 216 |
+
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
| 217 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
| 218 |
+
return res;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// float2 -> half2
|
| 222 |
+
template <>
|
| 223 |
+
__inline__ __device__ uint32_t
|
| 224 |
+
vec_conversion<uint32_t, float2>(const float2& a) {
|
| 225 |
+
union {
|
| 226 |
+
half2 float16;
|
| 227 |
+
uint32_t uint32;
|
| 228 |
+
};
|
| 229 |
+
|
| 230 |
+
float16 = __float22half2_rn(a);
|
| 231 |
+
return uint32;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
// Float4 -> half2x2
|
| 235 |
+
template <>
|
| 236 |
+
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
|
| 237 |
+
uint2 b;
|
| 238 |
+
float2 val;
|
| 239 |
+
val.x = a.x.x;
|
| 240 |
+
val.y = a.x.y;
|
| 241 |
+
b.x = vec_conversion<uint32_t, float2>(val);
|
| 242 |
+
|
| 243 |
+
val.x = a.y.x;
|
| 244 |
+
val.y = a.y.y;
|
| 245 |
+
b.y = vec_conversion<uint32_t, float2>(val);
|
| 246 |
+
return b;
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
// Float4 -> float4
|
| 250 |
+
template <>
|
| 251 |
+
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
|
| 252 |
+
float4 b;
|
| 253 |
+
b.x = a.x.x;
|
| 254 |
+
b.y = a.x.y;
|
| 255 |
+
b.z = a.y.x;
|
| 256 |
+
b.w = a.y.y;
|
| 257 |
+
return b;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
// Float8 -> half2x4
|
| 261 |
+
template <>
|
| 262 |
+
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
|
| 263 |
+
uint4 b;
|
| 264 |
+
b.x = vec_conversion<uint32_t, float2>(a.x);
|
| 265 |
+
b.y = vec_conversion<uint32_t, float2>(a.y);
|
| 266 |
+
b.z = vec_conversion<uint32_t, float2>(a.z);
|
| 267 |
+
b.w = vec_conversion<uint32_t, float2>(a.w);
|
| 268 |
+
return b;
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
// float2 -> bfloat162
|
| 272 |
+
template <>
|
| 273 |
+
__inline__ __device__ __nv_bfloat162
|
| 274 |
+
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
|
| 275 |
+
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
| 276 |
+
return b;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
// Float4 -> bfloat162x2
|
| 280 |
+
template <>
|
| 281 |
+
__inline__ __device__ bf16_4_t
|
| 282 |
+
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
|
| 283 |
+
bf16_4_t b;
|
| 284 |
+
b.x = __float22bfloat162_rn(a.x);
|
| 285 |
+
b.y = __float22bfloat162_rn(a.y);
|
| 286 |
+
return b;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
// Float8 -> bfloat162x4
|
| 290 |
+
template <>
|
| 291 |
+
__inline__ __device__ bf16_8_t
|
| 292 |
+
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
| 293 |
+
bf16_8_t b;
|
| 294 |
+
b.x = __float22bfloat162_rn(a.x);
|
| 295 |
+
b.y = __float22bfloat162_rn(a.y);
|
| 296 |
+
b.z = __float22bfloat162_rn(a.z);
|
| 297 |
+
b.w = __float22bfloat162_rn(a.w);
|
| 298 |
+
return b;
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
/* Scaled and vectorized conversions, for data exchange between high and low
|
| 302 |
+
precision domains
|
| 303 |
+
|
| 304 |
+
Convention of the scale in API, e.g: FP8_data = Quantization(
|
| 305 |
+
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
|
| 306 |
+
scale => HP
|
| 307 |
+
|
| 308 |
+
*/
|
| 309 |
+
|
| 310 |
+
// fp8 -> half
|
| 311 |
+
template <>
|
| 312 |
+
__inline__ __device__ uint16_t
|
| 313 |
+
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
|
| 314 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
| 315 |
+
__half_raw res;
|
| 316 |
+
res.data = static_cast<float>(f8) * scale;
|
| 317 |
+
return res.x;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
// fp8x2 -> half2
|
| 321 |
+
template <>
|
| 322 |
+
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
| 323 |
+
const uint16_t& a, const float scale) {
|
| 324 |
+
#if defined(__HIP__MI300__) && \
|
| 325 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
| 326 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
| 327 |
+
union {
|
| 328 |
+
__half2_raw h2r;
|
| 329 |
+
uint32_t ui32;
|
| 330 |
+
} tmp;
|
| 331 |
+
tmp.h2r.x.data = f2[0] * scale;
|
| 332 |
+
tmp.h2r.y.data = f2[1] * scale;
|
| 333 |
+
return tmp.ui32;
|
| 334 |
+
#else
|
| 335 |
+
union {
|
| 336 |
+
uint16_t u16[2];
|
| 337 |
+
uint32_t u32;
|
| 338 |
+
} tmp;
|
| 339 |
+
|
| 340 |
+
tmp.u16[0] =
|
| 341 |
+
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
| 342 |
+
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
|
| 343 |
+
static_cast<uint8_t>(a >> 8U), scale);
|
| 344 |
+
return tmp.u32;
|
| 345 |
+
#endif
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
// fp8x4 -> half2x2
|
| 349 |
+
template <>
|
| 350 |
+
__inline__ __device__ uint2
|
| 351 |
+
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
|
| 352 |
+
union {
|
| 353 |
+
uint2 u32x2;
|
| 354 |
+
uint32_t u32[2];
|
| 355 |
+
} tmp;
|
| 356 |
+
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
| 357 |
+
tmp.u32[1] =
|
| 358 |
+
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
| 359 |
+
return tmp.u32x2;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
// fp8x8 -> half2x4
|
| 363 |
+
template <>
|
| 364 |
+
__inline__ __device__ uint4
|
| 365 |
+
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
|
| 366 |
+
union {
|
| 367 |
+
uint4 u64x2;
|
| 368 |
+
uint2 u64[2];
|
| 369 |
+
} tmp;
|
| 370 |
+
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
| 371 |
+
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
| 372 |
+
return tmp.u64x2;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
using __nv_bfloat16 = __hip_bfloat16;
|
| 376 |
+
|
| 377 |
+
// fp8 -> __nv_bfloat16
|
| 378 |
+
template <>
|
| 379 |
+
__inline__ __device__ __nv_bfloat16
|
| 380 |
+
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
|
| 381 |
+
const float scale) {
|
| 382 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
| 383 |
+
float f{f8};
|
| 384 |
+
return __float2bfloat16(f * scale);
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
using __nv_bfloat162 = __hip_bfloat162;
|
| 388 |
+
|
| 389 |
+
// fp8x2 -> __nv_bfloat162
|
| 390 |
+
template <>
|
| 391 |
+
__inline__ __device__ __nv_bfloat162
|
| 392 |
+
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
| 393 |
+
const float scale) {
|
| 394 |
+
__nv_bfloat162 res;
|
| 395 |
+
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
| 396 |
+
res.y =
|
| 397 |
+
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
| 398 |
+
return res;
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
// fp8x4 -> bf16_4_t
|
| 402 |
+
template <>
|
| 403 |
+
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
| 404 |
+
const uint32_t& a, const float scale) {
|
| 405 |
+
bf16_4_t res;
|
| 406 |
+
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
| 407 |
+
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
| 408 |
+
scale);
|
| 409 |
+
return res;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
// fp8x8 -> bf16_8_t
|
| 413 |
+
template <>
|
| 414 |
+
__inline__ __device__ bf16_8_t
|
| 415 |
+
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
| 416 |
+
bf16_4_t tmp1, tmp2;
|
| 417 |
+
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
| 418 |
+
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
| 419 |
+
bf16_8_t res;
|
| 420 |
+
res.x = tmp1.x;
|
| 421 |
+
res.y = tmp1.y;
|
| 422 |
+
res.z = tmp2.x;
|
| 423 |
+
res.w = tmp2.y;
|
| 424 |
+
return res;
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
// fp8 -> float
|
| 428 |
+
template <>
|
| 429 |
+
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
| 430 |
+
const uint8_t& a, const float scale) {
|
| 431 |
+
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
| 432 |
+
return static_cast<float>(fp8) * scale;
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
// fp8x2 -> float2
|
| 436 |
+
template <>
|
| 437 |
+
__inline__ __device__ float2
|
| 438 |
+
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
|
| 439 |
+
#if defined(__HIP__MI300__) && \
|
| 440 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
| 441 |
+
float2 res;
|
| 442 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
| 443 |
+
res.x = f2[0] * scale;
|
| 444 |
+
res.y = f2[1] * scale;
|
| 445 |
+
return res;
|
| 446 |
+
#else
|
| 447 |
+
float2 res;
|
| 448 |
+
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
| 449 |
+
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
|
| 450 |
+
scale);
|
| 451 |
+
return res;
|
| 452 |
+
#endif
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
// fp8x4 -> float4
|
| 456 |
+
template <>
|
| 457 |
+
__inline__ __device__ Float4_
|
| 458 |
+
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
| 459 |
+
Float4_ res;
|
| 460 |
+
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
| 461 |
+
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
| 462 |
+
return res;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
// fp8x8 -> float8
|
| 466 |
+
template <>
|
| 467 |
+
__inline__ __device__ Float8_
|
| 468 |
+
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
| 469 |
+
Float4_ tmp1, tmp2;
|
| 470 |
+
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
| 471 |
+
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
| 472 |
+
Float8_ res;
|
| 473 |
+
res.x = tmp1.x;
|
| 474 |
+
res.y = tmp1.y;
|
| 475 |
+
res.z = tmp2.x;
|
| 476 |
+
res.w = tmp2.y;
|
| 477 |
+
return res;
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
/* Quantize(HP / scale) => FP8 */
|
| 481 |
+
|
| 482 |
+
// TODO(Hai): vectorized to add
|
| 483 |
+
|
| 484 |
+
// half -> fp8
|
| 485 |
+
template <>
|
| 486 |
+
__inline__ __device__ uint8_t
|
| 487 |
+
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
|
| 488 |
+
__half_raw tmp;
|
| 489 |
+
tmp.x = a;
|
| 490 |
+
|
| 491 |
+
hip_fp8 f8{static_cast<float>(tmp.data) / scale};
|
| 492 |
+
return f8.data;
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
// bf16 -> fp8
|
| 496 |
+
template <>
|
| 497 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
| 498 |
+
const __nv_bfloat16& a, const float scale) {
|
| 499 |
+
hip_fp8 res{__bfloat162float(a) / scale};
|
| 500 |
+
return res.data;
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
// float -> fp8
|
| 504 |
+
template <>
|
| 505 |
+
__inline__ __device__ uint8_t
|
| 506 |
+
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
|
| 507 |
+
hip_fp8 f8(a / scale);
|
| 508 |
+
return f8.data;
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
// fp8x4 -> float4
|
| 512 |
+
template <>
|
| 513 |
+
__inline__ __device__ float4
|
| 514 |
+
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
|
| 515 |
+
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
| 516 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
| 517 |
+
return res;
|
| 518 |
+
}
|
| 519 |
+
#endif // ENABLE_FP8
|
| 520 |
+
|
| 521 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
| 522 |
+
__inline__ __device__ Tout convert(const Tin& x) {
|
| 523 |
+
#ifdef ENABLE_FP8
|
| 524 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
| 525 |
+
return vec_conversion<Tout, Tin>(x);
|
| 526 |
+
}
|
| 527 |
+
#endif
|
| 528 |
+
assert(false);
|
| 529 |
+
return {}; // Squash missing return statement warning
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
| 533 |
+
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
| 534 |
+
#ifdef ENABLE_FP8
|
| 535 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
| 536 |
+
return scaled_vec_conversion<Tout, Tin>(x, scale);
|
| 537 |
+
}
|
| 538 |
+
#endif
|
| 539 |
+
assert(false);
|
| 540 |
+
return {}; // Squash missing return statement warning
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
// The following macro is used to dispatch the conversion function based on
|
| 544 |
+
// the data type of the key and value cache. The FN is a macro that calls a
|
| 545 |
+
// function with template<typename scalar_t, typename cache_t,
|
| 546 |
+
// Fp8KVCacheDataType kv_dt>.
|
| 547 |
+
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
| 548 |
+
if (KV_DTYPE == "auto") { \
|
| 549 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
| 550 |
+
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
| 551 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
| 552 |
+
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
| 553 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
| 554 |
+
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
| 555 |
+
} else { \
|
| 556 |
+
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
| 557 |
+
} \
|
| 558 |
+
} else { \
|
| 559 |
+
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
| 560 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
| 561 |
+
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
| 562 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
| 563 |
+
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
| 564 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
| 565 |
+
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
| 566 |
+
} else { \
|
| 567 |
+
TORCH_CHECK(false, \
|
| 568 |
+
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
| 569 |
+
} \
|
| 570 |
+
} else { \
|
| 571 |
+
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
| 572 |
+
} \
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
} // namespace fp8
|
| 576 |
+
#endif // USE_ROCM
|
| 577 |
+
} // namespace vllm
|
paged-attention/quantization/fp8/nvidia/quant_utils.cuh
ADDED
|
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "../../../attention/attention_dtypes.h"
|
| 4 |
+
#include <assert.h>
|
| 5 |
+
#include <float.h>
|
| 6 |
+
#include <stdint.h>
|
| 7 |
+
#include <type_traits>
|
| 8 |
+
|
| 9 |
+
namespace vllm {
|
| 10 |
+
#ifndef USE_ROCM
|
| 11 |
+
|
| 12 |
+
namespace fp8 {
|
| 13 |
+
#ifdef ENABLE_FP8
|
| 14 |
+
|
| 15 |
+
#if 0 // Disable the following code to reduce the binary size.
|
| 16 |
+
template <typename Tout, typename Tin>
|
| 17 |
+
__inline__ __device__ Tout
|
| 18 |
+
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
| 19 |
+
return x;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
// fp8 -> half
|
| 23 |
+
template <>
|
| 24 |
+
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
| 25 |
+
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 26 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
| 27 |
+
return res.x;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// fp8x2 -> half2
|
| 31 |
+
template <>
|
| 32 |
+
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
|
| 33 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 34 |
+
union {
|
| 35 |
+
uint16_t u16[2];
|
| 36 |
+
uint32_t u32;
|
| 37 |
+
} tmp;
|
| 38 |
+
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
| 39 |
+
tmp.u16[0] = res.x;
|
| 40 |
+
tmp.u16[1] = res.y;
|
| 41 |
+
return tmp.u32;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
// fp8x4 -> half2x2
|
| 45 |
+
template <>
|
| 46 |
+
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
|
| 47 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 48 |
+
union {
|
| 49 |
+
uint2 u32x2;
|
| 50 |
+
uint32_t u32[2];
|
| 51 |
+
} tmp;
|
| 52 |
+
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
|
| 53 |
+
tmp.u32[1] =
|
| 54 |
+
vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
| 55 |
+
return tmp.u32x2;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// fp8x8 -> half2x4
|
| 59 |
+
template <>
|
| 60 |
+
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(
|
| 61 |
+
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 62 |
+
union {
|
| 63 |
+
uint4 u64x2;
|
| 64 |
+
uint2 u64[2];
|
| 65 |
+
} tmp;
|
| 66 |
+
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
|
| 67 |
+
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
|
| 68 |
+
return tmp.u64x2;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
// fp8 -> __nv_bfloat16
|
| 72 |
+
template <>
|
| 73 |
+
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
|
| 74 |
+
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 75 |
+
// Note there is no direct convert function from fp8 to bf16.
|
| 76 |
+
// fp8 -> half
|
| 77 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
| 78 |
+
// half -> float -> bf16
|
| 79 |
+
float tmp = half_to_float(res.x);
|
| 80 |
+
return __float2bfloat16(tmp);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
// fp8x2 -> __nv_bfloat162
|
| 84 |
+
template <>
|
| 85 |
+
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
|
| 86 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 87 |
+
__nv_bfloat162 res;
|
| 88 |
+
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
|
| 89 |
+
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
|
| 90 |
+
return res;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// fp8x4 -> bf16_4_t
|
| 94 |
+
template <>
|
| 95 |
+
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
|
| 96 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 97 |
+
bf16_4_t res;
|
| 98 |
+
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
|
| 99 |
+
res.y =
|
| 100 |
+
vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
| 101 |
+
return res;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// fp8x8 -> bf16_8_t
|
| 105 |
+
template <>
|
| 106 |
+
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
|
| 107 |
+
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 108 |
+
bf16_4_t tmp1, tmp2;
|
| 109 |
+
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
|
| 110 |
+
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
|
| 111 |
+
bf16_8_t res;
|
| 112 |
+
res.x = tmp1.x;
|
| 113 |
+
res.y = tmp1.y;
|
| 114 |
+
res.z = tmp2.x;
|
| 115 |
+
res.w = tmp2.y;
|
| 116 |
+
return res;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// fp8 -> float
|
| 120 |
+
template <>
|
| 121 |
+
__inline__ __device__ float
|
| 122 |
+
vec_conversion<float, uint8_t>(const uint8_t &a,
|
| 123 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 124 |
+
// fp8 -> half
|
| 125 |
+
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
|
| 126 |
+
// half -> float
|
| 127 |
+
return half_to_float(tmp);
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
// fp8x2 -> float2
|
| 131 |
+
template <>
|
| 132 |
+
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(
|
| 133 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 134 |
+
// fp8x2 -> half2
|
| 135 |
+
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
|
| 136 |
+
// half2 -> float2
|
| 137 |
+
return half2_to_float2(tmp);
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// fp8x4 -> float4
|
| 141 |
+
template <>
|
| 142 |
+
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
|
| 143 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 144 |
+
Float4_ res;
|
| 145 |
+
res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
|
| 146 |
+
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
| 147 |
+
return res;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
// fp8x8 -> float8
|
| 151 |
+
template <>
|
| 152 |
+
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
|
| 153 |
+
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 154 |
+
Float4_ tmp1, tmp2;
|
| 155 |
+
tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
|
| 156 |
+
tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
|
| 157 |
+
Float8_ res;
|
| 158 |
+
res.x = tmp1.x;
|
| 159 |
+
res.y = tmp1.y;
|
| 160 |
+
res.z = tmp2.x;
|
| 161 |
+
res.w = tmp2.y;
|
| 162 |
+
return res;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
// half -> fp8
|
| 166 |
+
template <>
|
| 167 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
|
| 168 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 169 |
+
__half_raw tmp;
|
| 170 |
+
tmp.x = a;
|
| 171 |
+
__nv_fp8_storage_t res =
|
| 172 |
+
__nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
|
| 173 |
+
return (uint8_t)res;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
// bf16 -> fp8
|
| 177 |
+
template <>
|
| 178 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
|
| 179 |
+
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 180 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 181 |
+
assert(false);
|
| 182 |
+
#else
|
| 183 |
+
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
|
| 184 |
+
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
|
| 185 |
+
return (uint8_t)res;
|
| 186 |
+
#endif
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
// float -> fp8
|
| 190 |
+
template <>
|
| 191 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
|
| 192 |
+
const float &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 193 |
+
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
|
| 194 |
+
return (uint8_t)res;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
// fp8x4 -> float4
|
| 198 |
+
template <>
|
| 199 |
+
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(
|
| 200 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 201 |
+
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
|
| 202 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
| 203 |
+
return res;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
template <>
|
| 207 |
+
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
|
| 208 |
+
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 209 |
+
union {
|
| 210 |
+
half2 float16;
|
| 211 |
+
uint32_t uint32;
|
| 212 |
+
};
|
| 213 |
+
|
| 214 |
+
float16 = __float22half2_rn(a);
|
| 215 |
+
return uint32;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
template <>
|
| 219 |
+
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
|
| 220 |
+
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 221 |
+
uint2 b;
|
| 222 |
+
float2 val;
|
| 223 |
+
val.x = a.x.x;
|
| 224 |
+
val.y = a.x.y;
|
| 225 |
+
b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
|
| 226 |
+
|
| 227 |
+
val.x = a.y.x;
|
| 228 |
+
val.y = a.y.y;
|
| 229 |
+
b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
|
| 230 |
+
|
| 231 |
+
return b;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
template <>
|
| 235 |
+
__inline__ __device__ float4 vec_conversion<float4, Float4_>(
|
| 236 |
+
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 237 |
+
float4 b;
|
| 238 |
+
b.x = a.x.x;
|
| 239 |
+
b.y = a.x.y;
|
| 240 |
+
b.z = a.y.x;
|
| 241 |
+
b.w = a.y.y;
|
| 242 |
+
return b;
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
template <>
|
| 246 |
+
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
|
| 247 |
+
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 248 |
+
uint4 b;
|
| 249 |
+
b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
|
| 250 |
+
b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
|
| 251 |
+
b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
|
| 252 |
+
b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
|
| 253 |
+
return b;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
template <>
|
| 257 |
+
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
|
| 258 |
+
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 259 |
+
__nv_bfloat162 b;
|
| 260 |
+
from_float(b, a);
|
| 261 |
+
return b;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
template <>
|
| 265 |
+
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
|
| 266 |
+
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 267 |
+
bf16_4_t b;
|
| 268 |
+
from_float(b, a);
|
| 269 |
+
return b;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
template <>
|
| 273 |
+
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
|
| 274 |
+
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
| 275 |
+
bf16_8_t b;
|
| 276 |
+
from_float(b, a);
|
| 277 |
+
return b;
|
| 278 |
+
}
|
| 279 |
+
#endif
|
| 280 |
+
|
| 281 |
+
/* Scaled and vectorized conversions, for data exchange between high and low
|
| 282 |
+
precision domains Convention of the scale in API, e.g: FP8_data =
|
| 283 |
+
Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
|
| 284 |
+
Dequant(FP8) * scale => HP
|
| 285 |
+
*/
|
| 286 |
+
|
| 287 |
+
template <typename Tout, typename Tin>
|
| 288 |
+
__inline__ __device__ Tout scaled_vec_conversion(
|
| 289 |
+
const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
|
| 290 |
+
return x;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
// fp8 -> half
|
| 294 |
+
template <>
|
| 295 |
+
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
|
| 296 |
+
const uint8_t& a, const float scale,
|
| 297 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 298 |
+
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
| 299 |
+
return float_to_half(half_to_float(tmp.x) * scale);
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
// fp8x2 -> half2
|
| 303 |
+
template <>
|
| 304 |
+
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
| 305 |
+
const uint16_t& a, const float scale,
|
| 306 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 307 |
+
union {
|
| 308 |
+
uint16_t u16[2];
|
| 309 |
+
uint32_t u32;
|
| 310 |
+
} tmp;
|
| 311 |
+
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
| 312 |
+
tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
|
| 313 |
+
tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
|
| 314 |
+
return tmp.u32;
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
// fp8x4 -> half2x2
|
| 318 |
+
template <>
|
| 319 |
+
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
|
| 320 |
+
const uint32_t& a, const float scale,
|
| 321 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 322 |
+
union {
|
| 323 |
+
uint2 u32x2;
|
| 324 |
+
uint32_t u32[2];
|
| 325 |
+
} tmp;
|
| 326 |
+
tmp.u32[0] =
|
| 327 |
+
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
|
| 328 |
+
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
|
| 329 |
+
scale, fp8_type);
|
| 330 |
+
return tmp.u32x2;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
// fp8x8 -> half2x4
|
| 334 |
+
template <>
|
| 335 |
+
__inline__ __device__ uint4
|
| 336 |
+
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
|
| 337 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 338 |
+
union {
|
| 339 |
+
uint4 u64x2;
|
| 340 |
+
uint2 u64[2];
|
| 341 |
+
} tmp;
|
| 342 |
+
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
|
| 343 |
+
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
|
| 344 |
+
return tmp.u64x2;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
// fp8 -> __nv_bfloat16
|
| 348 |
+
template <>
|
| 349 |
+
__inline__ __device__ __nv_bfloat16
|
| 350 |
+
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
|
| 351 |
+
const uint8_t& a, const float scale,
|
| 352 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 353 |
+
// Note there is no direct convert function from fp8 to bf16.
|
| 354 |
+
// fp8 -> half
|
| 355 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
| 356 |
+
// half -> float -> bf16
|
| 357 |
+
float tmp = half_to_float(res.x);
|
| 358 |
+
return __float2bfloat16(tmp * scale);
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
// fp8x2 -> __nv_bfloat162
|
| 362 |
+
template <>
|
| 363 |
+
__inline__ __device__ __nv_bfloat162
|
| 364 |
+
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
|
| 365 |
+
const uint16_t& a, const float scale,
|
| 366 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 367 |
+
__nv_bfloat162 res;
|
| 368 |
+
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
|
| 369 |
+
fp8_type);
|
| 370 |
+
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
|
| 371 |
+
scale, fp8_type);
|
| 372 |
+
return res;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
// fp8x4 -> bf16_4_t
|
| 376 |
+
template <>
|
| 377 |
+
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
| 378 |
+
const uint32_t& a, const float scale,
|
| 379 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 380 |
+
bf16_4_t res;
|
| 381 |
+
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
|
| 382 |
+
fp8_type);
|
| 383 |
+
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
| 384 |
+
scale, fp8_type);
|
| 385 |
+
return res;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
// fp8x8 -> bf16_8_t
|
| 389 |
+
template <>
|
| 390 |
+
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
|
| 391 |
+
const uint2& a, const float scale,
|
| 392 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 393 |
+
bf16_4_t tmp1, tmp2;
|
| 394 |
+
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
|
| 395 |
+
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
|
| 396 |
+
bf16_8_t res;
|
| 397 |
+
res.x = tmp1.x;
|
| 398 |
+
res.y = tmp1.y;
|
| 399 |
+
res.z = tmp2.x;
|
| 400 |
+
res.w = tmp2.y;
|
| 401 |
+
return res;
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
// fp8 -> float
|
| 405 |
+
template <>
|
| 406 |
+
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
| 407 |
+
const uint8_t& a, const float scale,
|
| 408 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 409 |
+
// fp8 -> half
|
| 410 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
| 411 |
+
uint16_t tmp = res.x;
|
| 412 |
+
|
| 413 |
+
// half -> float
|
| 414 |
+
return half_to_float(tmp) * scale;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
// fp8x2 -> float2
|
| 418 |
+
template <>
|
| 419 |
+
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
|
| 420 |
+
const uint16_t& a, const float scale,
|
| 421 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 422 |
+
// fp8x2 -> half2
|
| 423 |
+
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
|
| 424 |
+
// half2 -> float2
|
| 425 |
+
return half2_to_float2(tmp);
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
// fp8x4 -> float4
|
| 429 |
+
template <>
|
| 430 |
+
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
|
| 431 |
+
const uint32_t& a, const float scale,
|
| 432 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 433 |
+
Float4_ res;
|
| 434 |
+
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
|
| 435 |
+
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
|
| 436 |
+
fp8_type);
|
| 437 |
+
return res;
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
// fp8x8 -> float8
|
| 441 |
+
template <>
|
| 442 |
+
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
|
| 443 |
+
const uint2& a, const float scale,
|
| 444 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 445 |
+
Float4_ tmp1, tmp2;
|
| 446 |
+
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
|
| 447 |
+
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
|
| 448 |
+
Float8_ res;
|
| 449 |
+
res.x = tmp1.x;
|
| 450 |
+
res.y = tmp1.y;
|
| 451 |
+
res.z = tmp2.x;
|
| 452 |
+
res.w = tmp2.y;
|
| 453 |
+
return res;
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
// half -> fp8
|
| 457 |
+
template <>
|
| 458 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
|
| 459 |
+
const uint16_t& a, const float scale,
|
| 460 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 461 |
+
__nv_fp8_storage_t res =
|
| 462 |
+
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
|
| 463 |
+
return (uint8_t)res;
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
// bf16 -> fp8
|
| 467 |
+
template <>
|
| 468 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
| 469 |
+
const __nv_bfloat16& a, const float scale,
|
| 470 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 471 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 472 |
+
assert(false);
|
| 473 |
+
#else
|
| 474 |
+
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
| 475 |
+
__NV_SATFINITE, fp8_type);
|
| 476 |
+
return (uint8_t)res;
|
| 477 |
+
#endif
|
| 478 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
// float -> fp8
|
| 482 |
+
template <>
|
| 483 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
|
| 484 |
+
const float& a, const float scale,
|
| 485 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 486 |
+
__nv_fp8_storage_t res =
|
| 487 |
+
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
|
| 488 |
+
return (uint8_t)res;
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
// fp8x4 -> float4
|
| 492 |
+
template <>
|
| 493 |
+
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
|
| 494 |
+
const uint32_t& a, const float scale,
|
| 495 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
| 496 |
+
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
|
| 497 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
| 498 |
+
return res;
|
| 499 |
+
}
|
| 500 |
+
#endif // ENABLE_FP8
|
| 501 |
+
|
| 502 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
| 503 |
+
__inline__ __device__ Tout convert(const Tin& x) {
|
| 504 |
+
#if 0 // Disable the following code to reduce the binary size.
|
| 505 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
| 506 |
+
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
|
| 507 |
+
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
| 508 |
+
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
|
| 509 |
+
}
|
| 510 |
+
#endif
|
| 511 |
+
assert(false);
|
| 512 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
| 516 |
+
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
| 517 |
+
#ifdef ENABLE_FP8
|
| 518 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
| 519 |
+
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
|
| 520 |
+
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
| 521 |
+
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
|
| 522 |
+
}
|
| 523 |
+
#endif
|
| 524 |
+
assert(false);
|
| 525 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
// The following macro is used to dispatch the conversion function based on
|
| 529 |
+
// the data type of the key and value cache. The FN is a macro that calls a
|
| 530 |
+
// function with template<typename scalar_t, typename cache_t,
|
| 531 |
+
// Fp8KVCacheDataType kv_dt>.
|
| 532 |
+
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
| 533 |
+
if (KV_DTYPE == "auto") { \
|
| 534 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
| 535 |
+
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
| 536 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
| 537 |
+
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
| 538 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
| 539 |
+
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
| 540 |
+
} else { \
|
| 541 |
+
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
| 542 |
+
} \
|
| 543 |
+
} else { \
|
| 544 |
+
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
| 545 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
| 546 |
+
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
| 547 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
| 548 |
+
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
| 549 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
| 550 |
+
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
| 551 |
+
} else { \
|
| 552 |
+
TORCH_CHECK(false, \
|
| 553 |
+
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
| 554 |
+
} \
|
| 555 |
+
} else if (KV_DTYPE == "fp8_e5m2") { \
|
| 556 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
| 557 |
+
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
| 558 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
| 559 |
+
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
| 560 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
| 561 |
+
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
| 562 |
+
} else { \
|
| 563 |
+
TORCH_CHECK(false, \
|
| 564 |
+
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
| 565 |
+
} \
|
| 566 |
+
} else { \
|
| 567 |
+
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
| 568 |
+
} \
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
} // namespace fp8
|
| 572 |
+
#endif // not USE_ROCM
|
| 573 |
+
} // namespace vllm
|
tests/kernels/__init__.py
ADDED
|
File without changes
|
tests/kernels/allclose_default.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# Reference default values of atol and rtol are from
|
| 4 |
+
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
|
| 5 |
+
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
|
| 6 |
+
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_default_atol(output) -> float:
|
| 10 |
+
return default_atol[output.dtype]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_default_rtol(output) -> float:
|
| 14 |
+
return default_rtol[output.dtype]
|
tests/kernels/conftest.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import attention as ops
|
| 4 |
+
import pytest
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.fixture()
|
| 9 |
+
def kv_cache_factory():
|
| 10 |
+
return create_kv_caches_with_random
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@pytest.fixture()
|
| 14 |
+
def kv_cache_factory_flashinfer():
|
| 15 |
+
return create_kv_caches_with_random_flash
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
STR_DTYPE_TO_TORCH_DTYPE = {
|
| 19 |
+
"half": torch.half,
|
| 20 |
+
"bfloat16": torch.bfloat16,
|
| 21 |
+
"float": torch.float,
|
| 22 |
+
"fp8": torch.uint8,
|
| 23 |
+
"fp8_e4m3": torch.uint8,
|
| 24 |
+
"fp8_e5m2": torch.uint8,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def create_kv_caches_with_random(
|
| 29 |
+
num_blocks: int,
|
| 30 |
+
block_size: int,
|
| 31 |
+
num_layers: int,
|
| 32 |
+
num_heads: int,
|
| 33 |
+
head_size: int,
|
| 34 |
+
cache_dtype: Optional[Union[str, torch.dtype]],
|
| 35 |
+
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 36 |
+
seed: int = 0,
|
| 37 |
+
device: Optional[str] = "cuda",
|
| 38 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
| 39 |
+
|
| 40 |
+
if cache_dtype == "fp8" and head_size % 16:
|
| 41 |
+
raise ValueError(
|
| 42 |
+
f"Does not support key cache of type fp8 with head_size {head_size}"
|
| 43 |
+
)
|
| 44 |
+
from attention.platforms import current_platform
|
| 45 |
+
|
| 46 |
+
current_platform.seed_everything(seed)
|
| 47 |
+
|
| 48 |
+
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
| 49 |
+
|
| 50 |
+
scale = head_size**-0.5
|
| 51 |
+
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
|
| 52 |
+
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
| 53 |
+
key_caches: List[torch.Tensor] = []
|
| 54 |
+
for _ in range(num_layers):
|
| 55 |
+
key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device)
|
| 56 |
+
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
| 57 |
+
key_cache.uniform_(-scale, scale)
|
| 58 |
+
elif cache_dtype == "fp8":
|
| 59 |
+
_generate_random_fp8(key_cache, -scale, scale)
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(f"Does not support key cache of type {cache_dtype}")
|
| 62 |
+
key_caches.append(key_cache)
|
| 63 |
+
|
| 64 |
+
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
| 65 |
+
value_caches: List[torch.Tensor] = []
|
| 66 |
+
for _ in range(num_layers):
|
| 67 |
+
value_cache = torch.empty(
|
| 68 |
+
size=value_cache_shape, dtype=torch_dtype, device=device
|
| 69 |
+
)
|
| 70 |
+
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
| 71 |
+
value_cache.uniform_(-scale, scale)
|
| 72 |
+
elif cache_dtype == "fp8":
|
| 73 |
+
_generate_random_fp8(value_cache, -scale, scale)
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(f"Does not support value cache of type {cache_dtype}")
|
| 76 |
+
value_caches.append(value_cache)
|
| 77 |
+
return key_caches, value_caches
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def create_kv_caches_with_random_flash(
|
| 81 |
+
num_blocks: int,
|
| 82 |
+
block_size: int,
|
| 83 |
+
num_layers: int,
|
| 84 |
+
num_heads: int,
|
| 85 |
+
head_size: int,
|
| 86 |
+
cache_dtype: Optional[Union[str, torch.dtype]],
|
| 87 |
+
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 88 |
+
seed: int = 0,
|
| 89 |
+
device: Optional[str] = "cuda",
|
| 90 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
| 91 |
+
from attention.platforms import current_platform
|
| 92 |
+
|
| 93 |
+
current_platform.seed_everything(seed)
|
| 94 |
+
|
| 95 |
+
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
| 96 |
+
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
| 97 |
+
scale = head_size**-0.5
|
| 98 |
+
|
| 99 |
+
key_caches: List[torch.Tensor] = []
|
| 100 |
+
value_caches: List[torch.Tensor] = []
|
| 101 |
+
|
| 102 |
+
for _ in range(num_layers):
|
| 103 |
+
key_value_cache = torch.empty(
|
| 104 |
+
size=key_value_cache_shape, dtype=torch_dtype, device=device
|
| 105 |
+
)
|
| 106 |
+
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
| 107 |
+
key_value_cache.uniform_(-scale, scale)
|
| 108 |
+
elif cache_dtype == "fp8":
|
| 109 |
+
_generate_random_fp8(key_value_cache, -scale, scale)
|
| 110 |
+
else:
|
| 111 |
+
raise ValueError(f"Does not support key cache of type {cache_dtype}")
|
| 112 |
+
key_caches.append(key_value_cache[:, 0])
|
| 113 |
+
value_caches.append(key_value_cache[:, 1])
|
| 114 |
+
return key_caches, value_caches
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_kv_cache_torch_dtype(
|
| 118 |
+
cache_dtype: Optional[Union[str, torch.dtype]],
|
| 119 |
+
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 120 |
+
) -> torch.dtype:
|
| 121 |
+
if isinstance(cache_dtype, str):
|
| 122 |
+
if cache_dtype == "auto":
|
| 123 |
+
if isinstance(model_dtype, str):
|
| 124 |
+
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
| 125 |
+
elif isinstance(model_dtype, torch.dtype):
|
| 126 |
+
torch_dtype = model_dtype
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f"Invalid model dtype: {model_dtype}")
|
| 129 |
+
elif cache_dtype in ["half", "bfloat16", "float"]:
|
| 130 |
+
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
| 131 |
+
elif cache_dtype == "fp8":
|
| 132 |
+
torch_dtype = torch.uint8
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
| 135 |
+
elif isinstance(cache_dtype, torch.dtype):
|
| 136 |
+
torch_dtype = cache_dtype
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
| 139 |
+
return torch_dtype
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _generate_random_fp8(
|
| 143 |
+
tensor: torch.Tensor,
|
| 144 |
+
low: float,
|
| 145 |
+
high: float,
|
| 146 |
+
) -> None:
|
| 147 |
+
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
|
| 148 |
+
# it may occur Inf or NaN if we directly use torch.randint
|
| 149 |
+
# to generate random data for fp8 data.
|
| 150 |
+
# For example, s.11111.00 in fp8e5m2 format represents Inf.
|
| 151 |
+
# | E4M3 | E5M2
|
| 152 |
+
# -----|-------------|-------------------
|
| 153 |
+
# Inf | N/A | s.11111.00
|
| 154 |
+
# NaN | s.1111.111 | s.11111.{01,10,11}
|
| 155 |
+
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
|
| 156 |
+
tensor_tmp.uniform_(low, high)
|
| 157 |
+
ops.convert_fp8(tensor, tensor_tmp)
|
| 158 |
+
del tensor_tmp
|
tests/kernels/test_attention.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import attention as ops
|
| 5 |
+
import pytest
|
| 6 |
+
import torch
|
| 7 |
+
from attention.platforms import current_platform
|
| 8 |
+
|
| 9 |
+
from .allclose_default import get_default_atol, get_default_rtol
|
| 10 |
+
from .utils import get_max_shared_memory_bytes, opcheck
|
| 11 |
+
|
| 12 |
+
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
| 13 |
+
# This will change depending on the compute capability.
|
| 14 |
+
# - 512 as a buffer
|
| 15 |
+
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
| 16 |
+
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
| 17 |
+
# Reduce NUM_BLOCKS when it happens.
|
| 18 |
+
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
| 19 |
+
PARTITION_SIZE = 512
|
| 20 |
+
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
| 21 |
+
DTYPES = (
|
| 22 |
+
[torch.half, torch.bfloat16, torch.float]
|
| 23 |
+
if not current_platform.is_rocm()
|
| 24 |
+
else [torch.half, torch.bfloat16]
|
| 25 |
+
)
|
| 26 |
+
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
| 27 |
+
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
| 28 |
+
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
| 29 |
+
|
| 30 |
+
# This should be sync with get_supported_head_sizes() in
|
| 31 |
+
# vllm.attention.ops.paged_attn.PagedAttention
|
| 32 |
+
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
| 33 |
+
|
| 34 |
+
BLOCK_SIZES = [16, 32]
|
| 35 |
+
USE_ALIBI = [False, True]
|
| 36 |
+
KV_CACHE_DTYPE = ["auto", "fp8"]
|
| 37 |
+
SEEDS = [0]
|
| 38 |
+
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def ref_masked_attention(
|
| 42 |
+
query: torch.Tensor,
|
| 43 |
+
key: torch.Tensor,
|
| 44 |
+
value: torch.Tensor,
|
| 45 |
+
scale: float,
|
| 46 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 47 |
+
) -> torch.Tensor:
|
| 48 |
+
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
| 49 |
+
if attn_mask is not None:
|
| 50 |
+
attn_weights = attn_weights + attn_mask.float()
|
| 51 |
+
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
| 52 |
+
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def ref_single_query_cached_kv_attention(
|
| 57 |
+
output: torch.Tensor,
|
| 58 |
+
query: torch.Tensor,
|
| 59 |
+
num_queries_per_kv: int,
|
| 60 |
+
key_cache: torch.Tensor,
|
| 61 |
+
value_cache: torch.Tensor,
|
| 62 |
+
block_tables: torch.Tensor,
|
| 63 |
+
seq_lens: torch.Tensor,
|
| 64 |
+
scale: float,
|
| 65 |
+
alibi_slopes: Optional[torch.Tensor],
|
| 66 |
+
) -> None:
|
| 67 |
+
num_query_heads = query.shape[1]
|
| 68 |
+
num_kv_heads = value_cache.shape[1]
|
| 69 |
+
head_size = value_cache.shape[2]
|
| 70 |
+
block_size = value_cache.shape[3]
|
| 71 |
+
num_seqs = query.shape[0]
|
| 72 |
+
|
| 73 |
+
block_tables_lst = block_tables.cpu().tolist()
|
| 74 |
+
seq_lens_lst = seq_lens.cpu().tolist()
|
| 75 |
+
for i in range(num_seqs):
|
| 76 |
+
q = query[i].unsqueeze(0)
|
| 77 |
+
block_table = block_tables_lst[i]
|
| 78 |
+
seq_len = int(seq_lens_lst[i])
|
| 79 |
+
|
| 80 |
+
keys_lst: List[torch.Tensor] = []
|
| 81 |
+
values_lst: List[torch.Tensor] = []
|
| 82 |
+
for j in range(seq_len):
|
| 83 |
+
block_number = int(block_table[j // block_size])
|
| 84 |
+
block_offset = j % block_size
|
| 85 |
+
|
| 86 |
+
k = key_cache[block_number, :, :, block_offset, :]
|
| 87 |
+
k = k.reshape(num_kv_heads, head_size)
|
| 88 |
+
keys_lst.append(k)
|
| 89 |
+
|
| 90 |
+
v = value_cache[block_number, :, :, block_offset]
|
| 91 |
+
values_lst.append(v)
|
| 92 |
+
keys = torch.stack(keys_lst, dim=0)
|
| 93 |
+
values = torch.stack(values_lst, dim=0)
|
| 94 |
+
if num_queries_per_kv > 1:
|
| 95 |
+
# Handle MQA and GQA
|
| 96 |
+
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
| 97 |
+
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
| 98 |
+
|
| 99 |
+
alibi_bias = None
|
| 100 |
+
if alibi_slopes is not None:
|
| 101 |
+
# Create the ALiBi bias used in the paged attention kernel.
|
| 102 |
+
position_ids = torch.arange(seq_len).int()
|
| 103 |
+
alibi_bias = (position_ids - seq_len + 1).float()
|
| 104 |
+
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
|
| 105 |
+
|
| 106 |
+
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
|
| 107 |
+
out = out.view(num_query_heads, head_size)
|
| 108 |
+
output[i].copy_(out, non_blocking=True)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@pytest.mark.parametrize(
|
| 112 |
+
"version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]
|
| 113 |
+
)
|
| 114 |
+
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
| 115 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 116 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 117 |
+
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
| 118 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 119 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 120 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
| 121 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 122 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 123 |
+
def test_paged_attention(
|
| 124 |
+
kv_cache_factory,
|
| 125 |
+
version: str,
|
| 126 |
+
num_seqs: int,
|
| 127 |
+
num_heads: Tuple[int, int],
|
| 128 |
+
head_size: int,
|
| 129 |
+
use_alibi: bool,
|
| 130 |
+
block_size: int,
|
| 131 |
+
dtype: torch.dtype,
|
| 132 |
+
kv_cache_dtype: str,
|
| 133 |
+
seed: int,
|
| 134 |
+
device: str,
|
| 135 |
+
) -> None:
|
| 136 |
+
if (kv_cache_dtype == "fp8" and head_size % 16) or (
|
| 137 |
+
version == "rocm" and head_size not in (64, 128)
|
| 138 |
+
):
|
| 139 |
+
pytest.skip()
|
| 140 |
+
|
| 141 |
+
current_platform.seed_everything(seed)
|
| 142 |
+
torch.set_default_device(device)
|
| 143 |
+
scale = float(1.0 / (head_size**0.5))
|
| 144 |
+
num_query_heads, num_kv_heads = num_heads
|
| 145 |
+
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
| 146 |
+
query.uniform_(-scale, scale)
|
| 147 |
+
|
| 148 |
+
assert num_query_heads % num_kv_heads == 0
|
| 149 |
+
num_queries_per_kv = num_query_heads // num_kv_heads
|
| 150 |
+
alibi_slopes = None
|
| 151 |
+
if use_alibi:
|
| 152 |
+
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
| 153 |
+
|
| 154 |
+
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
| 155 |
+
seq_lens[-1] = MAX_SEQ_LEN
|
| 156 |
+
max_seq_len = max(seq_lens)
|
| 157 |
+
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
|
| 158 |
+
|
| 159 |
+
# Create the block tables.
|
| 160 |
+
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
| 161 |
+
block_tables_lst: List[List[int]] = []
|
| 162 |
+
for _ in range(num_seqs):
|
| 163 |
+
block_table = [
|
| 164 |
+
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
|
| 165 |
+
]
|
| 166 |
+
block_tables_lst.append(block_table)
|
| 167 |
+
|
| 168 |
+
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
|
| 169 |
+
|
| 170 |
+
# Create the KV caches.
|
| 171 |
+
key_caches, value_caches = kv_cache_factory(
|
| 172 |
+
NUM_BLOCKS,
|
| 173 |
+
block_size,
|
| 174 |
+
1,
|
| 175 |
+
num_kv_heads,
|
| 176 |
+
head_size,
|
| 177 |
+
kv_cache_dtype,
|
| 178 |
+
dtype,
|
| 179 |
+
seed,
|
| 180 |
+
device,
|
| 181 |
+
)
|
| 182 |
+
key_cache, value_cache = key_caches[0], value_caches[0]
|
| 183 |
+
|
| 184 |
+
# Using default kv_scale
|
| 185 |
+
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
| 186 |
+
|
| 187 |
+
# Call the paged attention kernel.
|
| 188 |
+
output = torch.empty_like(query)
|
| 189 |
+
if version == "v1":
|
| 190 |
+
ops.paged_attention_v1(
|
| 191 |
+
output,
|
| 192 |
+
query,
|
| 193 |
+
key_cache,
|
| 194 |
+
value_cache,
|
| 195 |
+
num_kv_heads,
|
| 196 |
+
scale,
|
| 197 |
+
block_tables,
|
| 198 |
+
seq_lens,
|
| 199 |
+
block_size,
|
| 200 |
+
max_seq_len,
|
| 201 |
+
alibi_slopes,
|
| 202 |
+
kv_cache_dtype,
|
| 203 |
+
k_scale,
|
| 204 |
+
v_scale,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
opcheck(
|
| 208 |
+
ops.ops.paged_attention_v1,
|
| 209 |
+
(
|
| 210 |
+
output,
|
| 211 |
+
query,
|
| 212 |
+
key_cache,
|
| 213 |
+
value_cache,
|
| 214 |
+
num_kv_heads,
|
| 215 |
+
scale,
|
| 216 |
+
block_tables,
|
| 217 |
+
seq_lens,
|
| 218 |
+
block_size,
|
| 219 |
+
max_seq_len,
|
| 220 |
+
alibi_slopes,
|
| 221 |
+
kv_cache_dtype,
|
| 222 |
+
k_scale,
|
| 223 |
+
v_scale,
|
| 224 |
+
0,
|
| 225 |
+
0,
|
| 226 |
+
0,
|
| 227 |
+
64,
|
| 228 |
+
0,
|
| 229 |
+
),
|
| 230 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
elif version in ("v2", "rocm"):
|
| 234 |
+
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
| 235 |
+
assert PARTITION_SIZE % block_size == 0
|
| 236 |
+
num_seqs, num_heads, head_size = output.shape
|
| 237 |
+
tmp_output = torch.empty(
|
| 238 |
+
size=(num_seqs, num_heads, num_partitions, head_size),
|
| 239 |
+
dtype=output.dtype,
|
| 240 |
+
)
|
| 241 |
+
exp_sums = torch.empty(
|
| 242 |
+
size=(num_seqs, num_heads, num_partitions),
|
| 243 |
+
dtype=torch.float32,
|
| 244 |
+
)
|
| 245 |
+
max_logits = torch.empty_like(exp_sums)
|
| 246 |
+
if version == "v2":
|
| 247 |
+
ops.paged_attention_v2(
|
| 248 |
+
output,
|
| 249 |
+
exp_sums,
|
| 250 |
+
max_logits,
|
| 251 |
+
tmp_output,
|
| 252 |
+
query,
|
| 253 |
+
key_cache,
|
| 254 |
+
value_cache,
|
| 255 |
+
num_kv_heads,
|
| 256 |
+
scale,
|
| 257 |
+
block_tables,
|
| 258 |
+
seq_lens,
|
| 259 |
+
block_size,
|
| 260 |
+
max_seq_len,
|
| 261 |
+
alibi_slopes,
|
| 262 |
+
kv_cache_dtype,
|
| 263 |
+
k_scale,
|
| 264 |
+
v_scale,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
opcheck(
|
| 268 |
+
ops.ops.paged_attention_v2,
|
| 269 |
+
(
|
| 270 |
+
output,
|
| 271 |
+
exp_sums,
|
| 272 |
+
max_logits,
|
| 273 |
+
tmp_output,
|
| 274 |
+
query,
|
| 275 |
+
key_cache,
|
| 276 |
+
value_cache,
|
| 277 |
+
num_kv_heads,
|
| 278 |
+
scale,
|
| 279 |
+
block_tables,
|
| 280 |
+
seq_lens,
|
| 281 |
+
block_size,
|
| 282 |
+
max_seq_len,
|
| 283 |
+
alibi_slopes,
|
| 284 |
+
kv_cache_dtype,
|
| 285 |
+
k_scale,
|
| 286 |
+
v_scale,
|
| 287 |
+
0,
|
| 288 |
+
0,
|
| 289 |
+
0,
|
| 290 |
+
64,
|
| 291 |
+
0,
|
| 292 |
+
),
|
| 293 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
else:
|
| 297 |
+
ops.paged_attention_rocm(
|
| 298 |
+
output,
|
| 299 |
+
exp_sums,
|
| 300 |
+
max_logits,
|
| 301 |
+
tmp_output,
|
| 302 |
+
query,
|
| 303 |
+
key_cache,
|
| 304 |
+
value_cache,
|
| 305 |
+
num_kv_heads,
|
| 306 |
+
scale,
|
| 307 |
+
block_tables,
|
| 308 |
+
seq_lens,
|
| 309 |
+
block_size,
|
| 310 |
+
max_seq_len,
|
| 311 |
+
alibi_slopes,
|
| 312 |
+
kv_cache_dtype,
|
| 313 |
+
k_scale,
|
| 314 |
+
v_scale,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
opcheck(
|
| 318 |
+
torch.ops._rocm_C.paged_attention,
|
| 319 |
+
(
|
| 320 |
+
output,
|
| 321 |
+
exp_sums,
|
| 322 |
+
max_logits,
|
| 323 |
+
tmp_output,
|
| 324 |
+
query,
|
| 325 |
+
key_cache,
|
| 326 |
+
value_cache,
|
| 327 |
+
num_kv_heads,
|
| 328 |
+
scale,
|
| 329 |
+
block_tables,
|
| 330 |
+
seq_lens,
|
| 331 |
+
block_size,
|
| 332 |
+
max_seq_len,
|
| 333 |
+
alibi_slopes,
|
| 334 |
+
kv_cache_dtype,
|
| 335 |
+
k_scale,
|
| 336 |
+
v_scale,
|
| 337 |
+
),
|
| 338 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
else:
|
| 342 |
+
raise AssertionError(f"Unknown version: {version}")
|
| 343 |
+
|
| 344 |
+
# Run the reference implementation.
|
| 345 |
+
if kv_cache_dtype == "fp8":
|
| 346 |
+
# Convert cache data back to dtype.
|
| 347 |
+
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
| 348 |
+
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
| 349 |
+
dequantized_key_cache = torch.empty(
|
| 350 |
+
size=key_cache_shape, dtype=dtype, device=device
|
| 351 |
+
)
|
| 352 |
+
ops.convert_fp8(dequantized_key_cache, key_cache)
|
| 353 |
+
key_cache = dequantized_key_cache
|
| 354 |
+
|
| 355 |
+
value_cache_shape = value_cache.shape
|
| 356 |
+
dequantized_value_cache = torch.empty(
|
| 357 |
+
size=value_cache_shape, dtype=dtype, device=device
|
| 358 |
+
)
|
| 359 |
+
ops.convert_fp8(dequantized_value_cache, value_cache)
|
| 360 |
+
value_cache = dequantized_value_cache
|
| 361 |
+
|
| 362 |
+
ref_output = torch.empty_like(query)
|
| 363 |
+
ref_single_query_cached_kv_attention(
|
| 364 |
+
ref_output,
|
| 365 |
+
query,
|
| 366 |
+
num_queries_per_kv,
|
| 367 |
+
key_cache,
|
| 368 |
+
value_cache,
|
| 369 |
+
block_tables,
|
| 370 |
+
seq_lens,
|
| 371 |
+
scale,
|
| 372 |
+
alibi_slopes,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# NOTE(woosuk): Due to the kernel-level differences in the two
|
| 376 |
+
# implementations, there is a small numerical difference in the two
|
| 377 |
+
# outputs. Thus, we use a relaxed tolerance for the test.
|
| 378 |
+
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
| 379 |
+
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
| 380 |
+
|
| 381 |
+
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
| 382 |
+
# so we use a relaxed tolerance for the test.
|
| 383 |
+
atol, rtol = 1e-3, 1e-5
|
| 384 |
+
if kv_cache_dtype == "fp8":
|
| 385 |
+
atol, rtol = 1e-2, 1e-5
|
| 386 |
+
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def ref_multi_query_kv_attention(
|
| 390 |
+
cu_seq_lens: List[int],
|
| 391 |
+
query: torch.Tensor,
|
| 392 |
+
key: torch.Tensor,
|
| 393 |
+
value: torch.Tensor,
|
| 394 |
+
scale: float,
|
| 395 |
+
dtype: torch.dtype,
|
| 396 |
+
) -> torch.Tensor:
|
| 397 |
+
num_seqs = len(cu_seq_lens) - 1
|
| 398 |
+
ref_outputs: List[torch.Tensor] = []
|
| 399 |
+
for i in range(num_seqs):
|
| 400 |
+
start_idx = cu_seq_lens[i]
|
| 401 |
+
end_idx = cu_seq_lens[i + 1]
|
| 402 |
+
seq_len = end_idx - start_idx
|
| 403 |
+
|
| 404 |
+
# Create attention mask.
|
| 405 |
+
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
|
| 406 |
+
attn_mask = attn_mask * torch.finfo(dtype).min
|
| 407 |
+
attn_mask = attn_mask.to(dtype=dtype)
|
| 408 |
+
|
| 409 |
+
ref_output = ref_masked_attention(
|
| 410 |
+
query[start_idx:end_idx],
|
| 411 |
+
key[start_idx:end_idx],
|
| 412 |
+
value[start_idx:end_idx],
|
| 413 |
+
scale,
|
| 414 |
+
attn_mask=attn_mask,
|
| 415 |
+
)
|
| 416 |
+
ref_outputs.append(ref_output)
|
| 417 |
+
|
| 418 |
+
return torch.cat(ref_outputs, dim=0)
|
tests/kernels/test_cache.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
import attention as ops
|
| 5 |
+
import pytest
|
| 6 |
+
import torch
|
| 7 |
+
from attention.platforms import current_platform
|
| 8 |
+
|
| 9 |
+
from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
| 10 |
+
|
| 11 |
+
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
|
| 12 |
+
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
| 13 |
+
NUM_TOKENS = [42] # Arbitrary values for testing
|
| 14 |
+
NUM_LAYERS = [1] # Arbitrary values for testing
|
| 15 |
+
NUM_HEADS = [8] # Arbitrary values for testing
|
| 16 |
+
HEAD_SIZES = [64, 80, 120, 256]
|
| 17 |
+
BLOCK_SIZES = [8, 16, 32]
|
| 18 |
+
|
| 19 |
+
# Arbitrary values for testing
|
| 20 |
+
# don't make it too large. e.g. [1024, 36000] will OOM
|
| 21 |
+
NUM_BLOCKS = [1024, 10000]
|
| 22 |
+
|
| 23 |
+
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
| 24 |
+
SEEDS = [0]
|
| 25 |
+
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
| 26 |
+
|
| 27 |
+
# We assume fp8 is always enabled for testing.
|
| 28 |
+
KV_CACHE_DTYPE = ["auto", "fp8"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
| 32 |
+
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
| 33 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 34 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 35 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 36 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
| 37 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 38 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 39 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 40 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
| 41 |
+
@torch.inference_mode()
|
| 42 |
+
def test_copy_blocks(
|
| 43 |
+
kv_cache_factory,
|
| 44 |
+
num_mappings: int,
|
| 45 |
+
num_layers: int,
|
| 46 |
+
num_heads: int,
|
| 47 |
+
head_size: int,
|
| 48 |
+
block_size: int,
|
| 49 |
+
num_blocks: int,
|
| 50 |
+
dtype: torch.dtype,
|
| 51 |
+
seed: int,
|
| 52 |
+
kv_cache_dtype: str,
|
| 53 |
+
device: str,
|
| 54 |
+
) -> None:
|
| 55 |
+
if kv_cache_dtype == "fp8" and head_size % 16:
|
| 56 |
+
pytest.skip()
|
| 57 |
+
current_platform.seed_everything(seed)
|
| 58 |
+
torch.set_default_device(device)
|
| 59 |
+
# Generate random block mappings where each source block is mapped to two
|
| 60 |
+
# destination blocks.
|
| 61 |
+
assert 2 * num_mappings <= num_blocks
|
| 62 |
+
src_blocks = random.sample(range(num_blocks), num_mappings)
|
| 63 |
+
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
| 64 |
+
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
| 65 |
+
block_mapping: List[Tuple[int, int]] = []
|
| 66 |
+
for i in range(num_mappings):
|
| 67 |
+
src = src_blocks[i]
|
| 68 |
+
dst1 = dst_blocks[2 * i]
|
| 69 |
+
dst2 = dst_blocks[2 * i + 1]
|
| 70 |
+
block_mapping.append((src, dst1))
|
| 71 |
+
block_mapping.append((src, dst2))
|
| 72 |
+
|
| 73 |
+
# Create the KV caches.
|
| 74 |
+
key_caches, value_caches = kv_cache_factory(
|
| 75 |
+
num_blocks,
|
| 76 |
+
block_size,
|
| 77 |
+
num_layers,
|
| 78 |
+
num_heads,
|
| 79 |
+
head_size,
|
| 80 |
+
kv_cache_dtype,
|
| 81 |
+
dtype,
|
| 82 |
+
seed,
|
| 83 |
+
device,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Clone the KV caches.
|
| 87 |
+
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
| 88 |
+
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
| 89 |
+
|
| 90 |
+
# Call the copy blocks kernel.
|
| 91 |
+
block_mapping_tensor = torch.tensor(
|
| 92 |
+
block_mapping, dtype=torch.int64, device=device
|
| 93 |
+
).view(-1, 2)
|
| 94 |
+
|
| 95 |
+
opcheck(
|
| 96 |
+
ops.ops.copy_blocks,
|
| 97 |
+
(key_caches, value_caches, block_mapping_tensor),
|
| 98 |
+
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
| 99 |
+
cond=(head_size == HEAD_SIZES[0]),
|
| 100 |
+
)
|
| 101 |
+
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
|
| 102 |
+
|
| 103 |
+
# Run the reference implementation.
|
| 104 |
+
for src, dst in block_mapping:
|
| 105 |
+
for cloned_key_cache in cloned_key_caches:
|
| 106 |
+
cloned_key_cache[dst].copy_(cloned_key_cache[src])
|
| 107 |
+
for cloned_value_cache in cloned_value_caches:
|
| 108 |
+
cloned_value_cache[dst].copy_(cloned_value_cache[src])
|
| 109 |
+
|
| 110 |
+
# Compare the results.
|
| 111 |
+
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
| 112 |
+
torch.testing.assert_close(key_cache, cloned_key_cache)
|
| 113 |
+
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
| 114 |
+
torch.testing.assert_close(value_cache, cloned_value_cache)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 118 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 119 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 120 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 121 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
| 122 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 123 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 124 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 125 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
| 126 |
+
@torch.inference_mode()
|
| 127 |
+
def test_reshape_and_cache(
|
| 128 |
+
kv_cache_factory,
|
| 129 |
+
num_tokens: int,
|
| 130 |
+
num_heads: int,
|
| 131 |
+
head_size: int,
|
| 132 |
+
block_size: int,
|
| 133 |
+
num_blocks: int,
|
| 134 |
+
dtype: torch.dtype,
|
| 135 |
+
seed: int,
|
| 136 |
+
device: str,
|
| 137 |
+
kv_cache_dtype: str,
|
| 138 |
+
) -> None:
|
| 139 |
+
if kv_cache_dtype == "fp8" and head_size % 16:
|
| 140 |
+
pytest.skip()
|
| 141 |
+
current_platform.seed_everything(seed)
|
| 142 |
+
torch.set_default_device(device)
|
| 143 |
+
# Create a random slot mapping.
|
| 144 |
+
num_slots = block_size * num_blocks
|
| 145 |
+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
| 146 |
+
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
|
| 147 |
+
|
| 148 |
+
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
|
| 149 |
+
_, key, value = qkv.unbind(dim=1)
|
| 150 |
+
|
| 151 |
+
# Create the KV caches.
|
| 152 |
+
key_caches, value_caches = kv_cache_factory(
|
| 153 |
+
num_blocks,
|
| 154 |
+
block_size,
|
| 155 |
+
1,
|
| 156 |
+
num_heads,
|
| 157 |
+
head_size,
|
| 158 |
+
kv_cache_dtype,
|
| 159 |
+
dtype,
|
| 160 |
+
seed,
|
| 161 |
+
device,
|
| 162 |
+
)
|
| 163 |
+
key_cache, value_cache = key_caches[0], value_caches[0]
|
| 164 |
+
|
| 165 |
+
# Clone the KV caches.
|
| 166 |
+
if kv_cache_dtype == "fp8":
|
| 167 |
+
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
| 168 |
+
ops.convert_fp8(cloned_key_cache, key_cache)
|
| 169 |
+
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
| 170 |
+
ops.convert_fp8(cloned_value_cache, value_cache)
|
| 171 |
+
else:
|
| 172 |
+
cloned_key_cache = key_cache.clone()
|
| 173 |
+
cloned_value_cache = value_cache.clone()
|
| 174 |
+
|
| 175 |
+
# Using default kv_scale
|
| 176 |
+
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
| 177 |
+
|
| 178 |
+
# Call the reshape_and_cache kernel.
|
| 179 |
+
opcheck(
|
| 180 |
+
ops.ops.reshape_and_cache,
|
| 181 |
+
(
|
| 182 |
+
key,
|
| 183 |
+
value,
|
| 184 |
+
key_cache,
|
| 185 |
+
value_cache,
|
| 186 |
+
slot_mapping,
|
| 187 |
+
kv_cache_dtype,
|
| 188 |
+
k_scale,
|
| 189 |
+
v_scale,
|
| 190 |
+
),
|
| 191 |
+
cond=(head_size == HEAD_SIZES[0]),
|
| 192 |
+
)
|
| 193 |
+
ops.reshape_and_cache(
|
| 194 |
+
key,
|
| 195 |
+
value,
|
| 196 |
+
key_cache,
|
| 197 |
+
value_cache,
|
| 198 |
+
slot_mapping,
|
| 199 |
+
kv_cache_dtype,
|
| 200 |
+
k_scale,
|
| 201 |
+
v_scale,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if kv_cache_dtype == "fp8":
|
| 205 |
+
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
| 206 |
+
ops.convert_fp8(result_key_cache, key_cache)
|
| 207 |
+
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
| 208 |
+
ops.convert_fp8(result_value_cache, value_cache)
|
| 209 |
+
|
| 210 |
+
# Run the reference implementation.
|
| 211 |
+
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
| 212 |
+
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
| 213 |
+
block_indicies_lst = block_indicies.cpu().tolist()
|
| 214 |
+
block_offsets = slot_mapping % block_size
|
| 215 |
+
block_offsets_lst = block_offsets.cpu().tolist()
|
| 216 |
+
for i in range(num_tokens):
|
| 217 |
+
block_idx = block_indicies_lst[i]
|
| 218 |
+
block_offset = block_offsets_lst[i]
|
| 219 |
+
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
| 220 |
+
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
| 221 |
+
|
| 222 |
+
if kv_cache_dtype == "fp8":
|
| 223 |
+
torch.testing.assert_close(
|
| 224 |
+
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
|
| 225 |
+
)
|
| 226 |
+
torch.testing.assert_close(
|
| 227 |
+
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
torch.testing.assert_close(key_cache, cloned_key_cache)
|
| 231 |
+
torch.testing.assert_close(value_cache, cloned_value_cache)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 235 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 236 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 237 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 238 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
| 239 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 240 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 241 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 242 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
| 243 |
+
@torch.inference_mode()
|
| 244 |
+
def test_reshape_and_cache_flash(
|
| 245 |
+
kv_cache_factory_flashinfer,
|
| 246 |
+
num_tokens: int,
|
| 247 |
+
num_heads: int,
|
| 248 |
+
head_size: int,
|
| 249 |
+
block_size: int,
|
| 250 |
+
num_blocks: int,
|
| 251 |
+
dtype: torch.dtype,
|
| 252 |
+
seed: int,
|
| 253 |
+
device: str,
|
| 254 |
+
kv_cache_dtype: str,
|
| 255 |
+
) -> None:
|
| 256 |
+
current_platform.seed_everything(seed)
|
| 257 |
+
torch.set_default_device(device)
|
| 258 |
+
|
| 259 |
+
# Create a random slot mapping.
|
| 260 |
+
num_slots = block_size * num_blocks
|
| 261 |
+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
| 262 |
+
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
| 263 |
+
|
| 264 |
+
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
|
| 265 |
+
_, key, value = qkv.unbind(dim=1)
|
| 266 |
+
|
| 267 |
+
# Create the KV caches.
|
| 268 |
+
key_caches, value_caches = kv_cache_factory_flashinfer(
|
| 269 |
+
num_blocks,
|
| 270 |
+
block_size,
|
| 271 |
+
1,
|
| 272 |
+
num_heads,
|
| 273 |
+
head_size,
|
| 274 |
+
kv_cache_dtype,
|
| 275 |
+
dtype,
|
| 276 |
+
device=device,
|
| 277 |
+
)
|
| 278 |
+
key_cache, value_cache = key_caches[0].contiguous(), value_caches[0].contiguous()
|
| 279 |
+
del key_caches
|
| 280 |
+
del value_caches
|
| 281 |
+
|
| 282 |
+
k_scale = (key.amax() / 256.0).to(torch.float32)
|
| 283 |
+
v_scale = (value.amax() / 256.0).to(torch.float32)
|
| 284 |
+
|
| 285 |
+
# Clone the KV caches.
|
| 286 |
+
if kv_cache_dtype == "fp8":
|
| 287 |
+
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
| 288 |
+
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
|
| 289 |
+
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
| 290 |
+
ops.convert_fp8(cloned_value_cache, value_cache, v_scale, kv_cache_dtype)
|
| 291 |
+
else:
|
| 292 |
+
cloned_key_cache = key_cache.clone()
|
| 293 |
+
cloned_value_cache = value_cache.clone()
|
| 294 |
+
|
| 295 |
+
# Call the reshape_and_cache kernel.
|
| 296 |
+
opcheck(
|
| 297 |
+
ops.ops.reshape_and_cache_flash,
|
| 298 |
+
(
|
| 299 |
+
key,
|
| 300 |
+
value,
|
| 301 |
+
key_cache,
|
| 302 |
+
value_cache,
|
| 303 |
+
slot_mapping,
|
| 304 |
+
kv_cache_dtype,
|
| 305 |
+
k_scale,
|
| 306 |
+
v_scale,
|
| 307 |
+
),
|
| 308 |
+
cond=(head_size == HEAD_SIZES[0]),
|
| 309 |
+
)
|
| 310 |
+
ops.reshape_and_cache_flash(
|
| 311 |
+
key,
|
| 312 |
+
value,
|
| 313 |
+
key_cache,
|
| 314 |
+
value_cache,
|
| 315 |
+
slot_mapping,
|
| 316 |
+
kv_cache_dtype,
|
| 317 |
+
k_scale,
|
| 318 |
+
v_scale,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if kv_cache_dtype == "fp8":
|
| 322 |
+
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
| 323 |
+
ops.convert_fp8(
|
| 324 |
+
result_key_cache, key_cache, k_scale.item(), kv_dtype=kv_cache_dtype
|
| 325 |
+
)
|
| 326 |
+
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
| 327 |
+
ops.convert_fp8(
|
| 328 |
+
result_value_cache, value_cache, v_scale.item(), kv_dtype=kv_cache_dtype
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Run the reference implementation.
|
| 332 |
+
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
| 333 |
+
block_indicies_lst = block_indicies.cpu().tolist()
|
| 334 |
+
block_offsets = slot_mapping % block_size
|
| 335 |
+
block_offsets_lst = block_offsets.cpu().tolist()
|
| 336 |
+
for i in range(num_tokens):
|
| 337 |
+
block_idx = block_indicies_lst[i]
|
| 338 |
+
block_offset = block_offsets_lst[i]
|
| 339 |
+
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
|
| 340 |
+
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
| 341 |
+
|
| 342 |
+
if kv_cache_dtype == "fp8":
|
| 343 |
+
torch.testing.assert_close(
|
| 344 |
+
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
|
| 345 |
+
)
|
| 346 |
+
torch.testing.assert_close(
|
| 347 |
+
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
|
| 348 |
+
)
|
| 349 |
+
else:
|
| 350 |
+
torch.testing.assert_close(key_cache, cloned_key_cache)
|
| 351 |
+
torch.testing.assert_close(value_cache, cloned_value_cache)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
| 355 |
+
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
| 356 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 357 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 358 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 359 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
| 360 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 361 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 362 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 363 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
| 364 |
+
@torch.inference_mode()
|
| 365 |
+
def test_swap_blocks(
|
| 366 |
+
kv_cache_factory,
|
| 367 |
+
direction: Tuple[str, str],
|
| 368 |
+
num_mappings: int,
|
| 369 |
+
num_heads: int,
|
| 370 |
+
head_size: int,
|
| 371 |
+
block_size: int,
|
| 372 |
+
num_blocks: int,
|
| 373 |
+
dtype: torch.dtype,
|
| 374 |
+
seed: int,
|
| 375 |
+
device: str,
|
| 376 |
+
kv_cache_dtype: str,
|
| 377 |
+
) -> None:
|
| 378 |
+
if kv_cache_dtype == "fp8" and "cpu" in direction:
|
| 379 |
+
pytest.skip()
|
| 380 |
+
if kv_cache_dtype == "fp8" and head_size % 16:
|
| 381 |
+
pytest.skip()
|
| 382 |
+
|
| 383 |
+
current_platform.seed_everything(seed)
|
| 384 |
+
|
| 385 |
+
src_device = device if direction[0] == "cuda" else "cpu"
|
| 386 |
+
dst_device = device if direction[1] == "cuda" else "cpu"
|
| 387 |
+
|
| 388 |
+
src_blocks = random.sample(range(num_blocks), num_mappings)
|
| 389 |
+
# For the same device, mapping must not overlap
|
| 390 |
+
if src_device == dst_device:
|
| 391 |
+
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
| 392 |
+
dst_blocks = random.sample(remaining_blocks, num_mappings)
|
| 393 |
+
else:
|
| 394 |
+
dst_blocks = random.sample(range(num_blocks), num_mappings)
|
| 395 |
+
|
| 396 |
+
block_mapping = list(zip(src_blocks, dst_blocks))
|
| 397 |
+
block_mapping_tensor = torch.tensor(
|
| 398 |
+
block_mapping, dtype=torch.int64, device="cpu"
|
| 399 |
+
).view(-1, 2)
|
| 400 |
+
|
| 401 |
+
# Create the KV caches on the first device.
|
| 402 |
+
src_key_caches, src_value_caches = kv_cache_factory(
|
| 403 |
+
num_blocks,
|
| 404 |
+
block_size,
|
| 405 |
+
1,
|
| 406 |
+
num_heads,
|
| 407 |
+
head_size,
|
| 408 |
+
kv_cache_dtype,
|
| 409 |
+
dtype,
|
| 410 |
+
seed,
|
| 411 |
+
src_device,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
# Create the KV caches on the second device.
|
| 415 |
+
dist_key_caches, dist_value_caches = kv_cache_factory(
|
| 416 |
+
num_blocks,
|
| 417 |
+
block_size,
|
| 418 |
+
1,
|
| 419 |
+
num_heads,
|
| 420 |
+
head_size,
|
| 421 |
+
kv_cache_dtype,
|
| 422 |
+
dtype,
|
| 423 |
+
seed,
|
| 424 |
+
dst_device,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
src_key_caches_clone = src_key_caches[0].clone()
|
| 428 |
+
src_value_caches_clone = src_value_caches[0].clone()
|
| 429 |
+
|
| 430 |
+
# Call the swap_blocks kernel.
|
| 431 |
+
do_opcheck = head_size == HEAD_SIZES[0]
|
| 432 |
+
opcheck(
|
| 433 |
+
ops.ops.swap_blocks,
|
| 434 |
+
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
|
| 435 |
+
cond=do_opcheck,
|
| 436 |
+
)
|
| 437 |
+
opcheck(
|
| 438 |
+
ops.ops.swap_blocks,
|
| 439 |
+
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
|
| 440 |
+
cond=do_opcheck,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
|
| 444 |
+
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
|
| 445 |
+
|
| 446 |
+
for src, dst in block_mapping:
|
| 447 |
+
torch.testing.assert_close(
|
| 448 |
+
src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
|
| 449 |
+
)
|
| 450 |
+
torch.testing.assert_close(
|
| 451 |
+
src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 456 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 457 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 458 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
| 459 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 460 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 461 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 462 |
+
@torch.inference_mode()
|
| 463 |
+
def test_fp8_e4m3_conversion(
|
| 464 |
+
num_heads: int,
|
| 465 |
+
head_size: int,
|
| 466 |
+
block_size: int,
|
| 467 |
+
num_blocks: int,
|
| 468 |
+
dtype: torch.dtype,
|
| 469 |
+
seed: int,
|
| 470 |
+
device: str,
|
| 471 |
+
) -> None:
|
| 472 |
+
current_platform.seed_everything(seed)
|
| 473 |
+
|
| 474 |
+
low = -224.0
|
| 475 |
+
high = 224.0
|
| 476 |
+
shape = (num_blocks, num_heads, head_size, block_size)
|
| 477 |
+
cache = torch.empty(shape, dtype=dtype, device=device)
|
| 478 |
+
cache.uniform_(low, high)
|
| 479 |
+
|
| 480 |
+
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
|
| 481 |
+
ops.convert_fp8(cache_fp8, cache)
|
| 482 |
+
|
| 483 |
+
converted_cache = torch.empty_like(cache)
|
| 484 |
+
ops.convert_fp8(converted_cache, cache_fp8)
|
| 485 |
+
|
| 486 |
+
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
|
tests/kernels/utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Kernel test utils"""
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
import random
|
| 5 |
+
import unittest
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from numbers import Number
|
| 8 |
+
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
+
import torch
|
| 12 |
+
from torch._prims_common import TensorLikeType
|
| 13 |
+
|
| 14 |
+
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
| 15 |
+
# bugs related to this test in PyTorch 2.4.
|
| 16 |
+
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
| 17 |
+
"test_schema",
|
| 18 |
+
"test_autograd_registration",
|
| 19 |
+
"test_faketensor",
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
| 23 |
+
"test_schema",
|
| 24 |
+
"test_autograd_registration",
|
| 25 |
+
"test_faketensor",
|
| 26 |
+
"test_aot_dispatch_dynamic",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Copied/modified from torch._refs.__init__.py
|
| 31 |
+
def fp8_allclose(
|
| 32 |
+
a: TensorLikeType,
|
| 33 |
+
b: TensorLikeType,
|
| 34 |
+
rtol: float = 1e-05,
|
| 35 |
+
atol: float = 1e-08,
|
| 36 |
+
equal_nan: bool = False,
|
| 37 |
+
) -> bool:
|
| 38 |
+
"""
|
| 39 |
+
Reference implementation of torch.allclose
|
| 40 |
+
"""
|
| 41 |
+
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
|
| 42 |
+
|
| 43 |
+
return bool(
|
| 44 |
+
torch.all(
|
| 45 |
+
torch.isclose(
|
| 46 |
+
a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
|
| 47 |
+
)
|
| 48 |
+
).item()
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def compute_max_diff(output, output_ref):
|
| 53 |
+
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
| 54 |
+
torch.abs(output_ref)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# A special version of op check that has a restricted default set of test_utils
|
| 59 |
+
# and a patched version of allclose that supports fp8 types.
|
| 60 |
+
def opcheck(
|
| 61 |
+
op: Union[
|
| 62 |
+
torch._ops.OpOverload,
|
| 63 |
+
torch._ops.OpOverloadPacket,
|
| 64 |
+
torch._library.custom_ops.CustomOpDef,
|
| 65 |
+
],
|
| 66 |
+
args: Tuple[Any, ...],
|
| 67 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 68 |
+
*,
|
| 69 |
+
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
| 70 |
+
raise_exception: bool = True,
|
| 71 |
+
cond: bool = True
|
| 72 |
+
) -> Dict[str, str]:
|
| 73 |
+
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
| 74 |
+
return (
|
| 75 |
+
torch.library.opcheck(
|
| 76 |
+
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
| 77 |
+
)
|
| 78 |
+
if cond
|
| 79 |
+
else {}
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@lru_cache(maxsize=None)
|
| 84 |
+
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
| 85 |
+
"""Returns the maximum shared memory per thread block in bytes."""
|
| 86 |
+
from attention import ops
|
| 87 |
+
|
| 88 |
+
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
| 89 |
+
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
| 90 |
+
# will fail
|
| 91 |
+
assert max_shared_mem > 0, "max_shared_mem can not be zero"
|
| 92 |
+
return int(max_shared_mem)
|
torch-ext/attention/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._custom_ops import (
|
| 2 |
+
convert_fp8,
|
| 3 |
+
copy_blocks,
|
| 4 |
+
paged_attention_v1,
|
| 5 |
+
paged_attention_v2,
|
| 6 |
+
reshape_and_cache,
|
| 7 |
+
reshape_and_cache_flash,
|
| 8 |
+
swap_blocks,
|
| 9 |
+
)
|
| 10 |
+
from ._ops import ops
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"convert_fp8",
|
| 14 |
+
"copy_blocks",
|
| 15 |
+
"ops",
|
| 16 |
+
"paged_attention_v1",
|
| 17 |
+
"paged_attention_v2",
|
| 18 |
+
"reshape_and_cache",
|
| 19 |
+
"reshape_and_cache_flash",
|
| 20 |
+
"swap_blocks",
|
| 21 |
+
]
|
torch-ext/attention/_custom_ops.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# page attention ops
|
| 9 |
+
def paged_attention_v1(
|
| 10 |
+
out: torch.Tensor,
|
| 11 |
+
query: torch.Tensor,
|
| 12 |
+
key_cache: torch.Tensor,
|
| 13 |
+
value_cache: torch.Tensor,
|
| 14 |
+
num_kv_heads: int,
|
| 15 |
+
scale: float,
|
| 16 |
+
block_tables: torch.Tensor,
|
| 17 |
+
seq_lens: torch.Tensor,
|
| 18 |
+
block_size: int,
|
| 19 |
+
max_seq_len: int,
|
| 20 |
+
alibi_slopes: Optional[torch.Tensor],
|
| 21 |
+
kv_cache_dtype: str,
|
| 22 |
+
k_scale: float,
|
| 23 |
+
v_scale: float,
|
| 24 |
+
tp_rank: int = 0,
|
| 25 |
+
blocksparse_local_blocks: int = 0,
|
| 26 |
+
blocksparse_vert_stride: int = 0,
|
| 27 |
+
blocksparse_block_size: int = 64,
|
| 28 |
+
blocksparse_head_sliding_step: int = 0,
|
| 29 |
+
) -> None:
|
| 30 |
+
ops.paged_attention_v1(
|
| 31 |
+
out,
|
| 32 |
+
query,
|
| 33 |
+
key_cache,
|
| 34 |
+
value_cache,
|
| 35 |
+
num_kv_heads,
|
| 36 |
+
scale,
|
| 37 |
+
block_tables,
|
| 38 |
+
seq_lens,
|
| 39 |
+
block_size,
|
| 40 |
+
max_seq_len,
|
| 41 |
+
alibi_slopes,
|
| 42 |
+
kv_cache_dtype,
|
| 43 |
+
k_scale,
|
| 44 |
+
v_scale,
|
| 45 |
+
tp_rank,
|
| 46 |
+
blocksparse_local_blocks,
|
| 47 |
+
blocksparse_vert_stride,
|
| 48 |
+
blocksparse_block_size,
|
| 49 |
+
blocksparse_head_sliding_step,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def paged_attention_v2(
|
| 54 |
+
out: torch.Tensor,
|
| 55 |
+
exp_sum: torch.Tensor,
|
| 56 |
+
max_logits: torch.Tensor,
|
| 57 |
+
tmp_out: torch.Tensor,
|
| 58 |
+
query: torch.Tensor,
|
| 59 |
+
key_cache: torch.Tensor,
|
| 60 |
+
value_cache: torch.Tensor,
|
| 61 |
+
num_kv_heads: int,
|
| 62 |
+
scale: float,
|
| 63 |
+
block_tables: torch.Tensor,
|
| 64 |
+
seq_lens: torch.Tensor,
|
| 65 |
+
block_size: int,
|
| 66 |
+
max_seq_len: int,
|
| 67 |
+
alibi_slopes: Optional[torch.Tensor],
|
| 68 |
+
kv_cache_dtype: str,
|
| 69 |
+
k_scale: float,
|
| 70 |
+
v_scale: float,
|
| 71 |
+
tp_rank: int = 0,
|
| 72 |
+
blocksparse_local_blocks: int = 0,
|
| 73 |
+
blocksparse_vert_stride: int = 0,
|
| 74 |
+
blocksparse_block_size: int = 64,
|
| 75 |
+
blocksparse_head_sliding_step: int = 0,
|
| 76 |
+
) -> None:
|
| 77 |
+
ops.paged_attention_v2(
|
| 78 |
+
out,
|
| 79 |
+
exp_sum,
|
| 80 |
+
max_logits,
|
| 81 |
+
tmp_out,
|
| 82 |
+
query,
|
| 83 |
+
key_cache,
|
| 84 |
+
value_cache,
|
| 85 |
+
num_kv_heads,
|
| 86 |
+
scale,
|
| 87 |
+
block_tables,
|
| 88 |
+
seq_lens,
|
| 89 |
+
block_size,
|
| 90 |
+
max_seq_len,
|
| 91 |
+
alibi_slopes,
|
| 92 |
+
kv_cache_dtype,
|
| 93 |
+
k_scale,
|
| 94 |
+
v_scale,
|
| 95 |
+
tp_rank,
|
| 96 |
+
blocksparse_local_blocks,
|
| 97 |
+
blocksparse_vert_stride,
|
| 98 |
+
blocksparse_block_size,
|
| 99 |
+
blocksparse_head_sliding_step,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def reshape_and_cache(
|
| 104 |
+
key: torch.Tensor,
|
| 105 |
+
value: torch.Tensor,
|
| 106 |
+
key_cache: torch.Tensor,
|
| 107 |
+
value_cache: torch.Tensor,
|
| 108 |
+
slot_mapping: torch.Tensor,
|
| 109 |
+
kv_cache_dtype: str,
|
| 110 |
+
k_scale: float,
|
| 111 |
+
v_scale: float,
|
| 112 |
+
) -> None:
|
| 113 |
+
ops.reshape_and_cache(
|
| 114 |
+
key,
|
| 115 |
+
value,
|
| 116 |
+
key_cache,
|
| 117 |
+
value_cache,
|
| 118 |
+
slot_mapping,
|
| 119 |
+
kv_cache_dtype,
|
| 120 |
+
k_scale,
|
| 121 |
+
v_scale,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def reshape_and_cache_flash(
|
| 126 |
+
key: torch.Tensor,
|
| 127 |
+
value: torch.Tensor,
|
| 128 |
+
key_cache: torch.Tensor,
|
| 129 |
+
value_cache: torch.Tensor,
|
| 130 |
+
slot_mapping: torch.Tensor,
|
| 131 |
+
kv_cache_dtype: str,
|
| 132 |
+
k_scale: torch.Tensor,
|
| 133 |
+
v_scale: torch.Tensor,
|
| 134 |
+
) -> None:
|
| 135 |
+
ops.reshape_and_cache_flash(
|
| 136 |
+
key,
|
| 137 |
+
value,
|
| 138 |
+
key_cache,
|
| 139 |
+
value_cache,
|
| 140 |
+
slot_mapping,
|
| 141 |
+
kv_cache_dtype,
|
| 142 |
+
k_scale,
|
| 143 |
+
v_scale,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def copy_blocks(
|
| 148 |
+
key_caches: List[torch.Tensor],
|
| 149 |
+
value_caches: List[torch.Tensor],
|
| 150 |
+
block_mapping: torch.Tensor,
|
| 151 |
+
) -> None:
|
| 152 |
+
ops.copy_blocks(key_caches, value_caches, block_mapping)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def swap_blocks(
|
| 156 |
+
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
|
| 157 |
+
) -> None:
|
| 158 |
+
ops.swap_blocks(src, dst, block_mapping)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def convert_fp8(
|
| 162 |
+
output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
|
| 163 |
+
) -> None:
|
| 164 |
+
ops.convert_fp8(output, input, scale, kv_dtype)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
__all__ = [
|
| 168 |
+
"convert_fp8",
|
| 169 |
+
"paged_attention_v1",
|
| 170 |
+
"paged_attention_v2",
|
| 171 |
+
"reshape_and_cache",
|
| 172 |
+
"copy_blocks",
|
| 173 |
+
]
|
torch-ext/attention/platforms.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from functools import lru_cache, wraps
|
| 5 |
+
from typing import Callable, ParamSpec, TypeVar
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
IS_ROCM = torch.version.hip is not None
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Platform(ABC):
|
| 14 |
+
@classmethod
|
| 15 |
+
def seed_everything(cls, seed: int) -> None:
|
| 16 |
+
"""
|
| 17 |
+
Set the seed of each random module.
|
| 18 |
+
`torch.manual_seed` will set seed on all devices.
|
| 19 |
+
|
| 20 |
+
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
| 21 |
+
"""
|
| 22 |
+
random.seed(seed)
|
| 23 |
+
np.random.seed(seed)
|
| 24 |
+
torch.manual_seed(seed)
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def get_device_name(self, device_id: int = 0) -> str: ...
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def is_cuda(self) -> bool: ...
|
| 31 |
+
|
| 32 |
+
@abstractmethod
|
| 33 |
+
def is_rocm(self) -> bool: ...
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CudaPlatform(Platform):
|
| 37 |
+
@classmethod
|
| 38 |
+
@lru_cache(maxsize=8)
|
| 39 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 40 |
+
return torch.cuda.get_device_name(0)
|
| 41 |
+
|
| 42 |
+
def is_cuda(self) -> bool:
|
| 43 |
+
return True
|
| 44 |
+
|
| 45 |
+
def is_rocm(self) -> bool:
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RocmPlatform(Platform):
|
| 50 |
+
@classmethod
|
| 51 |
+
@lru_cache(maxsize=8)
|
| 52 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 53 |
+
return torch.cuda.get_device_name(device_id)
|
| 54 |
+
|
| 55 |
+
def is_cuda(self) -> bool:
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
def is_rocm(self) -> bool:
|
| 59 |
+
return True
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
|
torch-ext/registration.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <Python.h>
|
| 4 |
+
|
| 5 |
+
#define _CONCAT(A, B) A##B
|
| 6 |
+
#define CONCAT(A, B) _CONCAT(A, B)
|
| 7 |
+
|
| 8 |
+
#define _STRINGIFY(A) #A
|
| 9 |
+
#define STRINGIFY(A) _STRINGIFY(A)
|
| 10 |
+
|
| 11 |
+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
| 12 |
+
// could be a macro instead of a literal token.
|
| 13 |
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
| 14 |
+
|
| 15 |
+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
| 16 |
+
// could be a macro instead of a literal token.
|
| 17 |
+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
| 18 |
+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
| 19 |
+
|
| 20 |
+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
| 21 |
+
// via python's import statement.
|
| 22 |
+
#define REGISTER_EXTENSION(NAME) \
|
| 23 |
+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
| 24 |
+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
| 25 |
+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
| 26 |
+
return PyModule_Create(&module); \
|
| 27 |
+
}
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/library.h>
|
| 2 |
+
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
|
| 5 |
+
#include "torch_binding.h"
|
| 6 |
+
|
| 7 |
+
// Note on op signatures:
|
| 8 |
+
// The X_meta signatures are for the meta functions corresponding to op X.
|
| 9 |
+
// They must be kept in sync with the signature for X. Generally, only
|
| 10 |
+
// functions that return Tensors require a meta function.
|
| 11 |
+
//
|
| 12 |
+
// See the following links for detailed docs on op registration and function
|
| 13 |
+
// schemas.
|
| 14 |
+
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
|
| 15 |
+
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
|
| 16 |
+
|
| 17 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 18 |
+
// Attention ops
|
| 19 |
+
// Compute the attention between an input query and the cached
|
| 20 |
+
// keys/values using PagedAttention.
|
| 21 |
+
ops.def(
|
| 22 |
+
"paged_attention_v1("
|
| 23 |
+
" Tensor! out, Tensor query, Tensor key_cache,"
|
| 24 |
+
" Tensor value_cache, int num_kv_heads, float scale,"
|
| 25 |
+
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
| 26 |
+
" int max_seq_len, Tensor? alibi_slopes,"
|
| 27 |
+
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
| 28 |
+
" int tp_rank, int blocksparse_local_blocks,"
|
| 29 |
+
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
| 30 |
+
" int blocksparse_head_sliding_step) -> ()");
|
| 31 |
+
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
|
| 32 |
+
|
| 33 |
+
// PagedAttention V2.
|
| 34 |
+
ops.def(
|
| 35 |
+
"paged_attention_v2("
|
| 36 |
+
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
| 37 |
+
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
| 38 |
+
" Tensor value_cache, int num_kv_heads, float scale,"
|
| 39 |
+
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
| 40 |
+
" int max_seq_len, Tensor? alibi_slopes,"
|
| 41 |
+
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
| 42 |
+
" int tp_rank, int blocksparse_local_blocks,"
|
| 43 |
+
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
| 44 |
+
" int blocksparse_head_sliding_step) -> ()");
|
| 45 |
+
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
| 46 |
+
|
| 47 |
+
// Swap in (out) the cache blocks from src to dst.
|
| 48 |
+
ops.def(
|
| 49 |
+
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
|
| 50 |
+
ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
| 51 |
+
|
| 52 |
+
// Copy the cache blocks from src to dst.
|
| 53 |
+
ops.def(
|
| 54 |
+
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
| 55 |
+
"Tensor block_mapping) -> ()");
|
| 56 |
+
ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
| 57 |
+
|
| 58 |
+
// Reshape the key and value tensors and cache them.
|
| 59 |
+
ops.def(
|
| 60 |
+
"reshape_and_cache(Tensor key, Tensor value,"
|
| 61 |
+
" Tensor! key_cache, Tensor! value_cache,"
|
| 62 |
+
" Tensor slot_mapping,"
|
| 63 |
+
" str kv_cache_dtype,"
|
| 64 |
+
" Tensor k_scale, Tensor v_scale) -> ()");
|
| 65 |
+
ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
| 66 |
+
|
| 67 |
+
// Reshape the key and value tensors and cache them.
|
| 68 |
+
ops.def(
|
| 69 |
+
"reshape_and_cache_flash(Tensor key, Tensor value,"
|
| 70 |
+
" Tensor! key_cache,"
|
| 71 |
+
" Tensor! value_cache,"
|
| 72 |
+
" Tensor slot_mapping,"
|
| 73 |
+
" str kv_cache_dtype,"
|
| 74 |
+
" Tensor k_scale, Tensor v_scale) -> ()");
|
| 75 |
+
ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
| 76 |
+
&reshape_and_cache_flash);
|
| 77 |
+
|
| 78 |
+
// Gets the specified device attribute.
|
| 79 |
+
ops.def("get_device_attribute(int attribute, int device_id) -> int");
|
| 80 |
+
ops.impl("get_device_attribute", &get_device_attribute);
|
| 81 |
+
|
| 82 |
+
// Gets the maximum shared memory per block device attribute.
|
| 83 |
+
ops.def(
|
| 84 |
+
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
| 85 |
+
ops.impl("get_max_shared_memory_per_block_device_attribute",
|
| 86 |
+
&get_max_shared_memory_per_block_device_attribute);
|
| 87 |
+
|
| 88 |
+
// Convert the key and value cache to fp8 data type.
|
| 89 |
+
ops.def(
|
| 90 |
+
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
| 91 |
+
"str kv_cache_dtype) -> ()");
|
| 92 |
+
ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
void paged_attention_v1(
|
| 6 |
+
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
| 7 |
+
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
| 8 |
+
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
| 9 |
+
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
| 10 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
| 11 |
+
torch::Tensor& v_scale, const int64_t tp_rank,
|
| 12 |
+
const int64_t blocksparse_local_blocks,
|
| 13 |
+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
| 14 |
+
const int64_t blocksparse_head_sliding_step);
|
| 15 |
+
|
| 16 |
+
void paged_attention_v2(
|
| 17 |
+
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
| 18 |
+
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
| 19 |
+
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
| 20 |
+
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
| 21 |
+
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
| 22 |
+
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
| 23 |
+
torch::Tensor& v_scale, const int64_t tp_rank,
|
| 24 |
+
const int64_t blocksparse_local_blocks,
|
| 25 |
+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
| 26 |
+
const int64_t blocksparse_head_sliding_step);
|
| 27 |
+
|
| 28 |
+
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
| 29 |
+
const torch::Tensor& block_mapping);
|
| 30 |
+
|
| 31 |
+
// Note: the key_caches and value_caches vectors are constant but
|
| 32 |
+
// not the Tensors they contain. The vectors need to be const refs
|
| 33 |
+
// in order to satisfy pytorch's C++ operator registration code.
|
| 34 |
+
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
| 35 |
+
std::vector<torch::Tensor> const& value_caches,
|
| 36 |
+
const torch::Tensor& block_mapping);
|
| 37 |
+
|
| 38 |
+
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
| 39 |
+
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
| 40 |
+
torch::Tensor& slot_mapping,
|
| 41 |
+
const std::string& kv_cache_dtype,
|
| 42 |
+
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
| 43 |
+
|
| 44 |
+
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
| 45 |
+
torch::Tensor& key_cache,
|
| 46 |
+
torch::Tensor& value_cache,
|
| 47 |
+
torch::Tensor& slot_mapping,
|
| 48 |
+
const std::string& kv_cache_dtype,
|
| 49 |
+
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
| 50 |
+
|
| 51 |
+
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
| 52 |
+
|
| 53 |
+
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
| 54 |
+
|
| 55 |
+
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
| 56 |
+
const double scale, const std::string& kv_cache_dtype);
|