Skip to content

Commit

Permalink
Improve performance of prefill mode FP8 Grouped Gemm
Browse files Browse the repository at this point in the history
Summary: I previously assumed that using hipmemcpy would be more efficient than launching many kernels that directly set gpu memory. This assumption is apparently (and very surprisingly) untrue. It seems the the multi-kernel-launch approach reduces overhead considerably, giving a 10% speedup.

Differential Revision: D67531231
  • Loading branch information
jwfromm authored and facebook-github-bot committed Dec 20, 2024
1 parent 6da23d5 commit 71fed6a
Showing 1 changed file with 35 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,27 @@ RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) {
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
}

__global__ void set_kernel_args_kernel(
KernelArguments* kernel_args,
ADataType* XQ,
BDataType* WQ,
D0DataType* w_scale,
D1DataType* x_scale,
EDataType* output,
int M,
int N,
int K) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each kernel annoyingly can only set the kernel args for one group.
// This could only be avoided with complicated memory management.
if (idx == 0) {
// Write kernel arguments directly to memory.
KernelArguments kernel_group_args = {
XQ, WQ, {w_scale, x_scale}, output, M, N, K, K, K, {0, 0}, N};
kernel_args[0] = kernel_group_args;
}
}

__global__ void set_kernel_args_fixed_nk_kernel(
KernelArguments* kernel_args,
ADataType* XQ,
Expand Down Expand Up @@ -199,42 +220,30 @@ at::Tensor get_grouped_kernel_args(
group_count);
return kernel_args;
} else {
// When running in eager mode, we assume we can directly interact with host
// values.
// Note that this version is not supported with cuda graphs.
TORCH_CHECK(
stream == 0,
"f8f8bf16_rowwise_grouped eager mode is not supported with cuda graphs.");

std::vector<KernelArguments> ggemm_kargs;
ggemm_kargs.reserve(group_count);

// Iterate over inputs and get group information.
// We use the smallest reasonable block size since we effectively need only 1 thread.
int blockSize = 32;
int numBlocks = 1;
// Launch a kernel for each group to set kernel memory on device.
// Using multiple kernels this way allows us to support arbitrary M,N,K.
// For some reason, this approach is faster than using hipmemcpy.
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int K = XQ[i].size(1);
int N = WQ[i].size(0);
KernelArguments group_args = {
// Launch kernel to set kernel arguments.
set_kernel_args_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(
reinterpret_cast<char*>(kernel_args.data_ptr()) +
(i * sizeof(KernelArguments))),
reinterpret_cast<ADataType*>(XQ[i].data_ptr()),
reinterpret_cast<BDataType*>(WQ[i].data_ptr()),
{reinterpret_cast<D0DataType*>(w_scale[i].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[i].data_ptr())},
reinterpret_cast<D0DataType*>(w_scale[i].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[i].data_ptr()),
reinterpret_cast<EDataType*>(output[i].data_ptr()),
M,
N,
K,
K,
K,
{0, 0},
N};
ggemm_kargs.push_back(group_args);
K);
}
// Copy data onto device.
hipMemcpy(
kernel_args.data_ptr(), // Destination
ggemm_kargs.data(), // Source
sizeof(KernelArguments) * group_count, // Number of bytes
hipMemcpyHostToDevice); // Copy Type
}

return kernel_args;
Expand Down

0 comments on commit 71fed6a

Please sign in to comment.