|
|
|
|
|
#include "hip/hip_runtime.h" |
|
|
#include "grouped_gemm.h" |
|
|
|
|
|
#ifdef __HIP_PLATFORM_AMD__ |
|
|
|
|
|
#include "gpu_backend_hip.h" |
|
|
#include <ATen/hip/HIPContext.h> |
|
|
#include <hipblaslt/hipblaslt.h> |
|
|
#include <torch/autograd.h> |
|
|
#include <vector> |
|
|
#include <algorithm> |
|
|
#include <cctype> |
|
|
#include <cstdlib> |
|
|
#include <string> |
|
|
|
|
|
namespace grouped_gemm { |
|
|
namespace { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool use_hipblaslt_backend() { |
|
|
static int cached = [] { |
|
|
const char* raw = std::getenv("MEGABLOCKS_GG_USE_HIPBLASLT"); |
|
|
if (raw == nullptr) { |
|
|
return 0; |
|
|
} |
|
|
std::string value(raw); |
|
|
std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { |
|
|
return static_cast<char>(std::tolower(c)); |
|
|
}); |
|
|
if (value == "1" || value == "true" || value == "yes" || value == "on") { |
|
|
return 1; |
|
|
} |
|
|
return 0; |
|
|
}(); |
|
|
return cached == 1; |
|
|
} |
|
|
|
|
|
inline void hipblaslt_check(hipblasStatus_t status, const char* expr) { |
|
|
TORCH_CHECK(status == HIPBLAS_STATUS_SUCCESS, "hipBLASLt call failed with status ", status, " when executing ", expr); |
|
|
} |
|
|
#define HIPBLASLT_CHECK(cmd) hipblaslt_check((cmd), #cmd) |
|
|
|
|
|
hipblasLtHandle_t hipblaslt_handle() { |
|
|
static hipblasLtHandle_t handle = [] { |
|
|
hipblasLtHandle_t h; |
|
|
HIPBLASLT_CHECK(hipblasLtCreate(&h)); |
|
|
return h; |
|
|
}(); |
|
|
return handle; |
|
|
} |
|
|
|
|
|
void hipblaslt_run_matmul(const void* a_ptr, |
|
|
const void* b_ptr, |
|
|
const void* c_ptr, |
|
|
void* d_ptr, |
|
|
int64_t rows_a, |
|
|
int64_t cols_a, |
|
|
int64_t rows_b, |
|
|
int64_t cols_b, |
|
|
int64_t rows_d, |
|
|
int64_t cols_d, |
|
|
int64_t lda, |
|
|
int64_t ldb, |
|
|
int64_t ldc, |
|
|
int64_t ldd, |
|
|
hipblasOperation_t op_a, |
|
|
hipblasOperation_t op_b, |
|
|
bool accumulate) { |
|
|
if (rows_a == 0 || cols_a == 0 || rows_b == 0 || cols_b == 0 || rows_d == 0 || cols_d == 0) |
|
|
return; |
|
|
|
|
|
auto handle = hipblaslt_handle(); |
|
|
auto stream = c10::hip::getCurrentHIPStream(); |
|
|
|
|
|
hipblasLtMatmulDesc_t matmul_desc; |
|
|
HIPBLASLT_CHECK(hipblasLtMatmulDescCreate(&matmul_desc, HIPBLAS_COMPUTE_32F, HIP_R_32F)); |
|
|
HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute( |
|
|
matmul_desc, HIPBLASLT_MATMUL_DESC_TRANSA, &op_a, sizeof(op_a))); |
|
|
HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute( |
|
|
matmul_desc, HIPBLASLT_MATMUL_DESC_TRANSB, &op_b, sizeof(op_b))); |
|
|
hipblasLtPointerMode_t pointer_mode = HIPBLASLT_POINTER_MODE_HOST; |
|
|
HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute( |
|
|
matmul_desc, HIPBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); |
|
|
|
|
|
hipblasLtOrder_t order = HIPBLASLT_ORDER_ROW; |
|
|
|
|
|
hipblasLtMatrixLayout_t layout_a; |
|
|
HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_a, HIP_R_16BF, rows_a, cols_a, lda)); |
|
|
HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( |
|
|
layout_a, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order))); |
|
|
|
|
|
hipblasLtMatrixLayout_t layout_b; |
|
|
HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_b, HIP_R_16BF, rows_b, cols_b, ldb)); |
|
|
HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( |
|
|
layout_b, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order))); |
|
|
|
|
|
hipblasLtMatrixLayout_t layout_c; |
|
|
HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_c, HIP_R_16BF, rows_d, cols_d, ldc)); |
|
|
HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( |
|
|
layout_c, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order))); |
|
|
|
|
|
hipblasLtMatrixLayout_t layout_d; |
|
|
HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_d, HIP_R_16BF, rows_d, cols_d, ldd)); |
|
|
HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( |
|
|
layout_d, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order))); |
|
|
|
|
|
hipblasLtMatmulPreference_t preference; |
|
|
HIPBLASLT_CHECK(hipblasLtMatmulPreferenceCreate(&preference)); |
|
|
uint64_t workspace_size = 0; |
|
|
HIPBLASLT_CHECK(hipblasLtMatmulPreferenceSetAttribute( |
|
|
preference, |
|
|
HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, |
|
|
&workspace_size, |
|
|
sizeof(workspace_size))); |
|
|
|
|
|
hipblasLtMatmulHeuristicResult_t heuristic{}; |
|
|
int returned_results = 0; |
|
|
HIPBLASLT_CHECK(hipblasLtMatmulAlgoGetHeuristic( |
|
|
handle, |
|
|
matmul_desc, |
|
|
layout_a, |
|
|
layout_b, |
|
|
layout_c, |
|
|
layout_d, |
|
|
preference, |
|
|
1, |
|
|
&heuristic, |
|
|
&returned_results)); |
|
|
TORCH_CHECK(returned_results > 0, "hipBLASLt could not find a suitable algorithm"); |
|
|
|
|
|
const float alpha = 1.0f; |
|
|
const float beta = accumulate ? 1.0f : 0.0f; |
|
|
|
|
|
HIPBLASLT_CHECK(hipblasLtMatmul(handle, |
|
|
matmul_desc, |
|
|
&alpha, |
|
|
a_ptr, |
|
|
layout_a, |
|
|
b_ptr, |
|
|
layout_b, |
|
|
&beta, |
|
|
c_ptr, |
|
|
layout_c, |
|
|
d_ptr, |
|
|
layout_d, |
|
|
&heuristic.algo, |
|
|
nullptr, |
|
|
0, |
|
|
stream)); |
|
|
|
|
|
hipblasLtMatmulPreferenceDestroy(preference); |
|
|
hipblasLtMatrixLayoutDestroy(layout_d); |
|
|
hipblasLtMatrixLayoutDestroy(layout_c); |
|
|
hipblasLtMatrixLayoutDestroy(layout_b); |
|
|
hipblasLtMatrixLayoutDestroy(layout_a); |
|
|
hipblasLtMatmulDescDestroy(matmul_desc); |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
torch::Tensor hipblaslt_gmm_internal(torch::Tensor a, |
|
|
torch::Tensor b, |
|
|
torch::Tensor batch_sizes, |
|
|
bool trans_a, |
|
|
bool trans_b, |
|
|
c10::optional<torch::Tensor> c_opt) { |
|
|
torch::NoGradGuard no_grad; |
|
|
TORCH_CHECK(a.is_cuda(), "hipblaslt_gmm requires CUDA tensors"); |
|
|
TORCH_CHECK(b.is_cuda(), "hipblaslt_gmm requires CUDA tensors"); |
|
|
TORCH_CHECK(a.scalar_type() == torch::kBFloat16, "hipblaslt_gmm expects BF16 inputs"); |
|
|
TORCH_CHECK(b.scalar_type() == torch::kBFloat16, "hipblaslt_gmm expects BF16 weights"); |
|
|
TORCH_CHECK(batch_sizes.device().is_cpu(), "batch_sizes must reside on CPU"); |
|
|
|
|
|
a = a.contiguous(); |
|
|
b = b.contiguous(); |
|
|
|
|
|
auto device = a.device(); |
|
|
auto dtype = a.scalar_type(); |
|
|
const bool use_hip = use_hipblaslt_backend(); |
|
|
|
|
|
const auto counts_ptr = batch_sizes.data_ptr<int64_t>(); |
|
|
const int64_t num_experts = batch_sizes.size(0); |
|
|
std::vector<int64_t> prefix(num_experts); |
|
|
int64_t running = 0; |
|
|
for (int64_t i = 0; i < num_experts; ++i) { |
|
|
running += counts_ptr[i]; |
|
|
prefix[i] = running; |
|
|
} |
|
|
const int64_t tokens = num_experts ? prefix.back() : 0; |
|
|
TORCH_CHECK(a.size(0) == tokens, "tokens mismatch with batch sizes"); |
|
|
|
|
|
torch::Tensor out; |
|
|
if (trans_a) { |
|
|
const int64_t hidden_in = a.size(1); |
|
|
const int64_t hidden_out = b.size(1); |
|
|
out = c_opt.value_or(torch::empty({num_experts, hidden_in, hidden_out}, |
|
|
a.options().dtype(dtype))); |
|
|
TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous"); |
|
|
|
|
|
auto b_contig = b.contiguous(); |
|
|
|
|
|
if (use_hip) { |
|
|
int64_t start = 0; |
|
|
for (int64_t expert = 0; expert < num_experts; ++expert) { |
|
|
const int64_t end = prefix[expert]; |
|
|
const int64_t rows = end - start; |
|
|
auto out_chunk = out.select(0, expert); |
|
|
if (rows == 0) { |
|
|
out_chunk.zero_(); |
|
|
start = end; |
|
|
continue; |
|
|
} |
|
|
|
|
|
auto a_chunk = a.narrow(0, start, rows).contiguous(); |
|
|
auto b_chunk = b_contig.narrow(0, start, rows).contiguous(); |
|
|
|
|
|
hipblaslt_run_matmul(a_chunk.data_ptr(), |
|
|
b_chunk.data_ptr(), |
|
|
out_chunk.data_ptr(), |
|
|
out_chunk.data_ptr(), |
|
|
rows, |
|
|
hidden_in, |
|
|
rows, |
|
|
hidden_out, |
|
|
hidden_in, |
|
|
hidden_out, |
|
|
hidden_in, |
|
|
hidden_out, |
|
|
hidden_out, |
|
|
hidden_out, |
|
|
HIPBLAS_OP_T, |
|
|
HIPBLAS_OP_N, |
|
|
false); |
|
|
start = end; |
|
|
} |
|
|
} else { |
|
|
int64_t start = 0; |
|
|
for (int64_t expert = 0; expert < num_experts; ++expert) { |
|
|
const int64_t end = prefix[expert]; |
|
|
const int64_t rows = end - start; |
|
|
auto out_chunk = out.select(0, expert); |
|
|
if (rows == 0) { |
|
|
out_chunk.zero_(); |
|
|
start = end; |
|
|
continue; |
|
|
} |
|
|
|
|
|
auto a_slice = a.narrow(0, start, rows); |
|
|
auto b_slice = b_contig.narrow(0, start, rows); |
|
|
|
|
|
auto a_f32 = a_slice.contiguous().to(torch::kFloat32); |
|
|
auto b_f32 = b_slice.contiguous().to(torch::kFloat32); |
|
|
|
|
|
auto prod = torch::matmul(a_f32.transpose(0, 1), b_f32); |
|
|
auto prod_bf16 = prod.to(dtype); |
|
|
|
|
|
out_chunk.copy_(prod_bf16); |
|
|
start = end; |
|
|
} |
|
|
} |
|
|
return out; |
|
|
} |
|
|
|
|
|
if (trans_b) { |
|
|
const int64_t hidden_in = a.size(1); |
|
|
const int64_t hidden_out = b.size(1); |
|
|
out = c_opt.value_or(torch::empty({tokens, hidden_out}, a.options())); |
|
|
TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous"); |
|
|
|
|
|
auto b_contig = b.contiguous(); |
|
|
|
|
|
if (use_hip) { |
|
|
int64_t start = 0; |
|
|
for (int64_t expert = 0; expert < num_experts; ++expert) { |
|
|
const int64_t end = prefix[expert]; |
|
|
const int64_t rows = end - start; |
|
|
if (rows == 0) { |
|
|
start = end; |
|
|
continue; |
|
|
} |
|
|
auto a_chunk = a.narrow(0, start, rows).contiguous(); |
|
|
auto b_chunk = b_contig.select(0, expert).contiguous(); |
|
|
auto out_chunk = out.narrow(0, start, rows); |
|
|
|
|
|
hipblaslt_run_matmul(a_chunk.data_ptr(), |
|
|
b_chunk.data_ptr(), |
|
|
out_chunk.data_ptr(), |
|
|
out_chunk.data_ptr(), |
|
|
rows, |
|
|
hidden_in, |
|
|
hidden_out, |
|
|
hidden_in, |
|
|
rows, |
|
|
hidden_out, |
|
|
hidden_in, |
|
|
hidden_in, |
|
|
hidden_out, |
|
|
hidden_out, |
|
|
HIPBLAS_OP_N, |
|
|
HIPBLAS_OP_T, |
|
|
false); |
|
|
start = end; |
|
|
} |
|
|
} else { |
|
|
int64_t start = 0; |
|
|
for (int64_t expert = 0; expert < num_experts; ++expert) { |
|
|
const int64_t end = prefix[expert]; |
|
|
const int64_t rows = end - start; |
|
|
if (rows == 0) { |
|
|
start = end; |
|
|
continue; |
|
|
} |
|
|
auto a_slice = a.narrow(0, start, rows); |
|
|
auto b_slice = b_contig.select(0, expert); |
|
|
auto out_chunk = out.narrow(0, start, rows); |
|
|
|
|
|
auto a_f32 = a_slice.contiguous().to(torch::kFloat32); |
|
|
auto b_f32 = b_slice.contiguous().to(torch::kFloat32); |
|
|
|
|
|
auto prod = torch::matmul(a_f32, b_f32.transpose(0, 1)); |
|
|
auto prod_bf16 = prod.to(dtype); |
|
|
|
|
|
out_chunk.copy_(prod_bf16); |
|
|
start = end; |
|
|
} |
|
|
} |
|
|
return out; |
|
|
} |
|
|
|
|
|
const int64_t hidden_out = a.size(1); |
|
|
const int64_t hidden_in = b.size(2); |
|
|
out = c_opt.value_or(torch::empty({tokens, hidden_in}, a.options())); |
|
|
TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous"); |
|
|
|
|
|
auto b_contig = b.contiguous(); |
|
|
|
|
|
if (use_hip) { |
|
|
int64_t start = 0; |
|
|
for (int64_t expert = 0; expert < num_experts; ++expert) { |
|
|
const int64_t end = prefix[expert]; |
|
|
const int64_t rows = end - start; |
|
|
if (rows == 0) { |
|
|
start = end; |
|
|
continue; |
|
|
} |
|
|
auto a_chunk = a.narrow(0, start, rows).contiguous(); |
|
|
auto b_chunk = b_contig.select(0, expert).contiguous(); |
|
|
auto out_chunk = out.narrow(0, start, rows); |
|
|
|
|
|
hipblaslt_run_matmul(a_chunk.data_ptr(), |
|
|
b_chunk.data_ptr(), |
|
|
out_chunk.data_ptr(), |
|
|
out_chunk.data_ptr(), |
|
|
rows, |
|
|
hidden_out, |
|
|
hidden_out, |
|
|
hidden_in, |
|
|
rows, |
|
|
hidden_in, |
|
|
hidden_out, |
|
|
hidden_in, |
|
|
hidden_in, |
|
|
hidden_in, |
|
|
HIPBLAS_OP_N, |
|
|
HIPBLAS_OP_N, |
|
|
false); |
|
|
start = end; |
|
|
} |
|
|
} else { |
|
|
int64_t start = 0; |
|
|
for (int64_t expert = 0; expert < num_experts; ++expert) { |
|
|
const int64_t end = prefix[expert]; |
|
|
const int64_t rows = end - start; |
|
|
if (rows == 0) { |
|
|
start = end; |
|
|
continue; |
|
|
} |
|
|
auto a_slice = a.narrow(0, start, rows); |
|
|
auto b_slice = b_contig.select(0, expert); |
|
|
auto out_chunk = out.narrow(0, start, rows); |
|
|
|
|
|
auto a_f32 = a_slice.contiguous().to(torch::kFloat32); |
|
|
auto b_f32 = b_slice.contiguous().to(torch::kFloat32); |
|
|
|
|
|
auto prod = torch::matmul(a_f32, b_f32); |
|
|
auto prod_bf16 = prod.to(dtype); |
|
|
|
|
|
out_chunk.copy_(prod_bf16); |
|
|
start = end; |
|
|
} |
|
|
} |
|
|
return out; |
|
|
} |
|
|
|
|
|
void GroupedGemm(torch::Tensor a, |
|
|
torch::Tensor b, |
|
|
torch::Tensor c, |
|
|
torch::Tensor batch_sizes, |
|
|
bool trans_a, |
|
|
bool trans_b) { |
|
|
if (!batch_sizes.device().is_cpu()) { |
|
|
batch_sizes = batch_sizes.cpu(); |
|
|
} |
|
|
TORCH_CHECK(c.is_contiguous(), "Output tensor must be contiguous"); |
|
|
auto result = hipblaslt_gmm_internal(a, b, batch_sizes, trans_a, trans_b, c); |
|
|
if (!c.is_alias_of(result)) { |
|
|
c.copy_(result); |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
#else |
|
|
#include "fill_arguments_hip.cuh" |
|
|
|
|
|
#include <ATen/hip/HIPContext.h> |
|
|
#include <ATen/hip/detail/KernelUtils.h> |
|
|
#include <c10/util/BFloat16.h> |
|
|
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h> |
|
|
#include <hipcub/hipcub.hpp> |
|
|
#include <torch/torch.h> |
|
|
|
|
|
#include "cutlass/bfloat16.h" |
|
|
#include "cutlass/complex.h" |
|
|
#include "cutlass/gemm/kernel/gemm_grouped.h" |
|
|
#include "cutlass/gemm/kernel/default_gemm_grouped.h" |
|
|
#include "cutlass/gemm/device/gemm_grouped.h" |
|
|
|
|
|
#include <type_traits> |
|
|
|
|
|
namespace grouped_gemm { |
|
|
|
|
|
#define CUDA_CALL(code) \ |
|
|
do { \ |
|
|
hipError_t status = code; \ |
|
|
std::string err = hipGetErrorString(status); \ |
|
|
TORCH_CHECK(status == hipSuccess, err); \ |
|
|
} while (0) |
|
|
|
|
|
#define CUBLAS_CALL(code) \ |
|
|
do { \ |
|
|
hipblasStatus_t status = code; \ |
|
|
TORCH_CHECK(status == HIPBLAS_STATUS_SUCCESS, "CuBLAS Error"); \ |
|
|
} while (0) |
|
|
|
|
|
#define GROUPED_GEMM_STRINGIFY_HELPER(x) #x |
|
|
#define GROUPED_GEMM_STRINGIFY(x) \ |
|
|
GROUPED_GEMM_STRINGIFY_HELPER(x) |
|
|
|
|
|
template <bool trans> |
|
|
using GroupedGemmInputLayout = std::conditional_t<trans, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>; |
|
|
|
|
|
using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration< |
|
|
::cutlass::arch::OpClassTensorOp, |
|
|
::cutlass::arch::Sm80, |
|
|
::cutlass::bfloat16_t, |
|
|
::cutlass::bfloat16_t, |
|
|
::cutlass::bfloat16_t, |
|
|
float |
|
|
>; |
|
|
|
|
|
|
|
|
template <bool trans_a, bool trans_b> |
|
|
using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< |
|
|
|
|
|
::cutlass::bfloat16_t, |
|
|
GroupedGemmInputLayout<trans_a>, |
|
|
::cutlass::ComplexTransform::kNone, |
|
|
GroupedGemmConfig::kAlignmentA, |
|
|
|
|
|
::cutlass::bfloat16_t, |
|
|
GroupedGemmInputLayout<trans_b>, |
|
|
::cutlass::ComplexTransform::kNone, |
|
|
GroupedGemmConfig::kAlignmentB, |
|
|
|
|
|
::cutlass::bfloat16_t, |
|
|
::cutlass::layout::RowMajor, |
|
|
float, |
|
|
::cutlass::arch::OpClassTensorOp, |
|
|
::cutlass::arch::Sm80, |
|
|
GroupedGemmConfig::ThreadblockShape, |
|
|
GroupedGemmConfig::WarpShape, |
|
|
GroupedGemmConfig::InstructionShape, |
|
|
GroupedGemmConfig::EpilogueOutputOp, |
|
|
|
|
|
|
|
|
|
|
|
::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, |
|
|
|
|
|
GroupedGemmConfig::kStages>::GemmKernel; |
|
|
|
|
|
template <bool trans_a, bool trans_b> |
|
|
using GemmGrouped = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernel<trans_a, trans_b>>; |
|
|
|
|
|
template <typename T> |
|
|
torch::Tensor CopyToDevice(const std::vector<T> &x, const torch::Device &device) { |
|
|
size_t bytes = x.size() * sizeof(T); |
|
|
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device); |
|
|
torch::Tensor out = torch::empty(bytes, options); |
|
|
|
|
|
CUDA_CALL(hipMemcpyAsync(out.data_ptr(), |
|
|
x.data(), bytes, |
|
|
hipMemcpyHostToDevice, |
|
|
c10::hip::getCurrentHIPStreamMasqueradingAsCUDA())); |
|
|
return out; |
|
|
} |
|
|
|
|
|
template <typename T> |
|
|
static void ReorderArray(T* data, const std::vector<size_t>& indices) { |
|
|
|
|
|
std::vector<T> copy(data, data + indices.size()); |
|
|
for (size_t i = 0; i < indices.size(); ++i) { |
|
|
data[i] = copy.at(indices[i]); |
|
|
} |
|
|
} |
|
|
|
|
|
template <typename T> |
|
|
torch::Tensor TypedEmpty(size_t numel, const torch::Device& device) { |
|
|
return torch::empty(numel * sizeof(T), torch::dtype(torch::kInt8).device(device)); |
|
|
} |
|
|
|
|
|
struct RawGemmArguments { |
|
|
torch::Tensor lda, ldb, ldc, ptr_a, ptr_b, ptr_c, problem_sizes; |
|
|
int threadblock_count{}; |
|
|
}; |
|
|
|
|
|
template < |
|
|
typename Gemm, |
|
|
typename ElementA, typename ElementB, typename ElementC |
|
|
> |
|
|
RawGemmArguments MakeArgumentsOnDevice(int num_experts, const torch::Device& device) { |
|
|
TORCH_CHECK( |
|
|
num_experts <= kMaxExperts, |
|
|
"At most ", kMaxExperts, |
|
|
" experts are supported when batch_sizes is a CUDA tensor, but got ", num_experts |
|
|
); |
|
|
|
|
|
return RawGemmArguments { |
|
|
.lda = TypedEmpty<int64_t>(num_experts, device), |
|
|
.ldb = TypedEmpty<int64_t>(num_experts, device), |
|
|
.ldc = TypedEmpty<int64_t>(num_experts, device), |
|
|
.ptr_a = TypedEmpty<ElementA*>(num_experts, device), |
|
|
.ptr_b = TypedEmpty<ElementB*>(num_experts, device), |
|
|
.ptr_c = TypedEmpty<ElementC*>(num_experts, device), |
|
|
.problem_sizes = TypedEmpty<cutlass::gemm::GemmCoord>(num_experts, device), |
|
|
|
|
|
|
|
|
.threadblock_count = Gemm::sufficient(), |
|
|
}; |
|
|
} |
|
|
|
|
|
template < |
|
|
bool kDynamicK, |
|
|
typename Gemm, |
|
|
typename ElementA, typename ElementB, typename ElementC, |
|
|
typename LayoutA, typename LayoutB, typename LayoutC |
|
|
> |
|
|
RawGemmArguments MakeArgumentsOnHost(torch::Tensor a, |
|
|
torch::Tensor b, |
|
|
torch::Tensor c, |
|
|
torch::Tensor batch_sizes, |
|
|
::cutlass::gemm::GemmCoord coord_template, |
|
|
int64_t num_experts) { |
|
|
std::vector<::cutlass::gemm::GemmCoord> problem_sizes_host(num_experts); |
|
|
|
|
|
|
|
|
std::vector<int64_t> lda_host(num_experts), ldb_host(num_experts), ldc_host(num_experts); |
|
|
int64_t elements_a = 0, elements_b = 0, elements_c = 0; |
|
|
|
|
|
std::vector<ElementA *> ptr_a_host(num_experts), ptr_b_host(num_experts), ptr_c_host(num_experts); |
|
|
|
|
|
for (int i = 0; i < num_experts; ++i) { |
|
|
auto& problem = problem_sizes_host[i]; |
|
|
problem = coord_template; |
|
|
(kDynamicK ? problem.k() : problem.m()) = batch_sizes.data_ptr<int64_t>()[i]; |
|
|
|
|
|
lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0); |
|
|
ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0); |
|
|
ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0); |
|
|
|
|
|
ptr_a_host[i] = (ElementA*)a.data_ptr() + elements_a; |
|
|
ptr_b_host[i] = (ElementB*)b.data_ptr() + elements_b; |
|
|
ptr_c_host[i] = (ElementC*)c.data_ptr() + elements_c; |
|
|
|
|
|
elements_a += problem.m() * problem.k(); |
|
|
elements_b += problem.k() * problem.n(); |
|
|
elements_c += problem.m() * problem.n(); |
|
|
|
|
|
if (problem.k() == 0) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CUDA_CALL(hipMemsetAsync(ptr_c_host[i], |
|
|
0, |
|
|
problem.m() * problem.n() * sizeof(ElementC), |
|
|
c10::hip::getCurrentHIPStreamMasqueradingAsCUDA())); |
|
|
|
|
|
problem.m() = 0; |
|
|
problem.n() = 0; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (kDynamicK) { |
|
|
std::vector<size_t> indices(num_experts); |
|
|
std::iota(indices.begin(), indices.end(), 0); |
|
|
std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) { |
|
|
return problem_sizes_host[i].k() > problem_sizes_host[j].k(); |
|
|
}); |
|
|
|
|
|
ReorderArray(problem_sizes_host.data(), indices); |
|
|
ReorderArray(lda_host.data(), indices); |
|
|
ReorderArray(ldb_host.data(), indices); |
|
|
ReorderArray(ldc_host.data(), indices); |
|
|
ReorderArray(ptr_a_host.data(), indices); |
|
|
ReorderArray(ptr_b_host.data(), indices); |
|
|
ReorderArray(ptr_c_host.data(), indices); |
|
|
} |
|
|
|
|
|
|
|
|
return RawGemmArguments { |
|
|
.lda = CopyToDevice(lda_host, a.device()), |
|
|
.ldb = CopyToDevice(ldb_host, a.device()), |
|
|
.ldc = CopyToDevice(ldc_host, a.device()), |
|
|
.ptr_a = CopyToDevice(ptr_a_host, a.device()), |
|
|
.ptr_b = CopyToDevice(ptr_b_host, a.device()), |
|
|
.ptr_c = CopyToDevice(ptr_c_host, a.device()), |
|
|
.problem_sizes = CopyToDevice(problem_sizes_host, a.device()), |
|
|
|
|
|
|
|
|
.threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts), |
|
|
}; |
|
|
} |
|
|
|
|
|
template < |
|
|
bool kDynamicK, |
|
|
typename Gemm, |
|
|
typename ElementA, typename ElementB, typename ElementC, |
|
|
typename LayoutA, typename LayoutB, typename LayoutC |
|
|
> |
|
|
typename Gemm::Arguments MakeArguments(torch::Tensor a, |
|
|
torch::Tensor b, |
|
|
torch::Tensor c, |
|
|
torch::Tensor batch_sizes, |
|
|
::cutlass::gemm::GemmCoord coord_template, |
|
|
int64_t num_experts) { |
|
|
RawGemmArguments raw_args; |
|
|
if (batch_sizes.is_cuda()) { |
|
|
raw_args = MakeArgumentsOnDevice< |
|
|
Gemm, ElementA, ElementB, ElementC |
|
|
>(num_experts, a.device()); |
|
|
} else { |
|
|
raw_args = MakeArgumentsOnHost< |
|
|
kDynamicK, |
|
|
Gemm, |
|
|
ElementA, ElementB, ElementC, |
|
|
LayoutA, LayoutB, LayoutC |
|
|
>(a, b, c, batch_sizes, coord_template, num_experts); |
|
|
} |
|
|
|
|
|
printf("Using %d threadblocks for grouped GEMM.\n", raw_args.threadblock_count); |
|
|
|
|
|
if (!raw_args.threadblock_count) { |
|
|
TORCH_CHECK(false, "Grouped GEMM execution not possible with HW"); |
|
|
} |
|
|
|
|
|
typename Gemm::EpilogueOutputOp::Params epilogue_op(1.0f, 0.0f); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)raw_args.problem_sizes.data_ptr(), |
|
|
(int)num_experts, |
|
|
(int)raw_args.threadblock_count, |
|
|
epilogue_op, |
|
|
(ElementA**)raw_args.ptr_a.data_ptr(), |
|
|
(ElementB**)raw_args.ptr_b.data_ptr(), |
|
|
(ElementC**)raw_args.ptr_c.data_ptr(), |
|
|
(ElementC**)raw_args.ptr_c.data_ptr(), |
|
|
(int64_t*)raw_args.lda.data_ptr(), |
|
|
(int64_t*)raw_args.ldb.data_ptr(), |
|
|
(int64_t*)raw_args.ldc.data_ptr(), |
|
|
(int64_t*)raw_args.ldc.data_ptr(), |
|
|
nullptr); |
|
|
return arguments; |
|
|
} |
|
|
|
|
|
template < |
|
|
bool trans_a, |
|
|
typename ElementA, typename ElementB, typename ElementC, |
|
|
typename LayoutA, typename LayoutB, typename LayoutC, |
|
|
typename Arguments |
|
|
> |
|
|
void FillCutlassArguments(int num_experts, |
|
|
torch::Tensor batch_sizes, |
|
|
torch::Tensor a, |
|
|
torch::Tensor b, |
|
|
torch::Tensor c, |
|
|
const Arguments& arguments, |
|
|
::cutlass::gemm::GemmCoord coord_template) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hipLaunchKernelGGL(( FillArguments< |
|
|
trans_a, |
|
|
ElementA, ElementB, ElementC, |
|
|
LayoutA, LayoutB, LayoutC |
|
|
>), dim3(1), dim3(kMaxExperts), 0, c10::hip::getCurrentHIPStreamMasqueradingAsCUDA(), |
|
|
num_experts, batch_sizes.data_ptr<int64_t>(), |
|
|
(ElementA*)a.data_ptr(), (ElementB*)b.data_ptr(), (ElementC*)c.data_ptr(), |
|
|
arguments, coord_template |
|
|
); |
|
|
C10_HIP_KERNEL_LAUNCH_CHECK(); |
|
|
} |
|
|
|
|
|
template <typename Args> |
|
|
void RemoveK0Problems(int num_experts, const Args& arguments) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hipLaunchKernelGGL(( ZeroOutK0Outputs<>), |
|
|
dim3(arguments.threadblock_count), dim3(at::cuda::detail::CUDA_NUM_THREADS), 0, c10::hip::getCurrentHIPStreamMasqueradingAsCUDA() |
|
|
, |
|
|
num_experts, arguments |
|
|
); |
|
|
hipLaunchKernelGGL(( IgnoreK0Problems<>), |
|
|
dim3(1), dim3(kMaxExperts), 0, c10::hip::getCurrentHIPStreamMasqueradingAsCUDA() |
|
|
, |
|
|
num_experts, arguments |
|
|
); |
|
|
} |
|
|
|
|
|
template <bool trans_a, bool trans_b> |
|
|
torch::Tensor CutlassGroupedGemm(torch::Tensor a, |
|
|
torch::Tensor b, |
|
|
torch::Tensor c, |
|
|
torch::Tensor batch_sizes, |
|
|
::cutlass::gemm::GemmCoord coord_template) { |
|
|
using Gemm = GemmGrouped<trans_a, trans_b>; |
|
|
using LayoutA = typename Gemm::LayoutA; |
|
|
using LayoutB = typename Gemm::LayoutB; |
|
|
using LayoutC = typename Gemm::LayoutC; |
|
|
|
|
|
using ElementA = typename Gemm::ElementA; |
|
|
using ElementB = typename Gemm::ElementB; |
|
|
using ElementC = typename Gemm::ElementC; |
|
|
|
|
|
Gemm gemm; |
|
|
int64_t num_experts = batch_sizes.size(0); |
|
|
auto arguments = MakeArguments< |
|
|
trans_a, |
|
|
Gemm, |
|
|
ElementA, ElementB, ElementC, |
|
|
LayoutA, LayoutB, LayoutC |
|
|
>(a, b, c, batch_sizes, coord_template, num_experts); |
|
|
int64_t workspace_size = gemm.get_workspace_size(arguments); |
|
|
auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device()); |
|
|
torch::Tensor workspace = torch::empty(workspace_size, options); |
|
|
|
|
|
if (batch_sizes.is_cuda()) { |
|
|
FillCutlassArguments< |
|
|
trans_a, |
|
|
ElementA, ElementB, ElementC, |
|
|
LayoutA, LayoutB, LayoutC |
|
|
>(num_experts, batch_sizes, a, b, c, arguments, coord_template); |
|
|
|
|
|
RemoveK0Problems<>(num_experts, arguments); |
|
|
} |
|
|
|
|
|
|
|
|
if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) { |
|
|
TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); |
|
|
} |
|
|
|
|
|
|
|
|
if(gemm.run(c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()) != cutlass::Status::kSuccess) { |
|
|
TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); |
|
|
} |
|
|
return c; |
|
|
} |
|
|
|
|
|
void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a, |
|
|
c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b, |
|
|
c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) { |
|
|
int m = trans_b ? b_rows : b_cols; |
|
|
int k = trans_b ? b_cols : b_rows; |
|
|
int n = trans_a ? a_cols : a_rows; |
|
|
|
|
|
int lda = trans_a ? n : k; |
|
|
int ldb = trans_b ? k : m; |
|
|
hipblasOperation_t transpose_a = trans_a ? HIPBLAS_OP_T : HIPBLAS_OP_N; |
|
|
hipblasOperation_t transpose_b = trans_b ? HIPBLAS_OP_T : HIPBLAS_OP_N; |
|
|
|
|
|
float alpha = 1.0, beta = 0.0; |
|
|
CUBLAS_CALL(hipblasGemmEx(at::cuda::getCurrentCUDABlasHandle(), |
|
|
transpose_b, transpose_a, |
|
|
m, n, k, &alpha, |
|
|
b, HIP_R_16BF, ldb, |
|
|
a, HIP_R_16BF, lda, |
|
|
&beta, |
|
|
c, HIP_R_16BF, c_cols, HIP_R_32F, |
|
|
HIPBLAS_GEMM_DEFAULT)); |
|
|
} |
|
|
|
|
|
void CublasGroupedGemm(torch::Tensor a, |
|
|
torch::Tensor b, |
|
|
torch::Tensor c, |
|
|
torch::Tensor batch_sizes, |
|
|
bool trans_b) { |
|
|
int64_t bs = batch_sizes.size(0), k = a.size(1); |
|
|
int64_t n = trans_b ? b.size(1) : b.size(2); |
|
|
int64_t b_rows = b.size(1), b_cols = b.size(2); |
|
|
c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>(); |
|
|
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>(); |
|
|
c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>(); |
|
|
for (int i = 0; i < bs; ++i) { |
|
|
int64_t m = batch_sizes.data_ptr<int64_t>()[i]; |
|
|
CublasGemm(a_ptr, m, k, false, |
|
|
b_ptr, b_rows, b_cols, trans_b, |
|
|
c_ptr, m, n); |
|
|
a_ptr += m * k; |
|
|
b_ptr += b_rows * b_cols; |
|
|
c_ptr += m * n; |
|
|
} |
|
|
} |
|
|
|
|
|
void CublasGroupedGemmVariableK(torch::Tensor a, |
|
|
torch::Tensor b, |
|
|
torch::Tensor c, |
|
|
torch::Tensor batch_sizes) { |
|
|
int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1); |
|
|
c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>(); |
|
|
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>(); |
|
|
c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>(); |
|
|
for (int i = 0; i < bs; ++i) { |
|
|
int64_t k = batch_sizes.data_ptr<int64_t>()[i]; |
|
|
CublasGemm(a_ptr, k, m, true, |
|
|
b_ptr, k, n, false, |
|
|
c_ptr, m, n); |
|
|
a_ptr += k * m; |
|
|
b_ptr += k * n; |
|
|
c_ptr += m * n; |
|
|
} |
|
|
} |
|
|
|
|
|
void GroupedGemmVariableK(torch::Tensor a, |
|
|
torch::Tensor b, |
|
|
torch::Tensor c, |
|
|
torch::Tensor batch_sizes) { |
|
|
|
|
|
|
|
|
TORCH_CHECK(b.is_cuda()); |
|
|
TORCH_CHECK(b.ndimension() == 2); |
|
|
TORCH_CHECK(b.scalar_type() == torch::kBFloat16); |
|
|
|
|
|
|
|
|
int64_t tokens = a.size(0), num_experts = batch_sizes.size(0); |
|
|
int64_t m = a.size(1), n = b.size(1); |
|
|
|
|
|
|
|
|
TORCH_CHECK(tokens == b.size(0)); |
|
|
|
|
|
|
|
|
TORCH_CHECK(c.is_cuda()); |
|
|
TORCH_CHECK(c.ndimension() == 3); |
|
|
TORCH_CHECK(c.scalar_type() == torch::kBFloat16); |
|
|
TORCH_CHECK(c.size(0) == num_experts); |
|
|
TORCH_CHECK(c.size(1) == m); |
|
|
TORCH_CHECK(c.size(2) == n); |
|
|
|
|
|
|
|
|
CublasGroupedGemmVariableK(a, b, c, batch_sizes); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void GroupedGemm(torch::Tensor a, |
|
|
torch::Tensor b, |
|
|
torch::Tensor c, |
|
|
torch::Tensor batch_sizes, |
|
|
bool trans_a, bool trans_b) { |
|
|
|
|
|
TORCH_CHECK(!(trans_a && trans_b)); |
|
|
|
|
|
#if !defined(GROUPED_GEMM_CUTLASS) |
|
|
|
|
|
TORCH_CHECK(batch_sizes.is_cpu()); |
|
|
#else |
|
|
|
|
|
TORCH_CHECK(batch_sizes.is_cuda() || batch_sizes.is_cpu()); |
|
|
#endif |
|
|
TORCH_CHECK(batch_sizes.ndimension() == 1); |
|
|
TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64); |
|
|
|
|
|
|
|
|
|
|
|
TORCH_CHECK(a.is_cuda()); |
|
|
TORCH_CHECK(a.ndimension() == 2); |
|
|
TORCH_CHECK(a.scalar_type() == torch::kBFloat16); |
|
|
|
|
|
#if !defined(GROUPED_GEMM_CUTLASS) |
|
|
if (trans_a) { |
|
|
|
|
|
|
|
|
GroupedGemmVariableK(a, b, c, batch_sizes); |
|
|
return; |
|
|
} |
|
|
#endif |
|
|
|
|
|
TORCH_CHECK(b.is_cuda()); |
|
|
TORCH_CHECK(c.is_cuda()); |
|
|
TORCH_CHECK(b.scalar_type() == torch::kBFloat16); |
|
|
TORCH_CHECK(c.scalar_type() == torch::kBFloat16); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
size_t hidden_in{}, hidden_out{}; |
|
|
if (trans_a) { |
|
|
hidden_in = a.size(1); |
|
|
hidden_out = b.size(1); |
|
|
|
|
|
TORCH_CHECK(b.ndimension() == 2); |
|
|
TORCH_CHECK(c.ndimension() == 3); |
|
|
TORCH_CHECK(b.size(0) == a.size(0)); |
|
|
TORCH_CHECK(c.size(0) == batch_sizes.size(0)); |
|
|
TORCH_CHECK(c.size(1) == hidden_in); |
|
|
TORCH_CHECK(c.size(2) == hidden_out); |
|
|
} else { |
|
|
TORCH_CHECK(b.ndimension() == 3); |
|
|
TORCH_CHECK(c.ndimension() == 2); |
|
|
|
|
|
|
|
|
int64_t tokens = a.size(0), num_experts = b.size(0); |
|
|
hidden_in = trans_b ? b.size(2) : b.size(1); |
|
|
hidden_out = trans_b ? b.size(1) : b.size(2); |
|
|
TORCH_CHECK(hidden_in == a.size(1)); |
|
|
|
|
|
|
|
|
TORCH_CHECK(batch_sizes.size(0) == num_experts); |
|
|
} |
|
|
|
|
|
|
|
|
TORCH_CHECK(a.is_contiguous()); |
|
|
TORCH_CHECK(b.is_contiguous()); |
|
|
TORCH_CHECK(c.is_contiguous()); |
|
|
|
|
|
#if !defined(GROUPED_GEMM_CUTLASS) |
|
|
CublasGroupedGemm(a, b, c, batch_sizes, trans_b); |
|
|
return; |
|
|
#else |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const auto coord_template = trans_a |
|
|
? cutlass::gemm::GemmCoord(hidden_in, hidden_out, kDynamicDim) |
|
|
: cutlass::gemm::GemmCoord(kDynamicDim, hidden_out, hidden_in); |
|
|
if (trans_a) { |
|
|
CutlassGroupedGemm<true, false>(a, b, c, batch_sizes, coord_template); |
|
|
return; |
|
|
} |
|
|
if (trans_b) { |
|
|
CutlassGroupedGemm<false, true>(a, b, c, batch_sizes, coord_template); |
|
|
return; |
|
|
} |
|
|
CutlassGroupedGemm<false, false>(a, b, c, batch_sizes, coord_template); |
|
|
return; |
|
|
#endif |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
#endif |
|
|
|