From b2d5cf8a78df291d40354e8a216168ea5461eca1 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Tue, 8 Aug 2023 07:40:52 +0000 Subject: [PATCH] GQA-4 example --- .../CMakeLists.txt | 1 + ...uped_query_attention_forward_wmma_fp16.cpp | 302 ++++ ...ulti_query_attention_forward_wmma_fp16.cpp | 9 +- ...n_grouped_query_attention_forward_wmma.inc | 340 +++++ ...e_grouped_query_attention_forward_wmma.hpp | 1257 +++++++++++++++++ .../cpu/reference_batched_gemm.hpp | 124 ++ 6 files changed, 2028 insertions(+), 5 deletions(-) create mode 100644 example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index a88a3503144..af6609faad1 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -11,6 +11,7 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) + add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) endif() add_custom_target(example_gemm_scale_softmax_gemm) diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp new file mode 100644 index 00000000000..12dcfcc36d9 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Grouped Query Attention, +Ainslie, Joshua, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit +Sanghai. “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” +arXiv, May 22, 2023. https://doi.org/10.48550/arXiv.2305.13245. + +Example is GQA-4 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; +static constexpr ck::index_t QueryGroupNumber = 4; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +// #define CK_MHA_USE_WAVE_1 +// #define CK_MHA_USE_WAVE_2 +// #define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 32, + // Gemm 0 + 16, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 32, + // Gemm 0 + 16, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 64, + // Gemm 0 + 32, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 64, + // Gemm 0 + 32, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 128, + // Gemm 0 + 64, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 128, + // Gemm 0 + 64, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = + ck::tensor_operation::host::ReferenceBatchedGemm_GQA; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = + ck::tensor_operation::host::ReferenceBatchedGemm_GQA; + +#include "run_grouped_query_attention_forward_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp index 43feea12fb4..694a320a45f 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp @@ -2,11 +2,10 @@ // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. /* -Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n - |-----------------| - Gemm0 - |-------------------------------------| - Gemm1 +Multi-Query Attention +Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.” arXiv.org, November 6, +2019. https://arxiv.org/abs/1911.02150v1. + */ #include diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc new file mode 100644 index 00000000000..0d66d837d30 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 64; + ck::index_t O = 64; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 4; + ck::index_t G1 = 16; + ck::index_t KV_head = QueryGroupNumber; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, KV_head, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * KV_head * K, K, KV_head * K, 1} + // B0 layout [G0, N, G1, K] + : std::vector{KV_head * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, KV_head, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * KV_head * O, O, 1, KV_head * O} + // B1 layout [G0, N, G1, O] + : std::vector{KV_head * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_conv_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + G0, + G1, + alpha, + input_permute, + output_permute); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1; + std::size_t num_btype = + (sizeof(ADataType) * M * K + sizeof(CDataType) * M * O) * G0 * G1 + + (sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O) * G0 * QueryGroupNumber; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g0_g1_m_k({G0, G1, M, K}); + Tensor b0_g0_gq_k_n({G0, QueryGroupNumber, K, N}); + Tensor b1_g0_gq_n_o({G0, QueryGroupNumber, N, O}); + Tensor acc0_g0_g1_m_n({G0, G1, M, N}); // scratch object after gemm0 + Tensor a1_g0_g1_m_n({G0, G1, M, N}); // scratch object after softmax + Tensor c_g0_g1_m_o_host_result({G0, G1, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g0_g1_m_k(idx[0], idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g0_gq_k_n(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g0_gq_n_o(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument(a_g0_g1_m_k, + b0_g0_gq_k_n, + acc0_g0_g1_m_n, + a_element_op, + b0_element_op, + acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); + acc0_g0_g1_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[2], idx[3])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = + ref_softmax.MakeArgument(acc0_g0_g1_m_n, a1_g0_g1_m_n, 1, 0, {3}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g0_g1_m_n, + b1_g0_gq_n_o, + c_g0_g1_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = c_g0_g1_m_o_host_result(idx); }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MQA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M + << ", N: " << N << ", K: " << K << ", O: " << O << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp new file mode 100644 index 00000000000..2313b256c32 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp @@ -0,0 +1,1257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Multi-Query Attention (MQA) kernel implementation +// Assume number of head of K,V is 1. +// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N] +// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O] +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__)) + + // clang-format off +// *************************************************** + const auto q_head = G1; + const auto kv_head = QueryGroupNumber; +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, q_head, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K] + : std::array{q_head * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, kv_head, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K] + : std::array{kv_head * N * K, N * K, K, 1}; // B0 layout [G0, 1, N, K] + + std::array b1_gs_os_ns_lengths{G0, kv_head, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O] + : std::array{kv_head * N * O, N * O, 1, O}; // B1 layout [G0, 1, N, O] + + std::array c_gs_ms_os_lengths{G0, q_head, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O] + : std::array{q_head * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + compute_base_ptr_of_batch.GetB0BasePtr(g_idx * QueryGroupNumber / G1))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + compute_base_ptr_of_batch.GetB1BasePtr(g_idx * QueryGroupNumber / G1))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx1100__)) +} + +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceGroupedQueryAttentionForward_Wmma + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceGroupedQueryAttentionForward_Wmma; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102") + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(arg.G1_ % QueryGroupNumber != 0) + { + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_grouped_query_attention_wmma; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102") + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGroupedQueryAttentionForward_Wmma, " + << "QueryGroupNumber: " + << QueryGroupNumber << ", " + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp index 327cc9e28c6..7a8e1d9a37d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -255,6 +255,130 @@ struct ReferenceBatchedGemm_MQA : public device::BaseOperator } }; +template +struct ReferenceBatchedGemm_GQA : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_g0_g1_m_k, + const Tensor& b_g0_gq_k_n, + Tensor& c_g0_g1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_g0_g1_m_k_{a_g0_g1_m_k}, + b_g0_gq_k_n_{b_g0_gq_k_n}, + c_g0_g1_m_n_{c_g0_g1_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_g0_g1_m_k_; + const Tensor& b_g0_gq_k_n_; + Tensor& c_g0_g1_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceBatchedGemm_GQA::Argument; + + float Run(const Argument& arg) + { + auto f_g0g1mk_g0gqkn_g0g1mn = [&](auto g0, auto g1, auto m, auto n) { + const int G1 = arg.a_g0_g1_m_k_.mDesc.GetLengths()[1]; + const int K = arg.a_g0_g1_m_k_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + + arg.a_element_op_(v_a, arg.a_g0_g1_m_k_(g0, g1, m, k)); + arg.b_element_op_(v_b, arg.b_g0_gq_k_n_(g0, g1 * QueryGroupNumber / G1, k, n)); + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + AccDataType v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_g0_g1_m_n_(g0, g1, m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_g0g1mk_g0gqkn_g0g1mn, + arg.c_g0_g1_m_n_.mDesc.GetLengths()[0], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[1], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[2], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_g0_g1_m_k, + const Tensor& b_g0_gq_k_n, + Tensor& c_g0_g1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{ + a_g0_g1_m_k, b_g0_gq_k_n, c_g0_g1_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceBatchedGemm_GQA" + << std::endl; + // clang-format on + + return str.str(); + } +}; + } // namespace host } // namespace tensor_operation } // namespace ck