Skip to content

Commit

Permalink
Improve performance of prefill mode CK FP8 Grouped Gemm (pytorch#3522)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#603

I previously assumed that using hipmemcpy would be more efficient than launching many kernels that directly set gpu memory. This assumption is apparently (and very surprisingly) untrue. It seems the the multi-kernel-launch approach reduces overhead considerably, giving a 10% speedup.

Differential Revision: D67531231
  • Loading branch information
jwfromm authored and facebook-github-bot committed Dec 20, 2024
1 parent 6da23d5 commit 0322843
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 98 deletions.
9 changes: 9 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ def main(args: Any):
M = [int(m) for m in args.M.strip().split(",")]
N = [int(n) for n in args.N.strip().split(",")]
K = [int(k) for k in args.K.strip().split(",")]
if args.groups:
M = M * int(args.groups)
N = N * int(args.groups)
K = K * int(args.groups)
if args.B is None:
B = [1] * len(M)
else:
Expand Down Expand Up @@ -384,6 +388,11 @@ def invoke_main() -> None:
help="If set, do grouped gemm. In this mode, M, N, and K are interpreted "
"as the size of groups. The length of each must be the same.",
)
parser.add_argument(
"--groups",
default=None,
help="If set with grouped mode, repeat input shapes this many times.",
)
parser.add_argument(
"--no_cuda_graph",
default=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,99 @@ RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) {
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
}

__global__ void set_kernel_args_kernel(
KernelArguments* kernel_args,
ADataType* XQ,
BDataType* WQ,
D0DataType* w_scale,
D1DataType* x_scale,
EDataType* output,
int M,
int N,
int K) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each kernel annoyingly can only set the kernel args for one group.
// This could only be avoided with complicated memory management.
if (idx == 0) {
// Write kernel arguments directly to memory.
KernelArguments kernel_group_args = {
XQ, WQ, {w_scale, x_scale}, output, M, N, K, K, K, {0, 0}, N};
kernel_args[0] = kernel_group_args;
}
}

void set_static_kernel_args(
at::Tensor kernel_args,
at::TensorList XQ,
at::TensorList WQ,
at::TensorList x_scale,
at::TensorList w_scale,
std::vector<at::Tensor> output) {
// Get current cuda stream.
auto stream = at::cuda::getCurrentHIPStream().stream();
int group_count = XQ.size();
// When group count is large, we can more efficiently initialize
// by doing host setup and a memcpy. This is only viable if cuda
// graphs arent being used.
if (group_count >= 16 && stream == 0) {
std::vector<KernelArguments> ggemm_kargs;
ggemm_kargs.reserve(group_count);

// Iterate over inputs and get group information.
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int K = XQ[i].size(1);
int N = WQ[i].size(0);
KernelArguments group_args = {
reinterpret_cast<ADataType*>(XQ[i].data_ptr()),
reinterpret_cast<BDataType*>(WQ[i].data_ptr()),
{reinterpret_cast<D0DataType*>(w_scale[i].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[i].data_ptr())},
reinterpret_cast<EDataType*>(output[i].data_ptr()),
M,
N,
K,
K,
K,
{0, 0},
N};
ggemm_kargs.push_back(group_args);
}
// Copy data onto device.
hipMemcpy(
kernel_args.data_ptr(), // Destination
ggemm_kargs.data(), // Source
sizeof(KernelArguments) * group_count, // Number of bytes
hipMemcpyHostToDevice); // Copy Type
} else {
// We use the smallest reasonable block size since we effectively need only
// 1 thread.
int blockSize = 32;
int numBlocks = 1;
// Launch a kernel for each group to set kernel memory on device.
// Using multiple kernels this way allows us to support arbitrary M,N,K.
// For some reason, this approach is faster than using hipmemcpy.
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int K = XQ[i].size(1);
int N = WQ[i].size(0);
// Launch kernel to set kernel arguments.
set_kernel_args_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(
reinterpret_cast<char*>(kernel_args.data_ptr()) +
(i * sizeof(KernelArguments))),
reinterpret_cast<ADataType*>(XQ[i].data_ptr()),
reinterpret_cast<BDataType*>(WQ[i].data_ptr()),
reinterpret_cast<D0DataType*>(w_scale[i].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[i].data_ptr()),
reinterpret_cast<EDataType*>(output[i].data_ptr()),
M,
N,
K);
}
}
}

__global__ void set_kernel_args_fixed_nk_kernel(
KernelArguments* kernel_args,
ADataType* XQ,
Expand Down Expand Up @@ -114,6 +207,74 @@ __global__ void set_kernel_args_fixed_nk_kernel(
}
}

void set_dynamic_kernel_args(
at::Tensor kernel_args,
at::TensorList XQ,
at::TensorList WQ,
at::TensorList x_scale,
at::TensorList w_scale,
std::vector<at::Tensor> output,
at::Tensor zero_start_index_M) {
// Get current cuda stream.
auto stream = at::cuda::getCurrentHIPStream().stream();
int group_count = XQ.size();
// Confirm M is on the proper device.
TORCH_CHECK(
XQ[0].device() == zero_start_index_M.device(),
"zero_start_index_M and inputs must be on the same device.");
TORCH_CHECK(
zero_start_index_M.size(0) == group_count,
"zero_start_index_M must have an entry for each group.");

// We assume that M, N, and K are fixed across groups.
// The actual m values are sstored in the passed M tensor.
int M = XQ[0].size(0);
int K = XQ[0].size(1);
int N = WQ[0].size(0);

// Make sure that inputs are allocated in sequential memory as required by
// this mode.
for (int i = 1; i < group_count; i++) {
// Check that all inputs are allocated directly following preceding input.
TORCH_CHECK(
XQ[i].data_ptr() ==
(reinterpret_cast<ADataType*>(XQ[i - 1].data_ptr()) + (M * K)),
"Inputs must be sequential in memory to support dynamic M, but XQ is not.");
TORCH_CHECK(
WQ[i].data_ptr() ==
(reinterpret_cast<BDataType*>(WQ[i - 1].data_ptr()) + (N * K)),
"Inputs must be sequential in memory to support dynamic M, but WQ is not.");
TORCH_CHECK(
x_scale[i].data_ptr() ==
(reinterpret_cast<D0DataType*>(x_scale[i - 1].data_ptr()) + (M)),
"Inputs must be sequential in memory to support dynamic M, but x_scale is not.");
TORCH_CHECK(
w_scale[i].data_ptr() ==
(reinterpret_cast<D1DataType*>(w_scale[i - 1].data_ptr()) + (N)),
"Inputs must be sequential in memory to support dynamic M, but w_scale is not.");
TORCH_CHECK(
output[i].data_ptr() ==
(reinterpret_cast<EDataType*>(output[i - 1].data_ptr()) + (M * N)),
"Inputs must be sequential in memory to support dynamic M, but output is not.");
}

// Launch a kernel that sets kernel argument memory.
int const blockSize = std::min(1024, group_count);
int const numBlocks = (group_count + blockSize - 1) / blockSize;
set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
reinterpret_cast<ADataType*>(XQ[0].data_ptr()),
reinterpret_cast<BDataType*>(WQ[0].data_ptr()),
reinterpret_cast<D0DataType*>(w_scale[0].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[0].data_ptr()),
reinterpret_cast<EDataType*>(output[0].data_ptr()),
reinterpret_cast<int32_t*>(zero_start_index_M.data_ptr()),
M,
N,
K,
group_count);
}

at::Tensor get_grouped_kernel_args(
at::TensorList XQ,
at::TensorList WQ,
Expand All @@ -122,8 +283,6 @@ at::Tensor get_grouped_kernel_args(
std::optional<at::Tensor> zero_start_index_M,
std::vector<at::Tensor> output) {
int group_count = XQ.size();
auto stream = at::cuda::getCurrentHIPStream().stream();

// Get space on device for the kernel argument tensor.
at::Tensor kernel_args = at::empty(
{static_cast<long>(group_count * sizeof(KernelArguments))},
Expand All @@ -139,104 +298,17 @@ at::Tensor get_grouped_kernel_args(
// using zero_start_index_M.

if (zero_start_index_M.has_value()) {
// Make sure zero_start_index_M is configured properly.
at::Tensor prepad_M = zero_start_index_M.value();
// Confirm M is on the proper device.
TORCH_CHECK(
XQ[0].device() == prepad_M.device(),
"zero_start_index_M and inputs must be on the same device.");
TORCH_CHECK(
prepad_M.size(0) == group_count,
"zero_start_index_M must have an entry for each group.");

// We assume that M, N, and K are fixed across groups.
// The actual m values are sstored in the passed M tensor.
int M = XQ[0].size(0);
int K = XQ[0].size(1);
int N = WQ[0].size(0);

// Make sure that inputs are allocated in sequential memory as required by
// this mode.
for (int i = 1; i < group_count; i++) {
// Check that all inputs are allocated directly following preceding input.
TORCH_CHECK(
XQ[i].data_ptr() ==
(reinterpret_cast<ADataType*>(XQ[i - 1].data_ptr()) + (M * K)),
"Inputs must be sequential in memory to support dynamic M, but XQ is not.");
TORCH_CHECK(
WQ[i].data_ptr() ==
(reinterpret_cast<BDataType*>(WQ[i - 1].data_ptr()) + (N * K)),
"Inputs must be sequential in memory to support dynamic M, but WQ is not.");
TORCH_CHECK(
x_scale[i].data_ptr() ==
(reinterpret_cast<D0DataType*>(x_scale[i - 1].data_ptr()) + (M)),
"Inputs must be sequential in memory to support dynamic M, but x_scale is not.");
TORCH_CHECK(
w_scale[i].data_ptr() ==
(reinterpret_cast<D1DataType*>(w_scale[i - 1].data_ptr()) + (N)),
"Inputs must be sequential in memory to support dynamic M, but w_scale is not.");
TORCH_CHECK(
output[i].data_ptr() ==
(reinterpret_cast<EDataType*>(output[i - 1].data_ptr()) +
(M * N)),
"Inputs must be sequential in memory to support dynamic M, but output is not.");
}

// Launch a kernel that sets kernel argument memory.
int const blockSize = std::min(1024, group_count);
int const numBlocks = (group_count + blockSize - 1) / blockSize;
set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
reinterpret_cast<ADataType*>(XQ[0].data_ptr()),
reinterpret_cast<BDataType*>(WQ[0].data_ptr()),
reinterpret_cast<D0DataType*>(w_scale[0].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[0].data_ptr()),
reinterpret_cast<EDataType*>(output[0].data_ptr()),
reinterpret_cast<int32_t*>(prepad_M.data_ptr()),
M,
N,
K,
group_count);
return kernel_args;
set_dynamic_kernel_args(
kernel_args,
XQ,
WQ,
x_scale,
w_scale,
output,
zero_start_index_M.value());
} else {
// When running in eager mode, we assume we can directly interact with host
// values.
// Note that this version is not supported with cuda graphs.
TORCH_CHECK(
stream == 0,
"f8f8bf16_rowwise_grouped eager mode is not supported with cuda graphs.");

std::vector<KernelArguments> ggemm_kargs;
ggemm_kargs.reserve(group_count);

// Iterate over inputs and get group information.
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int K = XQ[i].size(1);
int N = WQ[i].size(0);
KernelArguments group_args = {
reinterpret_cast<ADataType*>(XQ[i].data_ptr()),
reinterpret_cast<BDataType*>(WQ[i].data_ptr()),
{reinterpret_cast<D0DataType*>(w_scale[i].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[i].data_ptr())},
reinterpret_cast<EDataType*>(output[i].data_ptr()),
M,
N,
K,
K,
K,
{0, 0},
N};
ggemm_kargs.push_back(group_args);
}
// Copy data onto device.
hipMemcpy(
kernel_args.data_ptr(), // Destination
ggemm_kargs.data(), // Source
sizeof(KernelArguments) * group_count, // Number of bytes
hipMemcpyHostToDevice); // Copy Type
set_static_kernel_args(kernel_args, XQ, WQ, x_scale, w_scale, output);
}

return kernel_args;
}

Expand Down

0 comments on commit 0322843

Please sign in to comment.