From 03228433f53bec09a28b622fd095776a42c64788 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 20 Dec 2024 13:49:25 -0800 Subject: [PATCH] Improve performance of prefill mode CK FP8 Grouped Gemm (#3522) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/603 I previously assumed that using hipmemcpy would be more efficient than launching many kernels that directly set gpu memory. This assumption is apparently (and very surprisingly) untrue. It seems the the multi-kernel-launch approach reduces overhead considerably, giving a 10% speedup. Differential Revision: D67531231 --- .../gen_ai/bench/quantize_bench.py | 9 + .../fp8_rowwise_grouped_gemm.hip | 268 +++++++++++------- 2 files changed, 179 insertions(+), 98 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index fa8dc2142..1f6d083c7 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -275,6 +275,10 @@ def main(args: Any): M = [int(m) for m in args.M.strip().split(",")] N = [int(n) for n in args.N.strip().split(",")] K = [int(k) for k in args.K.strip().split(",")] + if args.groups: + M = M * int(args.groups) + N = N * int(args.groups) + K = K * int(args.groups) if args.B is None: B = [1] * len(M) else: @@ -384,6 +388,11 @@ def invoke_main() -> None: help="If set, do grouped gemm. In this mode, M, N, and K are interpreted " "as the size of groups. The length of each must be the same.", ) + parser.add_argument( + "--groups", + default=None, + help="If set with grouped mode, repeat input shapes this many times.", + ) parser.add_argument( "--no_cuda_graph", default=False, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip index d59e6db95..a5faae201 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip @@ -80,6 +80,99 @@ RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) { return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; } +__global__ void set_kernel_args_kernel( + KernelArguments* kernel_args, + ADataType* XQ, + BDataType* WQ, + D0DataType* w_scale, + D1DataType* x_scale, + EDataType* output, + int M, + int N, + int K) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + // Each kernel annoyingly can only set the kernel args for one group. + // This could only be avoided with complicated memory management. + if (idx == 0) { + // Write kernel arguments directly to memory. + KernelArguments kernel_group_args = { + XQ, WQ, {w_scale, x_scale}, output, M, N, K, K, K, {0, 0}, N}; + kernel_args[0] = kernel_group_args; + } +} + +void set_static_kernel_args( + at::Tensor kernel_args, + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + std::vector output) { + // Get current cuda stream. + auto stream = at::cuda::getCurrentHIPStream().stream(); + int group_count = XQ.size(); + // When group count is large, we can more efficiently initialize + // by doing host setup and a memcpy. This is only viable if cuda + // graphs arent being used. + if (group_count >= 16 && stream == 0) { + std::vector ggemm_kargs; + ggemm_kargs.reserve(group_count); + + // Iterate over inputs and get group information. + for (int i = 0; i < group_count; i++) { + int M = XQ[i].size(0); + int K = XQ[i].size(1); + int N = WQ[i].size(0); + KernelArguments group_args = { + reinterpret_cast(XQ[i].data_ptr()), + reinterpret_cast(WQ[i].data_ptr()), + {reinterpret_cast(w_scale[i].data_ptr()), + reinterpret_cast(x_scale[i].data_ptr())}, + reinterpret_cast(output[i].data_ptr()), + M, + N, + K, + K, + K, + {0, 0}, + N}; + ggemm_kargs.push_back(group_args); + } + // Copy data onto device. + hipMemcpy( + kernel_args.data_ptr(), // Destination + ggemm_kargs.data(), // Source + sizeof(KernelArguments) * group_count, // Number of bytes + hipMemcpyHostToDevice); // Copy Type + } else { + // We use the smallest reasonable block size since we effectively need only + // 1 thread. + int blockSize = 32; + int numBlocks = 1; + // Launch a kernel for each group to set kernel memory on device. + // Using multiple kernels this way allows us to support arbitrary M,N,K. + // For some reason, this approach is faster than using hipmemcpy. + for (int i = 0; i < group_count; i++) { + int M = XQ[i].size(0); + int K = XQ[i].size(1); + int N = WQ[i].size(0); + // Launch kernel to set kernel arguments. + set_kernel_args_kernel<<>>( + reinterpret_cast( + reinterpret_cast(kernel_args.data_ptr()) + + (i * sizeof(KernelArguments))), + reinterpret_cast(XQ[i].data_ptr()), + reinterpret_cast(WQ[i].data_ptr()), + reinterpret_cast(w_scale[i].data_ptr()), + reinterpret_cast(x_scale[i].data_ptr()), + reinterpret_cast(output[i].data_ptr()), + M, + N, + K); + } + } +} + __global__ void set_kernel_args_fixed_nk_kernel( KernelArguments* kernel_args, ADataType* XQ, @@ -114,6 +207,74 @@ __global__ void set_kernel_args_fixed_nk_kernel( } } +void set_dynamic_kernel_args( + at::Tensor kernel_args, + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + std::vector output, + at::Tensor zero_start_index_M) { + // Get current cuda stream. + auto stream = at::cuda::getCurrentHIPStream().stream(); + int group_count = XQ.size(); + // Confirm M is on the proper device. + TORCH_CHECK( + XQ[0].device() == zero_start_index_M.device(), + "zero_start_index_M and inputs must be on the same device."); + TORCH_CHECK( + zero_start_index_M.size(0) == group_count, + "zero_start_index_M must have an entry for each group."); + + // We assume that M, N, and K are fixed across groups. + // The actual m values are sstored in the passed M tensor. + int M = XQ[0].size(0); + int K = XQ[0].size(1); + int N = WQ[0].size(0); + + // Make sure that inputs are allocated in sequential memory as required by + // this mode. + for (int i = 1; i < group_count; i++) { + // Check that all inputs are allocated directly following preceding input. + TORCH_CHECK( + XQ[i].data_ptr() == + (reinterpret_cast(XQ[i - 1].data_ptr()) + (M * K)), + "Inputs must be sequential in memory to support dynamic M, but XQ is not."); + TORCH_CHECK( + WQ[i].data_ptr() == + (reinterpret_cast(WQ[i - 1].data_ptr()) + (N * K)), + "Inputs must be sequential in memory to support dynamic M, but WQ is not."); + TORCH_CHECK( + x_scale[i].data_ptr() == + (reinterpret_cast(x_scale[i - 1].data_ptr()) + (M)), + "Inputs must be sequential in memory to support dynamic M, but x_scale is not."); + TORCH_CHECK( + w_scale[i].data_ptr() == + (reinterpret_cast(w_scale[i - 1].data_ptr()) + (N)), + "Inputs must be sequential in memory to support dynamic M, but w_scale is not."); + TORCH_CHECK( + output[i].data_ptr() == + (reinterpret_cast(output[i - 1].data_ptr()) + (M * N)), + "Inputs must be sequential in memory to support dynamic M, but output is not."); + } + + // Launch a kernel that sets kernel argument memory. + int const blockSize = std::min(1024, group_count); + int const numBlocks = (group_count + blockSize - 1) / blockSize; + set_kernel_args_fixed_nk_kernel<<>>( + reinterpret_cast(kernel_args.data_ptr()), + reinterpret_cast(XQ[0].data_ptr()), + reinterpret_cast(WQ[0].data_ptr()), + reinterpret_cast(w_scale[0].data_ptr()), + reinterpret_cast(x_scale[0].data_ptr()), + reinterpret_cast(output[0].data_ptr()), + reinterpret_cast(zero_start_index_M.data_ptr()), + M, + N, + K, + group_count); +} + at::Tensor get_grouped_kernel_args( at::TensorList XQ, at::TensorList WQ, @@ -122,8 +283,6 @@ at::Tensor get_grouped_kernel_args( std::optional zero_start_index_M, std::vector output) { int group_count = XQ.size(); - auto stream = at::cuda::getCurrentHIPStream().stream(); - // Get space on device for the kernel argument tensor. at::Tensor kernel_args = at::empty( {static_cast(group_count * sizeof(KernelArguments))}, @@ -139,104 +298,17 @@ at::Tensor get_grouped_kernel_args( // using zero_start_index_M. if (zero_start_index_M.has_value()) { - // Make sure zero_start_index_M is configured properly. - at::Tensor prepad_M = zero_start_index_M.value(); - // Confirm M is on the proper device. - TORCH_CHECK( - XQ[0].device() == prepad_M.device(), - "zero_start_index_M and inputs must be on the same device."); - TORCH_CHECK( - prepad_M.size(0) == group_count, - "zero_start_index_M must have an entry for each group."); - - // We assume that M, N, and K are fixed across groups. - // The actual m values are sstored in the passed M tensor. - int M = XQ[0].size(0); - int K = XQ[0].size(1); - int N = WQ[0].size(0); - - // Make sure that inputs are allocated in sequential memory as required by - // this mode. - for (int i = 1; i < group_count; i++) { - // Check that all inputs are allocated directly following preceding input. - TORCH_CHECK( - XQ[i].data_ptr() == - (reinterpret_cast(XQ[i - 1].data_ptr()) + (M * K)), - "Inputs must be sequential in memory to support dynamic M, but XQ is not."); - TORCH_CHECK( - WQ[i].data_ptr() == - (reinterpret_cast(WQ[i - 1].data_ptr()) + (N * K)), - "Inputs must be sequential in memory to support dynamic M, but WQ is not."); - TORCH_CHECK( - x_scale[i].data_ptr() == - (reinterpret_cast(x_scale[i - 1].data_ptr()) + (M)), - "Inputs must be sequential in memory to support dynamic M, but x_scale is not."); - TORCH_CHECK( - w_scale[i].data_ptr() == - (reinterpret_cast(w_scale[i - 1].data_ptr()) + (N)), - "Inputs must be sequential in memory to support dynamic M, but w_scale is not."); - TORCH_CHECK( - output[i].data_ptr() == - (reinterpret_cast(output[i - 1].data_ptr()) + - (M * N)), - "Inputs must be sequential in memory to support dynamic M, but output is not."); - } - - // Launch a kernel that sets kernel argument memory. - int const blockSize = std::min(1024, group_count); - int const numBlocks = (group_count + blockSize - 1) / blockSize; - set_kernel_args_fixed_nk_kernel<<>>( - reinterpret_cast(kernel_args.data_ptr()), - reinterpret_cast(XQ[0].data_ptr()), - reinterpret_cast(WQ[0].data_ptr()), - reinterpret_cast(w_scale[0].data_ptr()), - reinterpret_cast(x_scale[0].data_ptr()), - reinterpret_cast(output[0].data_ptr()), - reinterpret_cast(prepad_M.data_ptr()), - M, - N, - K, - group_count); - return kernel_args; + set_dynamic_kernel_args( + kernel_args, + XQ, + WQ, + x_scale, + w_scale, + output, + zero_start_index_M.value()); } else { - // When running in eager mode, we assume we can directly interact with host - // values. - // Note that this version is not supported with cuda graphs. - TORCH_CHECK( - stream == 0, - "f8f8bf16_rowwise_grouped eager mode is not supported with cuda graphs."); - - std::vector ggemm_kargs; - ggemm_kargs.reserve(group_count); - - // Iterate over inputs and get group information. - for (int i = 0; i < group_count; i++) { - int M = XQ[i].size(0); - int K = XQ[i].size(1); - int N = WQ[i].size(0); - KernelArguments group_args = { - reinterpret_cast(XQ[i].data_ptr()), - reinterpret_cast(WQ[i].data_ptr()), - {reinterpret_cast(w_scale[i].data_ptr()), - reinterpret_cast(x_scale[i].data_ptr())}, - reinterpret_cast(output[i].data_ptr()), - M, - N, - K, - K, - K, - {0, 0}, - N}; - ggemm_kargs.push_back(group_args); - } - // Copy data onto device. - hipMemcpy( - kernel_args.data_ptr(), // Destination - ggemm_kargs.data(), // Source - sizeof(KernelArguments) * group_count, // Number of bytes - hipMemcpyHostToDevice); // Copy Type + set_static_kernel_args(kernel_args, XQ, WQ, x_scale, w_scale, output); } - return kernel_args; }