megablocks-hip / csrc /grouped_gemm /grouped_gemm.hip
leonardlin's picture
Teach HIP grouped_gemm about autograd
5ce4c31
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "grouped_gemm.h"
#ifdef __HIP_PLATFORM_AMD__
#include "gpu_backend_hip.h"
#include <ATen/hip/HIPContext.h>
#include <hipblaslt/hipblaslt.h>
#include <torch/autograd.h>
#include <vector>
#include <algorithm>
#include <cctype>
#include <cstdlib>
#include <string>
namespace grouped_gemm {
namespace {
// Experimental: toggled via MEGABLOCKS_GG_USE_HIPBLASLT=1. This flag is
// intentionally off by default because the hipBLASLt path still fails on the
// largest `tests/ops_test.py` configurations.
bool use_hipblaslt_backend() {
static int cached = [] {
const char* raw = std::getenv("MEGABLOCKS_GG_USE_HIPBLASLT");
if (raw == nullptr) {
return 0;
}
std::string value(raw);
std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (value == "1" || value == "true" || value == "yes" || value == "on") {
return 1;
}
return 0;
}();
return cached == 1;
}
inline void hipblaslt_check(hipblasStatus_t status, const char* expr) {
TORCH_CHECK(status == HIPBLAS_STATUS_SUCCESS, "hipBLASLt call failed with status ", status, " when executing ", expr);
}
#define HIPBLASLT_CHECK(cmd) hipblaslt_check((cmd), #cmd)
hipblasLtHandle_t hipblaslt_handle() {
static hipblasLtHandle_t handle = [] {
hipblasLtHandle_t h;
HIPBLASLT_CHECK(hipblasLtCreate(&h));
return h;
}();
return handle;
}
void hipblaslt_run_matmul(const void* a_ptr,
const void* b_ptr,
const void* c_ptr,
void* d_ptr,
int64_t rows_a,
int64_t cols_a,
int64_t rows_b,
int64_t cols_b,
int64_t rows_d,
int64_t cols_d,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t ldd,
hipblasOperation_t op_a,
hipblasOperation_t op_b,
bool accumulate) {
if (rows_a == 0 || cols_a == 0 || rows_b == 0 || cols_b == 0 || rows_d == 0 || cols_d == 0)
return;
auto handle = hipblaslt_handle();
auto stream = c10::hip::getCurrentHIPStream();
hipblasLtMatmulDesc_t matmul_desc;
HIPBLASLT_CHECK(hipblasLtMatmulDescCreate(&matmul_desc, HIPBLAS_COMPUTE_32F, HIP_R_32F));
HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
matmul_desc, HIPBLASLT_MATMUL_DESC_TRANSA, &op_a, sizeof(op_a)));
HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
matmul_desc, HIPBLASLT_MATMUL_DESC_TRANSB, &op_b, sizeof(op_b)));
hipblasLtPointerMode_t pointer_mode = HIPBLASLT_POINTER_MODE_HOST;
HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
matmul_desc, HIPBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));
hipblasLtOrder_t order = HIPBLASLT_ORDER_ROW;
hipblasLtMatrixLayout_t layout_a;
HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_a, HIP_R_16BF, rows_a, cols_a, lda));
HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
layout_a, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
hipblasLtMatrixLayout_t layout_b;
HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_b, HIP_R_16BF, rows_b, cols_b, ldb));
HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
layout_b, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
hipblasLtMatrixLayout_t layout_c;
HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_c, HIP_R_16BF, rows_d, cols_d, ldc));
HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
layout_c, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
hipblasLtMatrixLayout_t layout_d;
HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_d, HIP_R_16BF, rows_d, cols_d, ldd));
HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
layout_d, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
hipblasLtMatmulPreference_t preference;
HIPBLASLT_CHECK(hipblasLtMatmulPreferenceCreate(&preference));
uint64_t workspace_size = 0;
HIPBLASLT_CHECK(hipblasLtMatmulPreferenceSetAttribute(
preference,
HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(workspace_size)));
hipblasLtMatmulHeuristicResult_t heuristic{};
int returned_results = 0;
HIPBLASLT_CHECK(hipblasLtMatmulAlgoGetHeuristic(
handle,
matmul_desc,
layout_a,
layout_b,
layout_c,
layout_d,
preference,
1,
&heuristic,
&returned_results));
TORCH_CHECK(returned_results > 0, "hipBLASLt could not find a suitable algorithm");
const float alpha = 1.0f;
const float beta = accumulate ? 1.0f : 0.0f;
HIPBLASLT_CHECK(hipblasLtMatmul(handle,
matmul_desc,
&alpha,
a_ptr,
layout_a,
b_ptr,
layout_b,
&beta,
c_ptr,
layout_c,
d_ptr,
layout_d,
&heuristic.algo,
nullptr,
0,
stream));
hipblasLtMatmulPreferenceDestroy(preference);
hipblasLtMatrixLayoutDestroy(layout_d);
hipblasLtMatrixLayoutDestroy(layout_c);
hipblasLtMatrixLayoutDestroy(layout_b);
hipblasLtMatrixLayoutDestroy(layout_a);
hipblasLtMatmulDescDestroy(matmul_desc);
}
} // namespace
torch::Tensor hipblaslt_gmm_internal(torch::Tensor a,
torch::Tensor b,
torch::Tensor batch_sizes,
bool trans_a,
bool trans_b,
c10::optional<torch::Tensor> c_opt) {
torch::NoGradGuard no_grad;
TORCH_CHECK(a.is_cuda(), "hipblaslt_gmm requires CUDA tensors");
TORCH_CHECK(b.is_cuda(), "hipblaslt_gmm requires CUDA tensors");
TORCH_CHECK(a.scalar_type() == torch::kBFloat16, "hipblaslt_gmm expects BF16 inputs");
TORCH_CHECK(b.scalar_type() == torch::kBFloat16, "hipblaslt_gmm expects BF16 weights");
TORCH_CHECK(batch_sizes.device().is_cpu(), "batch_sizes must reside on CPU");
a = a.contiguous();
b = b.contiguous();
auto device = a.device();
auto dtype = a.scalar_type();
const bool use_hip = use_hipblaslt_backend();
const auto counts_ptr = batch_sizes.data_ptr<int64_t>();
const int64_t num_experts = batch_sizes.size(0);
std::vector<int64_t> prefix(num_experts);
int64_t running = 0;
for (int64_t i = 0; i < num_experts; ++i) {
running += counts_ptr[i];
prefix[i] = running;
}
const int64_t tokens = num_experts ? prefix.back() : 0;
TORCH_CHECK(a.size(0) == tokens, "tokens mismatch with batch sizes");
torch::Tensor out;
if (trans_a) {
const int64_t hidden_in = a.size(1);
const int64_t hidden_out = b.size(1);
out = c_opt.value_or(torch::empty({num_experts, hidden_in, hidden_out},
a.options().dtype(dtype)));
TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous");
auto b_contig = b.contiguous();
if (use_hip) {
int64_t start = 0;
for (int64_t expert = 0; expert < num_experts; ++expert) {
const int64_t end = prefix[expert];
const int64_t rows = end - start;
auto out_chunk = out.select(0, expert);
if (rows == 0) {
out_chunk.zero_();
start = end;
continue;
}
auto a_chunk = a.narrow(0, start, rows).contiguous();
auto b_chunk = b_contig.narrow(0, start, rows).contiguous();
hipblaslt_run_matmul(a_chunk.data_ptr(),
b_chunk.data_ptr(),
out_chunk.data_ptr(),
out_chunk.data_ptr(),
rows,
hidden_in,
rows,
hidden_out,
hidden_in,
hidden_out,
hidden_in,
hidden_out,
hidden_out,
hidden_out,
HIPBLAS_OP_T,
HIPBLAS_OP_N,
/*accumulate=*/false);
start = end;
}
} else {
int64_t start = 0;
for (int64_t expert = 0; expert < num_experts; ++expert) {
const int64_t end = prefix[expert];
const int64_t rows = end - start;
auto out_chunk = out.select(0, expert);
if (rows == 0) {
out_chunk.zero_();
start = end;
continue;
}
auto a_slice = a.narrow(0, start, rows);
auto b_slice = b_contig.narrow(0, start, rows);
auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
auto prod = torch::matmul(a_f32.transpose(0, 1), b_f32);
auto prod_bf16 = prod.to(dtype);
out_chunk.copy_(prod_bf16);
start = end;
}
}
return out;
}
if (trans_b) {
const int64_t hidden_in = a.size(1);
const int64_t hidden_out = b.size(1);
out = c_opt.value_or(torch::empty({tokens, hidden_out}, a.options()));
TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous");
auto b_contig = b.contiguous();
if (use_hip) {
int64_t start = 0;
for (int64_t expert = 0; expert < num_experts; ++expert) {
const int64_t end = prefix[expert];
const int64_t rows = end - start;
if (rows == 0) {
start = end;
continue;
}
auto a_chunk = a.narrow(0, start, rows).contiguous();
auto b_chunk = b_contig.select(0, expert).contiguous();
auto out_chunk = out.narrow(0, start, rows);
hipblaslt_run_matmul(a_chunk.data_ptr(),
b_chunk.data_ptr(),
out_chunk.data_ptr(),
out_chunk.data_ptr(),
rows,
hidden_in,
hidden_out,
hidden_in,
rows,
hidden_out,
hidden_in,
hidden_in,
hidden_out,
hidden_out,
HIPBLAS_OP_N,
HIPBLAS_OP_T,
/*accumulate=*/false);
start = end;
}
} else {
int64_t start = 0;
for (int64_t expert = 0; expert < num_experts; ++expert) {
const int64_t end = prefix[expert];
const int64_t rows = end - start;
if (rows == 0) {
start = end;
continue;
}
auto a_slice = a.narrow(0, start, rows);
auto b_slice = b_contig.select(0, expert);
auto out_chunk = out.narrow(0, start, rows);
auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
auto prod = torch::matmul(a_f32, b_f32.transpose(0, 1));
auto prod_bf16 = prod.to(dtype);
out_chunk.copy_(prod_bf16);
start = end;
}
}
return out;
}
const int64_t hidden_out = a.size(1);
const int64_t hidden_in = b.size(2);
out = c_opt.value_or(torch::empty({tokens, hidden_in}, a.options()));
TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous");
auto b_contig = b.contiguous();
if (use_hip) {
int64_t start = 0;
for (int64_t expert = 0; expert < num_experts; ++expert) {
const int64_t end = prefix[expert];
const int64_t rows = end - start;
if (rows == 0) {
start = end;
continue;
}
auto a_chunk = a.narrow(0, start, rows).contiguous();
auto b_chunk = b_contig.select(0, expert).contiguous();
auto out_chunk = out.narrow(0, start, rows);
hipblaslt_run_matmul(a_chunk.data_ptr(),
b_chunk.data_ptr(),
out_chunk.data_ptr(),
out_chunk.data_ptr(),
rows,
hidden_out,
hidden_out,
hidden_in,
rows,
hidden_in,
hidden_out,
hidden_in,
hidden_in,
hidden_in,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
/*accumulate=*/false);
start = end;
}
} else {
int64_t start = 0;
for (int64_t expert = 0; expert < num_experts; ++expert) {
const int64_t end = prefix[expert];
const int64_t rows = end - start;
if (rows == 0) {
start = end;
continue;
}
auto a_slice = a.narrow(0, start, rows);
auto b_slice = b_contig.select(0, expert);
auto out_chunk = out.narrow(0, start, rows);
auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
auto b_f32 = b_slice.contiguous().to(torch::kFloat32);
auto prod = torch::matmul(a_f32, b_f32);
auto prod_bf16 = prod.to(dtype);
out_chunk.copy_(prod_bf16);
start = end;
}
}
return out;
}
void GroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes,
bool trans_a,
bool trans_b) {
if (!batch_sizes.device().is_cpu()) {
batch_sizes = batch_sizes.cpu();
}
TORCH_CHECK(c.is_contiguous(), "Output tensor must be contiguous");
auto result = hipblaslt_gmm_internal(a, b, batch_sizes, trans_a, trans_b, c);
if (!c.is_alias_of(result)) {
c.copy_(result);
}
}
} // namespace grouped_gemm
#else
#include "fill_arguments_hip.cuh"
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/detail/KernelUtils.h>
#include <c10/util/BFloat16.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <hipcub/hipcub.hpp>
#include <torch/torch.h>
#include "cutlass/bfloat16.h"
#include "cutlass/complex.h"
#include "cutlass/gemm/kernel/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include <type_traits>
namespace grouped_gemm {
#define CUDA_CALL(code) \
do { \
hipError_t status = code; \
std::string err = hipGetErrorString(status); \
TORCH_CHECK(status == hipSuccess, err); \
} while (0)
#define CUBLAS_CALL(code) \
do { \
hipblasStatus_t status = code; \
TORCH_CHECK(status == HIPBLAS_STATUS_SUCCESS, "CuBLAS Error"); \
} while (0)
#define GROUPED_GEMM_STRINGIFY_HELPER(x) #x
#define GROUPED_GEMM_STRINGIFY(x) \
GROUPED_GEMM_STRINGIFY_HELPER(x)
template <bool trans>
using GroupedGemmInputLayout = std::conditional_t<trans, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;
using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration<
::cutlass::arch::OpClassTensorOp,
::cutlass::arch::Sm80,
::cutlass::bfloat16_t,
::cutlass::bfloat16_t,
::cutlass::bfloat16_t,
float
>;
// TODO(tgale): Update this for SM90 when it's supported by CUTLASS.
template <bool trans_a, bool trans_b>
using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
// A operand.
::cutlass::bfloat16_t,
GroupedGemmInputLayout<trans_a>,
::cutlass::ComplexTransform::kNone,
GroupedGemmConfig::kAlignmentA,
// B operand.
::cutlass::bfloat16_t,
GroupedGemmInputLayout<trans_b>,
::cutlass::ComplexTransform::kNone,
GroupedGemmConfig::kAlignmentB,
// C operand.
::cutlass::bfloat16_t,
::cutlass::layout::RowMajor,
float,
::cutlass::arch::OpClassTensorOp,
::cutlass::arch::Sm80,
GroupedGemmConfig::ThreadblockShape,
GroupedGemmConfig::WarpShape,
GroupedGemmConfig::InstructionShape,
GroupedGemmConfig::EpilogueOutputOp,
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
// This parameter is passed in at present to match the APIs of other kernels. The parameter
// is unused within the kernel.
::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
// TODO(tgale): Tune this for SM90.
GroupedGemmConfig::kStages>::GemmKernel;
template <bool trans_a, bool trans_b>
using GemmGrouped = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernel<trans_a, trans_b>>;
template <typename T>
torch::Tensor CopyToDevice(const std::vector<T> &x, const torch::Device &device) {
size_t bytes = x.size() * sizeof(T);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device);
torch::Tensor out = torch::empty(bytes, options);
CUDA_CALL(hipMemcpyAsync(out.data_ptr(),
x.data(), bytes,
hipMemcpyHostToDevice,
c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()));
return out;
}
template <typename T>
static void ReorderArray(T* data, const std::vector<size_t>& indices) {
// For now, simply create a copy of the data and then copy over to the original.
std::vector<T> copy(data, data + indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
data[i] = copy.at(indices[i]);
}
}
template <typename T>
torch::Tensor TypedEmpty(size_t numel, const torch::Device& device) {
return torch::empty(numel * sizeof(T), torch::dtype(torch::kInt8).device(device));
}
struct RawGemmArguments {
torch::Tensor lda, ldb, ldc, ptr_a, ptr_b, ptr_c, problem_sizes;
int threadblock_count{};
};
template <
typename Gemm,
typename ElementA, typename ElementB, typename ElementC
>
RawGemmArguments MakeArgumentsOnDevice(int num_experts, const torch::Device& device) {
TORCH_CHECK(
num_experts <= kMaxExperts,
"At most ", kMaxExperts,
" experts are supported when batch_sizes is a CUDA tensor, but got ", num_experts
);
return RawGemmArguments {
.lda = TypedEmpty<int64_t>(num_experts, device),
.ldb = TypedEmpty<int64_t>(num_experts, device),
.ldc = TypedEmpty<int64_t>(num_experts, device),
.ptr_a = TypedEmpty<ElementA*>(num_experts, device),
.ptr_b = TypedEmpty<ElementB*>(num_experts, device),
.ptr_c = TypedEmpty<ElementC*>(num_experts, device),
.problem_sizes = TypedEmpty<cutlass::gemm::GemmCoord>(num_experts, device),
// We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here.
.threadblock_count = Gemm::sufficient(),
};
}
template <
bool kDynamicK,
typename Gemm,
typename ElementA, typename ElementB, typename ElementC,
typename LayoutA, typename LayoutB, typename LayoutC
>
RawGemmArguments MakeArgumentsOnHost(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes,
::cutlass::gemm::GemmCoord coord_template,
int64_t num_experts) {
std::vector<::cutlass::gemm::GemmCoord> problem_sizes_host(num_experts);
// Create the host arrays of leading dimension data and pointer data.
std::vector<int64_t> lda_host(num_experts), ldb_host(num_experts), ldc_host(num_experts);
int64_t elements_a = 0, elements_b = 0, elements_c = 0;
std::vector<ElementA *> ptr_a_host(num_experts), ptr_b_host(num_experts), ptr_c_host(num_experts);
for (int i = 0; i < num_experts; ++i) {
auto& problem = problem_sizes_host[i];
problem = coord_template;
(kDynamicK ? problem.k() : problem.m()) = batch_sizes.data_ptr<int64_t>()[i];
lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
ptr_a_host[i] = (ElementA*)a.data_ptr() + elements_a;
ptr_b_host[i] = (ElementB*)b.data_ptr() + elements_b;
ptr_c_host[i] = (ElementC*)c.data_ptr() + elements_c;
elements_a += problem.m() * problem.k();
elements_b += problem.k() * problem.n();
elements_c += problem.m() * problem.n();
if (problem.k() == 0) {
// CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
// Until a fix is available on the CUTLASS side, handle these problems by ourselves:
// * set the output to zero with `hipMemsetAsync()`
// * make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
CUDA_CALL(hipMemsetAsync(ptr_c_host[i],
0,
problem.m() * problem.n() * sizeof(ElementC),
c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()));
problem.m() = 0;
problem.n() = 0;
}
}
// Only sort problems when K are different
if (kDynamicK) {
std::vector<size_t> indices(num_experts);
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) {
return problem_sizes_host[i].k() > problem_sizes_host[j].k();
});
ReorderArray(problem_sizes_host.data(), indices);
ReorderArray(lda_host.data(), indices);
ReorderArray(ldb_host.data(), indices);
ReorderArray(ldc_host.data(), indices);
ReorderArray(ptr_a_host.data(), indices);
ReorderArray(ptr_b_host.data(), indices);
ReorderArray(ptr_c_host.data(), indices);
}
// Copy the problem sizes, pointers and leading dimension data to the device.
return RawGemmArguments {
.lda = CopyToDevice(lda_host, a.device()),
.ldb = CopyToDevice(ldb_host, a.device()),
.ldc = CopyToDevice(ldc_host, a.device()),
.ptr_a = CopyToDevice(ptr_a_host, a.device()),
.ptr_b = CopyToDevice(ptr_b_host, a.device()),
.ptr_c = CopyToDevice(ptr_c_host, a.device()),
.problem_sizes = CopyToDevice(problem_sizes_host, a.device()),
// We know the problem dimensions on the host, so we can calculate the number of threadblocks based on that.
.threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts),
};
}
template <
bool kDynamicK,
typename Gemm,
typename ElementA, typename ElementB, typename ElementC,
typename LayoutA, typename LayoutB, typename LayoutC
>
typename Gemm::Arguments MakeArguments(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes,
::cutlass::gemm::GemmCoord coord_template,
int64_t num_experts) {
RawGemmArguments raw_args;
if (batch_sizes.is_cuda()) {
raw_args = MakeArgumentsOnDevice<
Gemm, ElementA, ElementB, ElementC
>(num_experts, a.device());
} else {
raw_args = MakeArgumentsOnHost<
kDynamicK,
Gemm,
ElementA, ElementB, ElementC,
LayoutA, LayoutB, LayoutC
>(a, b, c, batch_sizes, coord_template, num_experts);
}
printf("Using %d threadblocks for grouped GEMM.\n", raw_args.threadblock_count);
// Validate the result.
if (!raw_args.threadblock_count) {
TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
}
typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f);
// We currently always use `GroupScheduleMode::kDeviceOnly`, which doesn't use `host_problem_sizes` at all,
// so we can safely pass `nullptr` for `host_problem_sizes`.
// TODO(tgale): Experiment with `GroupScheduleMode::kHostPrecompute` for `batch_sizes.is_cpu()`, where we
// know the problem dimensions on the host.
typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)raw_args.problem_sizes.data_ptr(),
(int)num_experts,
(int)raw_args.threadblock_count,
epilogue_op,
(ElementA**)raw_args.ptr_a.data_ptr(),
(ElementB**)raw_args.ptr_b.data_ptr(),
(ElementC**)raw_args.ptr_c.data_ptr(),
(ElementC**)raw_args.ptr_c.data_ptr(),
/*lda=*/(int64_t*)raw_args.lda.data_ptr(),
/*ldb=*/(int64_t*)raw_args.ldb.data_ptr(),
/*ldc=*/(int64_t*)raw_args.ldc.data_ptr(),
/*ldd=*/(int64_t*)raw_args.ldc.data_ptr(),
/*host_problem_sizes=*/nullptr);
return arguments;
}
template <
bool trans_a,
typename ElementA, typename ElementB, typename ElementC,
typename LayoutA, typename LayoutB, typename LayoutC,
typename Arguments
>
void FillCutlassArguments(int num_experts,
torch::Tensor batch_sizes,
torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
const Arguments& arguments,
::cutlass::gemm::GemmCoord coord_template) {
// Convert the batch sizes to the format CUTLASS understands on the device.
// Use a single block here because:
// * the number of elements to process is microscopically small
// * we don't need any additional global memory
hipLaunchKernelGGL(( FillArguments<
/*kDynamicK*/trans_a,
ElementA, ElementB, ElementC,
LayoutA, LayoutB, LayoutC
>), dim3(1), dim3(kMaxExperts), 0, c10::hip::getCurrentHIPStreamMasqueradingAsCUDA(),
num_experts, batch_sizes.data_ptr<int64_t>(),
(ElementA*)a.data_ptr(), (ElementB*)b.data_ptr(), (ElementC*)c.data_ptr(),
arguments, coord_template
);
C10_HIP_KERNEL_LAUNCH_CHECK();
}
template <typename Args>
void RemoveK0Problems(int num_experts, const Args& arguments) {
// For zeroing out the outputs (which might be arbitrarily large), we want to use
// as many threadblocks as possible in order to hit the maximum possible global memory bandwidth.
// `arguments.threadblock_count`, which we will use for the grouped GEMM proper,
// should be a good approximation for this.
// When the `k=0` case is fixed in CUTLASS, we can completely remove this function.
hipLaunchKernelGGL(( ZeroOutK0Outputs<>),
dim3(arguments.threadblock_count), dim3(at::cuda::detail::CUDA_NUM_THREADS), 0, c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()
,
num_experts, arguments
);
hipLaunchKernelGGL(( IgnoreK0Problems<>),
dim3(1), dim3(kMaxExperts), 0, c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()
,
num_experts, arguments
);
}
template <bool trans_a, bool trans_b>
torch::Tensor CutlassGroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes,
::cutlass::gemm::GemmCoord coord_template) {
using Gemm = GemmGrouped<trans_a, trans_b>;
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementC = typename Gemm::ElementC;
Gemm gemm;
int64_t num_experts = batch_sizes.size(0);
auto arguments = MakeArguments<
/*kDynamicK*/trans_a,
Gemm,
ElementA, ElementB, ElementC,
LayoutA, LayoutB, LayoutC
>(a, b, c, batch_sizes, coord_template, num_experts);
int64_t workspace_size = gemm.get_workspace_size(arguments);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
torch::Tensor workspace = torch::empty(workspace_size, options);
if (batch_sizes.is_cuda()) {
FillCutlassArguments<
trans_a,
ElementA, ElementB, ElementC,
LayoutA, LayoutB, LayoutC
>(num_experts, batch_sizes, a, b, c, arguments, coord_template);
RemoveK0Problems<>(num_experts, arguments);
}
// Initialize the kernel.
if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
}
// Execute the kernel in the current stream.
if(gemm.run(c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()) != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
}
return c;
}
void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a,
c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b,
c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) {
int m = trans_b ? b_rows : b_cols;
int k = trans_b ? b_cols : b_rows;
int n = trans_a ? a_cols : a_rows;
int lda = trans_a ? n : k;
int ldb = trans_b ? k : m;
hipblasOperation_t transpose_a = trans_a ? HIPBLAS_OP_T : HIPBLAS_OP_N;
hipblasOperation_t transpose_b = trans_b ? HIPBLAS_OP_T : HIPBLAS_OP_N;
float alpha = 1.0, beta = 0.0;
CUBLAS_CALL(hipblasGemmEx(at::cuda::getCurrentCUDABlasHandle(),
transpose_b, transpose_a,
m, n, k, &alpha,
b, HIP_R_16BF, ldb,
a, HIP_R_16BF, lda,
&beta,
c, HIP_R_16BF, c_cols, HIP_R_32F,
HIPBLAS_GEMM_DEFAULT));
}
void CublasGroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes,
bool trans_b) {
int64_t bs = batch_sizes.size(0), k = a.size(1);
int64_t n = trans_b ? b.size(1) : b.size(2);
int64_t b_rows = b.size(1), b_cols = b.size(2);
c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
for (int i = 0; i < bs; ++i) {
int64_t m = batch_sizes.data_ptr<int64_t>()[i];
CublasGemm(a_ptr, m, k, /*trans_a=*/false,
b_ptr, b_rows, b_cols, trans_b,
c_ptr, m, n);
a_ptr += m * k;
b_ptr += b_rows * b_cols;
c_ptr += m * n;
}
}
void CublasGroupedGemmVariableK(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes) {
int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1);
c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
for (int i = 0; i < bs; ++i) {
int64_t k = batch_sizes.data_ptr<int64_t>()[i];
CublasGemm(a_ptr, k, m, /*trans_a=*/true,
b_ptr, k, n, /*trans_b=*/false,
c_ptr, m, n);
a_ptr += k * m;
b_ptr += k * n;
c_ptr += m * n;
}
}
void GroupedGemmVariableK(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes) {
// We expected a CUDA tensor with two dimensions and shape
// (tokens, hidden_out) for 'b'.
TORCH_CHECK(b.is_cuda());
TORCH_CHECK(b.ndimension() == 2);
TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
// Validate the dimensions.
int64_t tokens = a.size(0), num_experts = batch_sizes.size(0);
int64_t m = a.size(1), n = b.size(1);
// Validate that we have the same contraction dimension.
TORCH_CHECK(tokens == b.size(0));
// Validate the output shape.
TORCH_CHECK(c.is_cuda());
TORCH_CHECK(c.ndimension() == 3);
TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
TORCH_CHECK(c.size(0) == num_experts);
TORCH_CHECK(c.size(1) == m);
TORCH_CHECK(c.size(2) == n);
// Run the computation.
CublasGroupedGemmVariableK(a, b, c, batch_sizes);
}
// NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is
// assumed to be batched with fixed sized batches.
//
// TODO(tgale): Validate alignment is true for every batch element.
void GroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes,
bool trans_a, bool trans_b) {
// NOTE: We only support 'trans_a' or 'trans_b', not both.
TORCH_CHECK(!(trans_a && trans_b));
#if !defined(GROUPED_GEMM_CUTLASS)
// No way to run cuBLAS kernels if the problem dimensions are not known on the host.
TORCH_CHECK(batch_sizes.is_cpu());
#else
// CUTLASS can handle both CPU- and CUDA-resident problem dimensions.
TORCH_CHECK(batch_sizes.is_cuda() || batch_sizes.is_cpu());
#endif
TORCH_CHECK(batch_sizes.ndimension() == 1);
TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64);
// We expected a CUDA tensor with two dimensions and shape
// (tokens, hidden_in) for 'a'.
TORCH_CHECK(a.is_cuda());
TORCH_CHECK(a.ndimension() == 2);
TORCH_CHECK(a.scalar_type() == torch::kBFloat16);
#if !defined(GROUPED_GEMM_CUTLASS)
if (trans_a) {
// If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS
// for the rest of the op.
GroupedGemmVariableK(a, b, c, batch_sizes);
return;
}
#endif
TORCH_CHECK(b.is_cuda());
TORCH_CHECK(c.is_cuda());
TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
// The expected shapes of 'b' and 'c' are:
// * when 'trans_a' is set: b=(tokens, hidden_out), c=(num_experts, hidden_in, hidden_out)
// * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out)
// * otherwise: b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden
size_t hidden_in{}, hidden_out{};
if (trans_a) {
hidden_in = a.size(1);
hidden_out = b.size(1);
TORCH_CHECK(b.ndimension() == 2);
TORCH_CHECK(c.ndimension() == 3);
TORCH_CHECK(b.size(0) == a.size(0));
TORCH_CHECK(c.size(0) == batch_sizes.size(0));
TORCH_CHECK(c.size(1) == hidden_in);
TORCH_CHECK(c.size(2) == hidden_out);
} else {
TORCH_CHECK(b.ndimension() == 3);
TORCH_CHECK(c.ndimension() == 2);
// Validate the contraction dimensions match.
int64_t tokens = a.size(0), num_experts = b.size(0);
hidden_in = trans_b ? b.size(2) : b.size(1);
hidden_out = trans_b ? b.size(1) : b.size(2);
TORCH_CHECK(hidden_in == a.size(1));
// Validate that we have one size per expert.
TORCH_CHECK(batch_sizes.size(0) == num_experts);
}
// NOTE: We support transposition through the 'trans_b' flag.
TORCH_CHECK(a.is_contiguous());
TORCH_CHECK(b.is_contiguous());
TORCH_CHECK(c.is_contiguous());
#if !defined(GROUPED_GEMM_CUTLASS)
CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
return;
#else
// The `coord_template` argument contains `kDynamicDim` as one of its dimensions
// as a placeholder. This placeholder is later expanded into the actual dimension
// for every element of the batch, either on the host or on the device
// (if we can't do in on the host).
const auto coord_template = trans_a
? cutlass::gemm::GemmCoord(hidden_in, hidden_out, kDynamicDim)
: cutlass::gemm::GemmCoord(kDynamicDim, hidden_out, hidden_in);
if (trans_a) {
CutlassGroupedGemm<true, false>(a, b, c, batch_sizes, coord_template);
return;
}
if (trans_b) {
CutlassGroupedGemm<false, true>(a, b, c, batch_sizes, coord_template);
return;
}
CutlassGroupedGemm<false, false>(a, b, c, batch_sizes, coord_template);
return;
#endif
}
} // namespace grouped_gemm
#endif