From f6a511cfd3a67f8fed8ff482e5f9b6d0b5817171 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 10 Mar 2022 20:49:51 +0000 Subject: [PATCH 01/32] add gridwise gemm v4r1 --- example/01_gemm/gemm_xdl_fp16.cpp | 12 +- .../gpu/device/device_gemm_xdl_c_shuffle.hpp | 182 +++-- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 27 +- .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 1 + .../gpu/grid/gridwise_gemm_xdlops_v4r1.hpp | 728 ++++++++++++++++++ 5 files changed, 837 insertions(+), 113 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v4r1.hpp diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index ad369e774d4..3b95ad29d86 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -54,11 +54,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; #elif 1 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle -//######|AData| BData| CData| AccData| Shuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| Type| Type| Type| Type| Data| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -//######| | | | | Type| | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; +//######|AData| BData| CData| AccData| Shuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| Data| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWaveMWaveMPerXdl| ScalarPerVector| +//######| | | | | Type| | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWaveNWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; + < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 2, 2, S<1, 16, 1,16>, 8>; #elif 0 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp index a335a327a16..ccb4b5ebd7d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp @@ -1,6 +1,4 @@ -#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_HPP -#define DEVICE_GEMM_XDL_C_SHUFFLE_HPP - +#pragma once #include #include #include "device.hpp" @@ -9,53 +7,52 @@ #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v3r1.hpp" +#include "gridwise_gemm_xdlops_v4r1.hpp" namespace ck { namespace tensor_operation { namespace device { -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename CShuffleDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - ck::index_t BlockSize, - ck::index_t MPerBlock, - ck::index_t NPerBlock, - ck::index_t KPerBlock, - ck::index_t AK1, - ck::index_t BK1, - ck::index_t MPerXDL, - ck::index_t NPerXDL, - ck::index_t MXdlPerWave, - ck::index_t NXdlPerWave, - typename ABlockTransferThreadClusterLengths_K0_M_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - ck::index_t ABlockTransferSrcVectorDim, - ck::index_t ABlockTransferSrcScalarPerVector, - ck::index_t ABlockTransferDstScalarPerVector_K1, - bool ABlockLdsAddExtraM, - typename BBlockTransferThreadClusterLengths_K0_N_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - ck::index_t BBlockTransferSrcVectorDim, - ck::index_t BBlockTransferSrcScalarPerVector, - ck::index_t BBlockTransferDstScalarPerVector_K1, - bool BBlockLdsAddExtraN, - index_t CShuffleMXdlPerWavePerShuffle, - index_t CShuffleNXdlPerWavePerShuffle, - typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl, - index_t NumPrefetch = 1> +template struct DeviceGemmXdl_C_Shuffle : public DeviceGemm { @@ -132,7 +129,7 @@ struct DeviceGemmXdl_C_Shuffle using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v4r1< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -172,8 +169,8 @@ struct DeviceGemmXdl_C_Shuffle BBlockLdsAddExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, - CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - CBlockTransferScalarPerVector_NWaveNPerXdl, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch>; // Argument @@ -199,7 +196,7 @@ struct DeviceGemmXdl_C_Shuffle a_grid_desc_k0_m_k1_{}, b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_ctile_map_{}, M01_{M01}, N01_{N01}, @@ -216,10 +213,9 @@ struct DeviceGemmXdl_C_Shuffle if(GridwiseGemm::CheckValidity( a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); @@ -233,9 +229,8 @@ struct DeviceGemmXdl_C_Shuffle AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; index_t M01_; index_t N01_; @@ -284,71 +279,69 @@ struct DeviceGemmXdl_C_Shuffle if(has_main_k0_block_loop) { - const auto kernel = kernel_gemm_xdlops_v3r1< + const auto kernel = kernel_gemm_xdlops_v4r1< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, remove_reference_t, remove_reference_t, remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, remove_reference_t, true>; - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + ave_time = + launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); } else { - const auto kernel = kernel_gemm_xdlops_v3r1< + const auto kernel = kernel_gemm_xdlops_v4r1< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, remove_reference_t, remove_reference_t, remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, remove_reference_t, false>; - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + ave_time = + launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); } return ave_time; @@ -474,4 +467,3 @@ struct DeviceGemmXdl_C_Shuffle } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index bf6c3610b7f..d51ebf7faaf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -277,14 +277,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 __host__ __device__ static constexpr auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() { - constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); return make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); } using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -539,8 +539,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // output: register to global memory { - constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -564,8 +564,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 static_cast(p_shared_block), c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - static_assert(M1 == MWaves, ""); - static_assert(N1 == NWaves, ""); + static_assert(M1 == MWave, ""); + static_assert(N1 == NWave, ""); static_assert(M2 * M3 * M4 == MPerXDL, ""); static_assert(N2 == NPerXDL, ""); @@ -646,14 +646,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 n_thread_data_on_block_idx[I2]), ck::tensor_operation::element_wise::PassThrough{}}; + // LDS to global auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< BlockSize, // index_t BlockSize, CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, - CShuffleMRepeatPerShuffle * MWaves * MPerXDL, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, 1, - CShuffleNRepeatPerShuffle * NWaves * NPerXDL>, // BlockSliceLengths, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, FloatC, // typename SrcData, @@ -672,11 +673,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 c_element_op}; constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWaves * MPerXDL, 0, 0); + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWaves * NPerXDL); + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWaves * NPerXDL); + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { constexpr auto mxdlperwave = mxdlperwave_iter; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 3c815716259..49a5f9f3e30 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -657,6 +657,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 n_thread_data_on_block_idx[I2]), ck::tensor_operation::element_wise::PassThrough{}}; + // LDS to global auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< BlockSize, // index_t BlockSize, CElementwiseOperation, // ElementwiseOperation, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v4r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v4r1.hpp new file mode 100644 index 00000000000..b18c6fcc507 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v4r1.hpp @@ -0,0 +1,728 @@ +#pragma once +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_tensor_slice_transfer_v4r1.hpp" +#include "blockwise_tensor_slice_transfer_v6r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v4r1(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v4r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + constexpr auto max_lds_align = AK1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(AK0, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + constexpr auto max_lds_align = BK1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(BK0, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check NumPrefetch + if constexpr(NumPrefetch == 1) + { + // 1-stage prefetch always supported + } + else if constexpr(NumPrefetch == 2) + { + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K / KPerBlock) % 2 == 0)) + { + return false; + } + } + else + { + return false; + } + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + // TODO move this function into GEMM-pipeline class + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = ((K0 * AK1) / (NumPrefetch * KPerBlock)) > 1; + + return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumPrefetch>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumPrefetch>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumPrefetch, + HasMainK0BlockLoop>{}; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum_t::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_shuffle_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< + BlockSize, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + constexpr auto mperblock_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 0, 0); + constexpr auto nperblock_forward_step = + make_multi_index(0, 0, 0, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl); + constexpr auto nperblock_backward_step = + make_multi_index(0, 0, 0, -CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NXdlPerWave, CShuffleNXdlPerWavePerShuffle>{}( + [&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nperblock_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nperblock_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mperblock_forward_step); + } + }); + } + } +}; + +} // namespace ck From 6997a0a7c2879e29a35b91ab9366c22afce061d2 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 10 Mar 2022 22:36:27 +0000 Subject: [PATCH 02/32] rename --- example/01_gemm/gemm_xdl_fp16.cpp | 14 +- .../gpu/device/device_gemm_xdl_c_shuffle.hpp | 155 ++++++++++-------- ....hpp => gridwise_gemm_xdl_cshuffle_v1.hpp} | 24 +-- 3 files changed, 104 insertions(+), 89 deletions(-) rename include/ck/tensor_operation/gpu/grid/{gridwise_gemm_xdlops_v4r1.hpp => gridwise_gemm_xdl_cshuffle_v1.hpp} (97%) diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 3b95ad29d86..2f3d82ec9f0 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -54,13 +54,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; #elif 1 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle -//######|AData| BData| CData| AccData| Shuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| Type| Type| Type| Type| Data| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWaveMWaveMPerXdl| ScalarPerVector| -//######| | | | | Type| | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWaveNWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -// < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; - < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; -// < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 2, 2, S<1, 16, 1,16>, 8>; +//######|AData| BData| CData| AccData| Shuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| Data| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| +//######| | | | | Type| | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; + < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 2, 2, S<1, 16, 1,16>, 8>; #elif 0 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp index ccb4b5ebd7d..5cd1bb3e782 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp @@ -7,7 +7,7 @@ #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v4r1.hpp" +#include "gridwise_gemm_xdl_cshuffle_v1.hpp" namespace ck { namespace tensor_operation { @@ -129,49 +129,54 @@ struct DeviceGemmXdl_C_Shuffle using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v4r1< - BlockSize, - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CShuffleDataType, - CDataType, - InMemoryDataOperationEnum_t::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, - BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - NumPrefetch>; + using GridwiseGemm = +#if 0 + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v4r1 +#else + GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +#endif + ; // Argument struct Argument : public BaseArgument @@ -279,19 +284,24 @@ struct DeviceGemmXdl_C_Shuffle if(has_main_k0_block_loop) { - const auto kernel = kernel_gemm_xdlops_v4r1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; + const auto kernel = +#if 0 + kernel_gemm_xdlops_v4r1 +#else + kernel_gemm_xdl_cshuffle_v1 +#endif + , + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; ave_time = launch_and_time_kernel(kernel, @@ -312,19 +322,24 @@ struct DeviceGemmXdl_C_Shuffle } else { - const auto kernel = kernel_gemm_xdlops_v4r1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; + const auto kernel = +#if 0 + kernel_gemm_xdlops_v4r1 +#else + kernel_gemm_xdl_cshuffle_v1 +#endif + , + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; ave_time = launch_and_time_kernel(kernel, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v4r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp similarity index 97% rename from include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v4r1.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index b18c6fcc507..7ba5f7edbe2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -26,17 +26,17 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v4r1(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -95,7 +95,7 @@ template -struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v4r1 +struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; From 94b42affc58c2857305789bd91cad3069b3bfb19 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 11 Mar 2022 19:32:13 +0000 Subject: [PATCH 03/32] adding gemm+reduce --- example/14_gemm_reduce/CMakeLists.txt | 1 + .../14_gemm_reduce/gemm_xdl_reduce_fp16.cpp | 216 +++++ example/CMakeLists.txt | 1 + .../gpu/device/device_gemm_reduce.hpp | 40 + .../device_gemm_xdl_c_shuffle_reduce.hpp | 477 ++++++++++ .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 854 ++++++++++++++++++ 6 files changed, 1589 insertions(+) create mode 100644 example/14_gemm_reduce/CMakeLists.txt create mode 100644 example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_reduce.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp diff --git a/example/14_gemm_reduce/CMakeLists.txt b/example/14_gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000000..569dca61519 --- /dev/null +++ b/example/14_gemm_reduce/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_xdl_reduce_fp16 gemm_xdl_reduce_fp16.cpp) diff --git a/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp b/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp new file mode 100644 index 00000000000..e4dc8111f1c --- /dev/null +++ b/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp @@ -0,0 +1,216 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_c_shuffle_reduce.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; +using DDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// clang-format off +#if 1 +using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Reduce +//######|AData| BData| CData| AccData| CShuffle| DData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| DataType| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| +//######| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < F16, F16, F16, F32, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// < F16, F16, F16, F32, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 2, 2, S<1, 16, 1,16>, 8>; +#endif +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "d_m: " << d_m_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d_m_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmReduceInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + static_cast(d_m_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, nrepeat); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + d_m_device_buf.FromDevice(d_m_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + check_error(c_m_n_host_result, c_m_n_device_result); + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 6f9201d8351..2099e1eb3d9 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -38,3 +38,4 @@ add_subdirectory(10_conv2d_bwd_data) add_subdirectory(11_conv2d_bwd_wgt) add_subdirectory(12_reduce) add_subdirectory(13_pool2d_fwd) +add_subdirectory(14_gemm_reduce) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp new file mode 100644 index 00000000000..88246fb8754 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -0,0 +1,40 @@ +#pragma once +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmReduce : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + void* p_d, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmReducePtr = std::unique_ptr< + DeviceGemmReduce>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_reduce.hpp new file mode 100644 index 00000000000..81d76837252 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_reduce.hpp @@ -0,0 +1,477 @@ +#pragma once +#include +#include +#include "device.hpp" +#include "device_gemm_reduce.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdl_C_Shuffle_Reduce + : public DeviceGemmReduce +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % AK1 == 0); + + const index_t K0 = K / AK1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_k0_m_k1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, AK1)), make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_k0_m_k1; + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % BK1 == 0); + + const index_t K0 = K / BK1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_k0_n_k1 = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, BK1)), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_k0_n_k1; + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + DDataType* p_d_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_d_grid_{p_d_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceGemmXdl_C_Shuffle_Reduce::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceGemmXdl_C_Shuffle_Reduce::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = + DeviceGemmXdl_C_Shuffle_Reduce::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + DDataType* p_d_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdl_C_Shuffle_Reduce::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = + launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = + launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + DDataType* p_d, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + p_d, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + void* p_d, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_d), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl_C_Shuffle_Reduce" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp new file mode 100644 index 00000000000..14715017fa3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -0,0 +1,854 @@ +#pragma once +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_tensor_slice_transfer_v4r1.hpp" +#include "blockwise_tensor_slice_transfer_v6r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_reduce_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +} + +template +struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + constexpr auto max_lds_align = AK1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(AK0, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + constexpr auto max_lds_align = BK1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(BK0, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check NumPrefetch + if constexpr(NumPrefetch == 1) + { + // 1-stage prefetch always supported + } + else if constexpr(NumPrefetch == 2) + { + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K / KPerBlock) % 2 == 0)) + { + return false; + } + } + else + { + return false; + } + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + // TODO move this function into GEMM-pipeline class + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = ((K0 * AK1) / (NumPrefetch * KPerBlock)) > 1; + + return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumPrefetch>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumPrefetch>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumPrefetch, + HasMainK0BlockLoop>{}; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum_t::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_shuffle_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< + BlockSize, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + +#if 0 + // FIXME: hardcoded + using CReduceThreadClusterLengths = Sequence<64, 4>{}; + using c_in_reduce_thread_lengths_mperblock_nperblock = Sequence<1, 16> {} + + // VGPR c_in_reduce_thread_desc_mperblock_nperblock + constexpr auto c_in_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_in_reduce_thread_lengths_mperblock_nperblock)); + + // VGPR c_out_reduce_thread_desc_mperblock + constexpr auto c_out_reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); + + // VGPR c_reduce_thread_desc_mblock_mperblock + constexpr auto c_out_reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + + // global c_reduce_grid_desc_mblock_mperblock + + // TODO: this should be implemented as a blockwise reduction + // FIXME: hardcoded + auto c_in_reduce_thread_buf = make_static_buffer( + c_in_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + auto c_out_thread_reduce_buf = + make_static_buffer(c_out_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + // reduce: copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = + make_cluster_descriptor(CReduceThreadClusterLengths{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_in_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = + ThreadwiseTensorSliceTransfer_v2, + 1, + 8, + 1, + true>{c_in_reduce_thread_desc_mperblock_nperblock, + c_reduce_thread_data_idx_begin}; + + // reduce: copy from VGPR to global + auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + FloatCShuffle, + FloatC, + decltype(c_out_reduce_thread_desc_mblock_mperblock), + decltype(c_out_reduce_grid_desc_mblock_mperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, 1>, // FIXME: hardcoded + Sequence<0, 1>, + 1, + 1, + InMemoryDataOperationEnum_t::AtomicAdd, + 1, + true>{c_out_reduce_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + ck::tensor_operation::element_wise::PassThrough{}}; +#endif + + constexpr auto mperblock_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 0, 0); + constexpr auto nperblock_forward_step = + make_multi_index(0, 0, 0, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl); + constexpr auto nperblock_backward_step = + make_multi_index(0, 0, 0, -CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NXdlPerWave, CShuffleNXdlPerWavePerShuffle>{}( + [&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + +#if 0 + // reduce + { + // LDS to VGPR + c_reduce_thread_copy_lds_to_vgpr.Run( + c_in_reduce_thread_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_in_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0)); + + // reduce in VGPR + // FIXME: hardcoded + static_for<0, 1, 1>{}([&](auto im) { + // FIXME: Accumulator for reduction may use a different datatype? + FloatCShuffle v = 0; + + // FIXME: hardcoded + static_for<0, 16, 1>{}([&](auto in) { + v += c_in_reduce_thread_buf + [c_in_reduce_thread_desc_mperblock_nperblock + .CalculateOffset(make_tuple(im, in))]; + }); + + c_out_reduce_thread_buf(im) = v; + }); + + // VGPR to Global + c_out_reduce_thread_desc_mblock_mperblock.Run( + c_out_reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + c_out_reduce_thread_desc_mblock_mperblock, + c_reduce_grid_desc_mblock_mperblock, + c_grid_buf); + } +#endif + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nperblock_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nperblock_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mperblock_forward_step); + +#if 0 + // reduce: move + c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + c_reduce_grid_desc_mblock_mperblock, XXXXXXXXXXX); +#endif + } + }); + } + } +}; + +} // namespace ck From c50355938d54f8bda28b49e5d8c722aa7ac82a6c Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 11 Mar 2022 20:14:53 +0000 Subject: [PATCH 04/32] adding gemm+reduce --- .../14_gemm_reduce/gemm_xdl_reduce_fp16.cpp | 4 +- ...pp => device_gemm_reduce_xdl_cshuffle.hpp} | 48 +++++++++++++---- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 53 ++++++++++++++----- 3 files changed, 81 insertions(+), 24 deletions(-) rename include/ck/tensor_operation/gpu/device/{device_gemm_xdl_c_shuffle_reduce.hpp => device_gemm_reduce_xdl_cshuffle.hpp} (89%) diff --git a/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp b/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp index e4dc8111f1c..d3a337564c4 100644 --- a/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp +++ b/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp @@ -12,7 +12,7 @@ #include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_c_shuffle_reduce.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -46,7 +46,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // clang-format off #if 1 -using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Reduce +using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle //######|AData| BData| CData| AccData| CShuffle| DData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| Type| Type| Type| Type| DataType| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| //######| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp similarity index 89% rename from include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_reduce.hpp rename to include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 81d76837252..6d7a85a8050 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -54,7 +54,7 @@ template -struct DeviceGemmXdl_C_Shuffle_Reduce +struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce { static constexpr auto I0 = Number<0>{}; @@ -125,9 +125,16 @@ struct DeviceGemmXdl_C_Shuffle_Reduce } } + // assume D is packed tensor + static auto MakeDGridDescriptor_M(index_t M) + { + return make_naive_tensor_descriptor_packed(make_tuple(M)); + } + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); // GridwiseGemm using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< @@ -136,10 +143,12 @@ struct DeviceGemmXdl_C_Shuffle_Reduce AccDataType, CShuffleDataType, CDataType, + DDataType, InMemoryDataOperationEnum_t::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, + DGridDesc_M, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -199,7 +208,9 @@ struct DeviceGemmXdl_C_Shuffle_Reduce a_grid_desc_k0_m_k1_{}, b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, + d_grid_desc_m_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + d_grid_desc_mblock_mperblock_{}, block_2_ctile_map_{}, M01_{M01}, N01_{N01}, @@ -208,11 +219,12 @@ struct DeviceGemmXdl_C_Shuffle_Reduce c_element_op_{c_element_op} { a_grid_desc_k0_m_k1_ = - DeviceGemmXdl_C_Shuffle_Reduce::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + DeviceGemmReduce_Xdl_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); b_grid_desc_k0_n_k1_ = - DeviceGemmXdl_C_Shuffle_Reduce::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + DeviceGemmReduce_Xdl_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); c_grid_desc_m_n_ = - DeviceGemmXdl_C_Shuffle_Reduce::MakeCGridDescriptor_M_N(M, N, StrideC); + DeviceGemmReduce_Xdl_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); + d_grid_desc_m_ = DeviceGemmReduce_Xdl_CShuffle::MakeDGridDescriptor_M(M); if(GridwiseGemm::CheckValidity( a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) @@ -221,6 +233,9 @@ struct DeviceGemmXdl_C_Shuffle_Reduce GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n_); + d_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } @@ -234,8 +249,10 @@ struct DeviceGemmXdl_C_Shuffle_Reduce AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; + DGridDesc_M d_grid_desc_m_; typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; index_t M01_; index_t N01_; @@ -247,7 +264,7 @@ struct DeviceGemmXdl_C_Shuffle_Reduce // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceGemmXdl_C_Shuffle_Reduce::Argument; + using Argument = DeviceGemmReduce_Xdl_CShuffle::Argument; float Run(const Argument& arg, int nrepeat = 1) { @@ -262,6 +279,9 @@ struct DeviceGemmXdl_C_Shuffle_Reduce std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}" + << std::endl; } if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, @@ -288,10 +308,12 @@ struct DeviceGemmXdl_C_Shuffle_Reduce GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, - remove_reference_t, - remove_reference_t, + DDataType, + remove_reference_t, + remove_reference_t, remove_reference_t< typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -307,9 +329,11 @@ struct DeviceGemmXdl_C_Shuffle_Reduce arg.p_a_grid_, arg.p_b_grid_, arg.p_c_grid_, + arg.p_d_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -321,10 +345,12 @@ struct DeviceGemmXdl_C_Shuffle_Reduce GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, - remove_reference_t, - remove_reference_t, + DDataType, + remove_reference_t, + remove_reference_t, remove_reference_t< typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -340,9 +366,11 @@ struct DeviceGemmXdl_C_Shuffle_Reduce arg.p_a_grid_, arg.p_b_grid_, arg.p_c_grid_, + arg.p_d_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -457,7 +485,7 @@ struct DeviceGemmXdl_C_Shuffle_Reduce auto str = std::stringstream(); // clang-format off - str << "DeviceGemmXdl_C_Shuffle_Reduce" + str << "DeviceGemmReduce_Xdl_CShuffle" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 14715017fa3..9a1c7e3f450 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -14,9 +14,11 @@ namespace ck { template (p_a_grid, p_b_grid, p_c_grid, + p_d_grid, p_shared, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_mblock_mperblock_nblock_nperblock, + d_grid_desc_mblock_mperblock, a_element_op, b_element_op, c_element_op, @@ -58,10 +65,12 @@ template {}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return d_grid_desc_mblock_mperblock; + } + // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) @@ -333,6 +357,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using DGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = remove_cvref_t; @@ -340,11 +367,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, + FloatD* __restrict__ p_d_grid, void* __restrict__ p_shared, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, + const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock, const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op, const CElementwiseOperation& c_element_op, @@ -655,6 +684,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), c_element_op}; +#if 0 // LDS c_reduce_block_desc_mperblock_nperblock constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, @@ -668,7 +698,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); -#if 0 // FIXME: hardcoded using CReduceThreadClusterLengths = Sequence<64, 4>{}; using c_in_reduce_thread_lengths_mperblock_nperblock = Sequence<1, 16> {} From fe87d17b6fda8aa782a0c76e459b8ad60a2b6e8b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 11 Mar 2022 21:37:25 +0000 Subject: [PATCH 05/32] adding gemm+reduce --- .../14_gemm_reduce/gemm_xdl_reduce_fp16.cpp | 21 +++- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 111 ++++++++++-------- 2 files changed, 77 insertions(+), 55 deletions(-) diff --git a/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp b/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp index d3a337564c4..51790788e63 100644 --- a/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp +++ b/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp @@ -45,15 +45,13 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; // clang-format off -#if 1 using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle //######|AData| BData| CData| AccData| CShuffle| DData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| Type| Type| Type| Type| DataType| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| //######| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < F16, F16, F16, F32, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; + < F16, F16, F16, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; // < F16, F16, F16, F32, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 2, 2, S<1, 16, 1,16>, 8>; -#endif // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: @@ -196,11 +194,11 @@ int main(int argc, char* argv[]) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << gemm.GetTypeString() << std::endl; - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - d_m_device_buf.FromDevice(d_m_device_result.mData.data()); - if(do_verification) { + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + d_m_device_buf.FromDevice(d_m_device_result.mData.data()); + auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -209,7 +207,18 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); + for(int m = 0; m < M; ++m) + { + d_m_host_result(m) = 0; + + for(int n = 0; n < N; ++n) + { + d_m_host_result(m) += c_m_n_host_result(m, n); + } + } + check_error(c_m_n_host_result, c_m_n_device_result); + check_error(d_m_host_result, d_m_device_result); } return 0; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 9a1c7e3f450..654ba4e2be9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -385,6 +385,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto d_grid_buf = make_dynamic_buffer( + p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); // divide block work by [M, N] const auto block_work_idx = @@ -684,7 +686,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), c_element_op}; -#if 0 +#if 1 // LDS c_reduce_block_desc_mperblock_nperblock constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, @@ -698,9 +700,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); - // FIXME: hardcoded - using CReduceThreadClusterLengths = Sequence<64, 4>{}; - using c_in_reduce_thread_lengths_mperblock_nperblock = Sequence<1, 16> {} + // FIXME: hardcode + using CReduceThreadClusterLengths_MPerBlock_NPerBlock = Sequence<64, 4>; + constexpr auto c_in_reduce_thread_lengths_mperblock_nperblock = Sequence<1, 16>{}; // VGPR c_in_reduce_thread_desc_mperblock_nperblock constexpr auto c_in_reduce_thread_desc_mperblock_nperblock = @@ -708,26 +710,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 sequence_to_tuple_of_number(c_in_reduce_thread_lengths_mperblock_nperblock)); // VGPR c_out_reduce_thread_desc_mperblock + // FIXME: hardcode constexpr auto c_out_reduce_thread_desc_mperblock = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - // VGPR c_reduce_thread_desc_mblock_mperblock + // VGPR c_out_reduce_thread_desc_mblock_mperblock + // FIXME: hardcode constexpr auto c_out_reduce_thread_desc_mblock_mperblock = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); - // global c_reduce_grid_desc_mblock_mperblock - // TODO: this should be implemented as a blockwise reduction - // FIXME: hardcoded - auto c_in_reduce_thread_buf = make_static_buffer( - c_in_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + // FIXME: should use a different datatype? + auto c_in_reduce_thread_buf = + make_static_buffer( + c_in_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); - auto c_out_thread_reduce_buf = - make_static_buffer(c_out_reduce_thread_desc_mperblock.GetElementSpaceSize()); + auto c_out_reduce_thread_buf = + make_static_buffer( + c_out_reduce_thread_desc_mperblock.GetElementSpaceSize()); // reduce: copy from LDS to VGPR - constexpr auto c_reduce_thread_cluster_desc = - make_cluster_descriptor(CReduceThreadClusterLengths{}, Sequence<1, 0>{}); + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); const auto c_reduce_thread_cluster_idx = c_reduce_thread_cluster_desc.CalculateBottomIndex( @@ -736,36 +740,35 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 const auto c_reduce_thread_data_idx_begin = c_reduce_thread_cluster_idx * c_in_reduce_thread_lengths_mperblock_nperblock; - auto c_reduce_thread_copy_lds_to_vgpr = - ThreadwiseTensorSliceTransfer_v2, - 1, - 8, - 1, - true>{c_in_reduce_thread_desc_mperblock_nperblock, - c_reduce_thread_data_idx_begin}; + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatCShuffle, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_in_reduce_thread_desc_mperblock_nperblock), + decltype(c_in_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + 8, // FIXME: hardcoded + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; // reduce: copy from VGPR to global auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< FloatCShuffle, - FloatC, + FloatD, decltype(c_out_reduce_thread_desc_mblock_mperblock), - decltype(c_out_reduce_grid_desc_mblock_mperblock), + decltype(d_grid_desc_mblock_mperblock), ck::tensor_operation::element_wise::PassThrough, - Sequence<1, 1>, // FIXME: hardcoded + Sequence<1, 1>, // FIXME: hardcode Sequence<0, 1>, 1, - 1, + 1, // FIXME: hardcode InMemoryDataOperationEnum_t::AtomicAdd, 1, - true>{c_out_reduce_grid_desc_mblock_mperblock, - make_multi_index(block_work_idx[I0], // mblock - c_reduce_thread_data_idx_begin[I0]), // mperblock - ck::tensor_operation::element_wise::PassThrough{}}; + false>{d_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + ck::tensor_operation::element_wise::PassThrough{}}; #endif constexpr auto mperblock_forward_step = @@ -811,39 +814,48 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_buf); -#if 0 +#if 1 // reduce { // LDS to VGPR c_reduce_thread_copy_lds_to_vgpr.Run( - c_in_reduce_thread_desc_mperblock_nperblock, + c_reduce_block_desc_mperblock_nperblock, c_shuffle_block_buf, c_in_reduce_thread_desc_mperblock_nperblock, - make_tuple(I0, I0)); + make_tuple(I0, I0), + c_in_reduce_thread_buf); // reduce in VGPR - // FIXME: hardcoded + // loop over MPerBlock + // FIXME: hardcode static_for<0, 1, 1>{}([&](auto im) { // FIXME: Accumulator for reduction may use a different datatype? FloatCShuffle v = 0; - // FIXME: hardcoded + // loop over NPerBlock + // FIXME: hardcode static_for<0, 16, 1>{}([&](auto in) { - v += c_in_reduce_thread_buf - [c_in_reduce_thread_desc_mperblock_nperblock - .CalculateOffset(make_tuple(im, in))]; + constexpr index_t in_offset = + c_in_reduce_thread_desc_mperblock_nperblock.CalculateOffset( + make_tuple(im, in)); + + v += c_in_reduce_thread_buf[Number{}]; }); - c_out_reduce_thread_buf(im) = v; + constexpr index_t out_offset = + c_out_reduce_thread_desc_mperblock.CalculateOffset( + make_tuple(im)); + + c_out_reduce_thread_buf(Number{}) = v; }); // VGPR to Global - c_out_reduce_thread_desc_mblock_mperblock.Run( + c_reduce_thread_copy_vgpr_to_global.Run( c_out_reduce_thread_desc_mblock_mperblock, make_tuple(I0, I0), - c_out_reduce_thread_desc_mblock_mperblock, - c_reduce_grid_desc_mblock_mperblock, - c_grid_buf); + c_out_reduce_thread_buf, + d_grid_desc_mblock_mperblock, + d_grid_buf); } #endif @@ -869,10 +881,11 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_grid_desc_mblock_mperblock_nblock_nperblock, mperblock_forward_step); -#if 0 +#if 1 // reduce: move c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( - c_reduce_grid_desc_mblock_mperblock, XXXXXXXXXXX); + d_grid_desc_mblock_mperblock, + make_multi_index(0, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl)); #endif } }); From 4085ae1197b5876c7307763d877350f0b6c143ca Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 14 Mar 2022 00:18:07 +0000 Subject: [PATCH 06/32] adding gemm+reduce --- .../gpu/device/device_gemm.hpp | 34 +------------------ .../gpu/device/device_gemm_bias.hpp | 34 +++++++++++++++++++ .../device_gemm_xdl_c_shuffle_bias_2d.hpp | 4 +-- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 2 -- 4 files changed, 36 insertions(+), 38 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp diff --git a/include/ck/tensor_operation/gpu/device/device_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm.hpp index 72b79e85316..b1b50a082fc 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm.hpp @@ -1,6 +1,4 @@ -#ifndef DEVICE_GEMM_HPP -#define DEVICE_GEMM_HPP - +#pragma once #include #include "device_base.hpp" @@ -8,35 +6,6 @@ namespace ck { namespace tensor_operation { namespace device { -template -struct DeviceGemmBias : public BaseOperator -{ - virtual std::unique_ptr - MakeArgumentPointer(const void* p_a, - const void* p_b, - const void* p_bias, - void* p_c, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideB, - ck::index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - -template -using DeviceGemmBiasPtr = std::unique_ptr< - DeviceGemmBias>; - template @@ -68,4 +37,3 @@ using DeviceGemmPtr = std::unique_ptr< } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp new file mode 100644 index 00000000000..27d5ce9d8a1 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp @@ -0,0 +1,34 @@ +#pragma once +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmBias : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp index d1e0d6d84ef..ac2c2ec25f6 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp @@ -4,9 +4,7 @@ #include #include #include "device.hpp" -#include "device_base.hpp" -#include "device_gemm.hpp" -#include "device_gemm_xdl.hpp" +#include "device_gemm_bias.hpp" #include "common_header.hpp" #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 654ba4e2be9..9a40bbdb29b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -686,7 +686,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), c_element_op}; -#if 1 // LDS c_reduce_block_desc_mperblock_nperblock constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, @@ -769,7 +768,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_multi_index(block_work_idx[I0], // mblock c_reduce_thread_data_idx_begin[I0]), // mperblock ck::tensor_operation::element_wise::PassThrough{}}; -#endif constexpr auto mperblock_forward_step = make_multi_index(0, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 0, 0); From 9d8c71970154b5b5cf6dbeb785a62e6ac96b93cf Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 14 Mar 2022 05:17:31 +0000 Subject: [PATCH 07/32] use sfc in shuffling --- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 200 ++++++++---------- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 113 +++++----- .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 1 + .../ck/utility/tensor_space_filling_curve.hpp | 9 + 4 files changed, 153 insertions(+), 170 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 9a40bbdb29b..cfd8acd4c35 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -769,122 +769,106 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_reduce_thread_data_idx_begin[I0]), // mperblock ck::tensor_operation::element_wise::PassThrough{}}; - constexpr auto mperblock_forward_step = - make_multi_index(0, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 0, 0); - constexpr auto nperblock_forward_step = - make_multi_index(0, 0, 0, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl); - constexpr auto nperblock_backward_step = - make_multi_index(0, 0, 0, -CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl); - - static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { - constexpr auto mxdlperwave = mxdlperwave_iter; - - static_for<0, NXdlPerWave, CShuffleNXdlPerWavePerShuffle>{}( - [&](auto nxdlperwave_iter) { - constexpr bool nxdlperwave_forward_sweep = - (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); - - constexpr index_t nxdlperwave_value = - nxdlperwave_forward_sweep - ? nxdlperwave_iter - : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); - - constexpr auto nxdlperwave = Number{}; - - // make sure it's safe to do ds_write - block_sync_lds(); - - // VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to do ds_read - block_sync_lds(); - - // LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - -#if 1 - // reduce - { - // LDS to VGPR - c_reduce_thread_copy_lds_to_vgpr.Run( - c_reduce_block_desc_mperblock_nperblock, - c_shuffle_block_buf, - c_in_reduce_thread_desc_mperblock_nperblock, - make_tuple(I0, I0), - c_in_reduce_thread_buf); - - // reduce in VGPR - // loop over MPerBlock - // FIXME: hardcode - static_for<0, 1, 1>{}([&](auto im) { - // FIXME: Accumulator for reduction may use a different datatype? - FloatCShuffle v = 0; - - // loop over NPerBlock - // FIXME: hardcode - static_for<0, 16, 1>{}([&](auto in) { - constexpr index_t in_offset = - c_in_reduce_thread_desc_mperblock_nperblock.CalculateOffset( - make_tuple(im, in)); - - v += c_in_reduce_thread_buf[Number{}]; - }); - - constexpr index_t out_offset = - c_out_reduce_thread_desc_mperblock.CalculateOffset( - make_tuple(im)); - - c_out_reduce_thread_buf(Number{}) = v; - }); - - // VGPR to Global - c_reduce_thread_copy_vgpr_to_global.Run( - c_out_reduce_thread_desc_mblock_mperblock, - make_tuple(I0, I0), - c_out_reduce_thread_buf, - d_grid_desc_mblock_mperblock, - d_grid_buf); - } -#endif - - // move on nxdlperwave dimension - if constexpr(nxdlperwave_forward_sweep && - (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) - { - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nperblock_forward_step); - } - else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) - { - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nperblock_backward_step); - } + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // reduce + { + // LDS to VGPR + c_reduce_thread_copy_lds_to_vgpr.Run( + c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_in_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_in_reduce_thread_buf); + + // reduce in VGPR + // loop over MPerBlock + // FIXME: hardcode + static_for<0, 1, 1>{}([&](auto im) { + // FIXME: Accumulator for reduction may use a different datatype? + FloatCShuffle v = 0; + + // loop over NPerBlock + // FIXME: hardcode + static_for<0, 16, 1>{}([&](auto in) { + constexpr index_t in_offset = + c_in_reduce_thread_desc_mperblock_nperblock.CalculateOffset( + make_tuple(im, in)); + + v += c_in_reduce_thread_buf[Number{}]; + }); + + constexpr index_t out_offset = + c_out_reduce_thread_desc_mperblock.CalculateOffset(make_tuple(im)); + + c_out_reduce_thread_buf(Number{}) = v; }); - // move on mxdlperwave dimension - if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + // VGPR to Global + c_reduce_thread_copy_vgpr_to_global.Run( + c_out_reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + c_out_reduce_thread_buf, + d_grid_desc_mblock_mperblock, + d_grid_buf); + } + + if constexpr(access_id < num_access - 1) { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, mperblock_forward_step); + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); -#if 1 // reduce: move c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( d_grid_desc_mblock_mperblock, - make_multi_index(0, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl)); -#endif + make_tuple(c_global_step[I0], c_global_step[I1])); } }); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 7ba5f7edbe2..42b0d071e18 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -655,70 +655,59 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), c_element_op}; - constexpr auto mperblock_forward_step = - make_multi_index(0, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 0, 0); - constexpr auto nperblock_forward_step = - make_multi_index(0, 0, 0, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl); - constexpr auto nperblock_backward_step = - make_multi_index(0, 0, 0, -CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl); - - static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { - constexpr auto mxdlperwave = mxdlperwave_iter; - - static_for<0, NXdlPerWave, CShuffleNXdlPerWavePerShuffle>{}( - [&](auto nxdlperwave_iter) { - constexpr bool nxdlperwave_forward_sweep = - (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); - - constexpr index_t nxdlperwave_value = - nxdlperwave_forward_sweep - ? nxdlperwave_iter - : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); - - constexpr auto nxdlperwave = Number{}; - - // make sure it's safe to do ds_write - block_sync_lds(); - - // VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to do ds_read - block_sync_lds(); - - // LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - // move on nxdlperwave dimension - if constexpr(nxdlperwave_forward_sweep && - (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) - { - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nperblock_forward_step); - } - else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) - { - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nperblock_backward_step); - } - }); - - // move on mxdlperwave dimension - if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, mperblock_forward_step); + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); } }); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 49a5f9f3e30..bf89bfe681b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -10,6 +10,7 @@ #include "blockwise_tensor_slice_transfer_v6r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "gridwise_gemm_pipeline_v1.hpp" +#include "tensor_space_filling_curve.hpp" namespace ck { diff --git a/include/ck/utility/tensor_space_filling_curve.hpp b/include/ck/utility/tensor_space_filling_curve.hpp index c5cbe461f0b..4eebcd86e34 100644 --- a/include/ck/utility/tensor_space_filling_curve.hpp +++ b/include/ck/utility/tensor_space_filling_curve.hpp @@ -140,6 +140,15 @@ struct SpaceFillingCurve }(); return idx_md; } + + // FIXME: rename this function + template + static __device__ __host__ constexpr auto GetIndexTupleOfNumber(Number) + { + constexpr auto idx = GetIndex(Number{}); + + return generate_tuple([&](auto i) { return Number{}; }, Number{}); + } }; } // namespace ck From dd2da851083584c09c1f6efc20f763984c3be99b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 15 Mar 2022 21:52:27 +0000 Subject: [PATCH 08/32] remove hardcode --- .../14_gemm_reduce/gemm_xdl_reduce_fp16.cpp | 12 +- .../device_gemm_reduce_xdl_cshuffle.hpp | 26 ++-- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 129 ++++++++---------- 3 files changed, 79 insertions(+), 88 deletions(-) diff --git a/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp b/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp index 51790788e63..18d82b0248c 100644 --- a/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp +++ b/example/14_gemm_reduce/gemm_xdl_reduce_fp16.cpp @@ -46,12 +46,12 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // clang-format off using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle -//######|AData| BData| CData| AccData| CShuffle| DData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| Type| Type| Type| Type| DataType| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| -//######| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < F16, F16, F16, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; -// < F16, F16, F16, F32, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 2, 2, S<1, 16, 1,16>, 8>; +//######|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| ALayout| BLayout| CLayout| A| B| C| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| Type| Type| Type| DataType| DataType| DataType| Type| | | | Elementwise| Elementwise| Elementwise| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < F16, F16, F16, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// < F16, F16, F16, F32, F16, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 2, 2, S<1, 16, 1,16>, 8, S<64, 4>, 4, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 6d7a85a8050..e3f2e8d6518 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -16,8 +16,9 @@ namespace device { template + typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, + index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock> struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce { @@ -138,11 +142,11 @@ struct DeviceGemmReduce_Xdl_CShuffle // GridwiseGemm using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< - BlockSize, ADataType, // TODO: distinguish A/B datatype - AccDataType, + GemmAccDataType, CShuffleDataType, CDataType, + ReduceAccDataType, DDataType, InMemoryDataOperationEnum_t::Set, AGridDesc_K0_M_K1, @@ -152,6 +156,8 @@ struct DeviceGemmReduce_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, + NumGemmKPrefetchStage, + BlockSize, MPerBlock, NPerBlock, KPerBlock, @@ -168,7 +174,7 @@ struct DeviceGemmReduce_Xdl_CShuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, - ABlockLdsAddExtraM, + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, @@ -176,12 +182,14 @@ struct DeviceGemmReduce_Xdl_CShuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, - BBlockLdsAddExtraN, + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, - NumPrefetch>; + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>; // Argument struct Argument : public BaseArgument diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index cfd8acd4c35..69cecd7153b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -60,11 +60,11 @@ __global__ void block_2_ctile_map); } -template + typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, + index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock> struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { static constexpr auto I0 = Number<0>{}; @@ -123,46 +127,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { - constexpr auto max_lds_align = AK1; - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(AK0, Number{}, AK1), max_lds_align); - } - }(); - - return a_block_desc_ak0_m_ak1; + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); } __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { - constexpr auto max_lds_align = BK1; - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(BK0, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(BK0, Number{}, BK1), max_lds_align); - } - }(); - - return b_block_desc_bk0_n_bk1; + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); } __host__ __device__ static constexpr auto @@ -232,12 +208,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) return false; - // check NumPrefetch - if constexpr(NumPrefetch == 1) + // check NumGemmKPrefetchStage + if constexpr(NumGemmKPrefetchStage == 1) { // 1-stage prefetch always supported } - else if constexpr(NumPrefetch == 2) + else if constexpr(NumGemmKPrefetchStage == 2) { // 2-stage prefetch currently only support even number of K0 loop // TODO: add support for odd number of K0 loop @@ -279,7 +255,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - const bool has_main_k0_block_loop = ((K0 * AK1) / (NumPrefetch * KPerBlock)) > 1; + const bool has_main_k0_block_loop = ((K0 * AK1) / (NumGemmKPrefetchStage * KPerBlock)) > 1; return has_main_k0_block_loop; } @@ -431,7 +407,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 1, AThreadTransferSrcResetCoordinateAfterRun, true, - NumPrefetch>( + NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -462,7 +438,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 1, BThreadTransferSrcResetCoordinateAfterRun, true, - NumPrefetch>( + NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -483,7 +459,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, remove_cvref_t, remove_cvref_t, - NumPrefetch, + NumGemmKPrefetchStage, HasMainK0BlockLoop>{}; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( @@ -628,9 +604,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); - // VGPR to LDS + // shuffle: threadwise copy C from VGPR to LDS auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); - // FIXME: hardcode - using CReduceThreadClusterLengths_MPerBlock_NPerBlock = Sequence<64, 4>; - constexpr auto c_in_reduce_thread_lengths_mperblock_nperblock = Sequence<1, 16>{}; + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_in_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; // VGPR c_in_reduce_thread_desc_mperblock_nperblock constexpr auto c_in_reduce_thread_desc_mperblock_nperblock = make_naive_tensor_descriptor_packed( - sequence_to_tuple_of_number(c_in_reduce_thread_lengths_mperblock_nperblock)); + make_tuple(Number{}, Number{})); // VGPR c_out_reduce_thread_desc_mperblock - // FIXME: hardcode constexpr auto c_out_reduce_thread_desc_mperblock = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); + make_naive_tensor_descriptor_packed(make_tuple(Number{})); // VGPR c_out_reduce_thread_desc_mblock_mperblock - // FIXME: hardcode constexpr auto c_out_reduce_thread_desc_mblock_mperblock = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); // TODO: this should be implemented as a blockwise reduction - // FIXME: should use a different datatype? auto c_in_reduce_thread_buf = make_static_buffer( c_in_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); @@ -728,7 +716,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_static_buffer( c_out_reduce_thread_desc_mperblock.GetElementSpaceSize()); - // reduce: copy from LDS to VGPR + // reduce: threadwise copy from LDS to VGPR constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); @@ -747,7 +735,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 decltype(c_in_reduce_thread_lengths_mperblock_nperblock), Sequence<0, 1>, 1, - 8, // FIXME: hardcoded + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, 1, true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; @@ -758,10 +746,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 decltype(c_out_reduce_thread_desc_mblock_mperblock), decltype(d_grid_desc_mblock_mperblock), ck::tensor_operation::element_wise::PassThrough, - Sequence<1, 1>, // FIXME: hardcode + Sequence<1, mreduce_per_thread>, Sequence<0, 1>, 1, - 1, // FIXME: hardcode + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, InMemoryDataOperationEnum_t::AtomicAdd, 1, false>{d_grid_desc_mblock_mperblock, @@ -827,15 +815,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_in_reduce_thread_buf); // reduce in VGPR - // loop over MPerBlock - // FIXME: hardcode - static_for<0, 1, 1>{}([&](auto im) { - // FIXME: Accumulator for reduction may use a different datatype? - FloatCShuffle v = 0; - - // loop over NPerBlock - // FIXME: hardcode - static_for<0, 16, 1>{}([&](auto in) { + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + FloatReduceAcc v = 0; // FloatReduceAcc + + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { constexpr index_t in_offset = c_in_reduce_thread_desc_mperblock_nperblock.CalculateOffset( make_tuple(im, in)); From 073eeed59c5b0e7a0ed2245c5686d8ca2f6905e9 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 15 Mar 2022 22:13:51 +0000 Subject: [PATCH 09/32] remove hardcode --- example/01_gemm/gemm_xdl_fp16.cpp | 13 ++- .../gpu/device/device_gemm_xdl_c_shuffle.hpp | 101 +++++++++--------- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 72 ++++--------- 3 files changed, 76 insertions(+), 110 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 2f3d82ec9f0..5f73ca03f0b 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -54,13 +54,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; #elif 1 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle -//######|AData| BData| CData| AccData| Shuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| Type| Type| Type| Type| Data| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| -//######| | | | | Type| | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -// < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; - < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; -// < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 2, 2, S<1, 16, 1,16>, 8>; +//######|AData| BData| CData| GemmAcc| CShuffle| ALayout| BLayout| CLayout| A| B| C| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| DataType| DataType| | | | Elementwise| Elementwise| Elementwise| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| +//######| | | | | | | | | Operation| Operation| Operation| Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 2, 2, S<1, 16, 1,16>, 8>; #elif 0 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp index 5cd1bb3e782..a15a088a357 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp @@ -16,7 +16,7 @@ namespace device { template + index_t CShuffleBlockTransferScalarPerVector_NPerBlock> struct DeviceGemmXdl_C_Shuffle : public DeviceGemm { @@ -129,54 +129,49 @@ struct DeviceGemmXdl_C_Shuffle using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = -#if 0 - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v4r1 -#else - GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 -#endif - ; + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock>; // Argument struct Argument : public BaseArgument diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 42b0d071e18..7874ddebb57 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -53,9 +53,8 @@ __global__ void block_2_ctile_map); } -template + index_t CShuffleBlockTransferScalarPerVector_NPerBlock> struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { static constexpr auto I0 = Number<0>{}; @@ -114,46 +114,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { - constexpr auto max_lds_align = AK1; - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(AK0, Number{}, AK1), max_lds_align); - } - }(); - - return a_block_desc_ak0_m_ak1; + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); } __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { - constexpr auto max_lds_align = BK1; - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(BK0, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(BK0, Number{}, BK1), max_lds_align); - } - }(); - - return b_block_desc_bk0_n_bk1; + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); } __host__ __device__ static constexpr auto @@ -223,12 +195,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) return false; - // check NumPrefetch - if constexpr(NumPrefetch == 1) + // check NumGemmKPrefetchStage + if constexpr(NumGemmKPrefetchStage == 1) { // 1-stage prefetch always supported } - else if constexpr(NumPrefetch == 2) + else if constexpr(NumGemmKPrefetchStage == 2) { // 2-stage prefetch currently only support even number of K0 loop // TODO: add support for odd number of K0 loop @@ -270,7 +242,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - const bool has_main_k0_block_loop = ((K0 * AK1) / (NumPrefetch * KPerBlock)) > 1; + const bool has_main_k0_block_loop = ((K0 * AK1) / (NumGemmKPrefetchStage * KPerBlock)) > 1; return has_main_k0_block_loop; } @@ -400,7 +372,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 1, AThreadTransferSrcResetCoordinateAfterRun, true, - NumPrefetch>( + NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -431,7 +403,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 1, BThreadTransferSrcResetCoordinateAfterRun, true, - NumPrefetch>( + NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -452,7 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, remove_cvref_t, remove_cvref_t, - NumPrefetch, + NumGemmKPrefetchStage, HasMainK0BlockLoop>{}; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( @@ -599,7 +571,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // VGPR to LDS auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3 Date: Tue, 15 Mar 2022 22:44:10 +0000 Subject: [PATCH 10/32] refactor --- example/01_gemm/gemm_xdl_fp16.cpp | 10 +- .../gpu/device/device_gemm_xdl_c_shuffle.hpp | 250 +++++---- .../device/device_gemm_xdl_cshuffle_v1.hpp | 481 ++++++++++++++++++ .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 7 +- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 7 +- 5 files changed, 624 insertions(+), 131 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v1.hpp diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 5f73ca03f0b..f894639a622 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -13,6 +13,7 @@ #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" #include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle_v1.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -52,8 +53,15 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // [256, 128, 4, 8], 1 stage, 2 occupancy < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; -#elif 1 +#elif 0 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle +//######|AData| BData| CData| AccData| Shuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| Data| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | Type| | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; +#elif 1 +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v1 //######|AData| BData| CData| GemmAcc| CShuffle| ALayout| BLayout| CLayout| A| B| C| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| Type| Type| Type| DataType| DataType| | | | Elementwise| Elementwise| Elementwise| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| //######| | | | | | | | | Operation| Operation| Operation| Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp index a15a088a357..a335a327a16 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp @@ -1,4 +1,6 @@ -#pragma once +#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_HPP +#define DEVICE_GEMM_XDL_C_SHUFFLE_HPP + #include #include #include "device.hpp" @@ -7,52 +9,53 @@ #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "gridwise_gemm_xdlops_v3r1.hpp" namespace ck { namespace tensor_operation { namespace device { -template +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename CShuffleDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t KPerBlock, + ck::index_t AK1, + ck::index_t BK1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumPrefetch = 1> struct DeviceGemmXdl_C_Shuffle : public DeviceGemm { @@ -129,9 +132,10 @@ struct DeviceGemmXdl_C_Shuffle using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< + BlockSize, ADataType, // TODO: distinguish A/B datatype - GemmAccDataType, + AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum_t::Set, @@ -141,8 +145,6 @@ struct DeviceGemmXdl_C_Shuffle AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, - NumGemmKPrefetchStage, - BlockSize, MPerBlock, NPerBlock, KPerBlock, @@ -159,7 +161,7 @@ struct DeviceGemmXdl_C_Shuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, - ABlockLdsExtraM, + ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, @@ -167,11 +169,12 @@ struct DeviceGemmXdl_C_Shuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, - BBlockLdsExtraN, + BBlockLdsAddExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock>; + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl, + NumPrefetch>; // Argument struct Argument : public BaseArgument @@ -196,7 +199,7 @@ struct DeviceGemmXdl_C_Shuffle a_grid_desc_k0_m_k1_{}, b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, block_2_ctile_map_{}, M01_{M01}, N01_{N01}, @@ -213,9 +216,10 @@ struct DeviceGemmXdl_C_Shuffle if(GridwiseGemm::CheckValidity( a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); @@ -229,8 +233,9 @@ struct DeviceGemmXdl_C_Shuffle AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; index_t M01_; index_t N01_; @@ -279,79 +284,71 @@ struct DeviceGemmXdl_C_Shuffle if(has_main_k0_block_loop) { - const auto kernel = -#if 0 - kernel_gemm_xdlops_v4r1 -#else - kernel_gemm_xdl_cshuffle_v1 -#endif - , - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); } else { - const auto kernel = -#if 0 - kernel_gemm_xdlops_v4r1 -#else - kernel_gemm_xdl_cshuffle_v1 -#endif - , - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); } return ave_time; @@ -477,3 +474,4 @@ struct DeviceGemmXdl_C_Shuffle } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v1.hpp new file mode 100644 index 00000000000..8970f93172c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,481 @@ +#pragma once +#include +#include +#include "device.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdl_cshuffle_v1.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffle_v1 + : public DeviceGemm +{ + using DeviceOp = DeviceGemm_Xdl_CShuffle_v1; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % AK1 == 0); + + const index_t K0 = K / AK1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_k0_m_k1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, AK1)), make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_k0_m_k1; + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % BK1 == 0); + + const index_t K0 = K / BK1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_k0_n_k1 = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, BK1)), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_k0_n_k1; + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = +#if 0 + kernel_gemm_xdlops_v4r1 +#else + kernel_gemm_xdl_cshuffle_v1 +#endif + , + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = + launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = +#if 0 + kernel_gemm_xdlops_v4r1 +#else + kernel_gemm_xdl_cshuffle_v1 +#endif + , + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = + launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemm_Xdl_CShuffle_v1" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 69cecd7153b..15f2fc8caca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -164,11 +164,14 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1); + math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1); + math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 7874ddebb57..d2c059c2103 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -151,11 +151,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1); + math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1); + math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = From 0a83402453c7ec11fc4000bb0985fa2a6dc8324d Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 16 Mar 2022 04:01:46 +0000 Subject: [PATCH 11/32] fix build --- include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp index 27d5ce9d8a1..9f5d16a1f9b 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp @@ -29,6 +29,12 @@ struct DeviceGemmBias : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; }; +template +using DeviceGemmBiasPtr = std::unique_ptr< + DeviceGemmBias>; + } // namespace device } // namespace tensor_operation } // namespace ck From b32e19abbdcf86a0c553650a86a5f6f74b562154 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 16 Mar 2022 23:15:37 +0000 Subject: [PATCH 12/32] adding gemm+reduce --- .../15_gemm_reduce/gemm_xdl_reduce_fp16.cpp | 38 +++-- .../gpu/device/device_gemm_reduce.hpp | 40 ++--- .../device_gemm_reduce_xdl_cshuffle.hpp | 150 +++++++++++++----- .../device/device_gemm_xdl_cshuffle_v1.hpp | 8 +- .../gpu/device/gemm_specialization.hpp | 6 + .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 54 ++++--- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 8 +- .../include/ck/library/host_tensor/device.hpp | 2 + .../ck/library/host_tensor/host_tensor.hpp | 8 +- library/src/host_tensor/device.cpp | 4 + 10 files changed, 208 insertions(+), 110 deletions(-) diff --git a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp index 18d82b0248c..8b42f350b50 100644 --- a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp +++ b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp @@ -16,6 +16,14 @@ #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" +#include "reduction_enums.hpp" + +struct ReduceSum +{ + __host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); } + + __host__ __device__ void Reduce(float& acc, float v) const { acc += v; } +}; template using S = ck::Sequence; @@ -26,13 +34,10 @@ using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; -using DDataType = float; +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using DDataType = F32; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; @@ -41,17 +46,17 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using DReduceOp = ReduceSum; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; // clang-format off using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle -//######|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| ALayout| BLayout| CLayout| A| B| C| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| -//######| Type| Type| Type| DataType| DataType| DataType| Type| | | | Elementwise| Elementwise| Elementwise| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < F16, F16, F16, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; -// < F16, F16, F16, F32, F16, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 2, 2, S<1, 16, 1,16>, 8, S<64, 4>, 4, 1>; +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, DReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: @@ -153,9 +158,13 @@ int main(int argc, char* argv[]) a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + // init to 0 + d_m_device_buf.SetZero(); + auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; + auto d_reduce_op = DReduceOp{}; // do GEMM auto gemm = DeviceGemmReduceInstance{}; @@ -172,7 +181,8 @@ int main(int argc, char* argv[]) StrideC, a_element_op, b_element_op, - c_element_op); + c_element_op, + d_reduce_op); if(!gemm.IsSupportedArgument(argument)) { diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp index 88246fb8754..cab4ddb0bd5 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -8,32 +8,36 @@ namespace device { template + typename CElementwiseOperation, + typename DReduceOperation> struct DeviceGemmReduce : public BaseOperator { - virtual std::unique_ptr - MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - void* p_d, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideB, - ck::index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) = 0; + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + void* p_d, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DReduceOperation d_reduce_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; template -using DeviceGemmReducePtr = std::unique_ptr< - DeviceGemmReduce>; + typename CElementwiseOperation, + typename DReduceOperation> +using DeviceGemmReducePtr = std::unique_ptr>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index e3f2e8d6518..4acea641bda 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -13,19 +13,21 @@ namespace ck { namespace tensor_operation { namespace device { -template -struct DeviceGemmReduce_Xdl_CShuffle - : public DeviceGemmReduce +struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA) { - assert(K % AK1 == 0); - - const index_t K0 = K / AK1; - - const auto a_grid_desc_m_k = [&]() { + const auto a_grid_desc_mraw_kraw = [&]() { if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); } }(); - const auto a_grid_desc_k0_m_k1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, AK1)), make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; - return a_grid_desc_k0_m_k1; + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding || + GemmSpecialization == GemmSpecialization_t::MNKPadding) + { + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || + GemmSpecialization == GemmSpecialization_t::MNPadding) + { + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } } static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) @@ -148,14 +203,16 @@ struct DeviceGemmReduce_Xdl_CShuffle CDataType, ReduceAccDataType, DDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DReduceOperation, InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum_t::AtomicAdd, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, DGridDesc_M, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, @@ -208,7 +265,8 @@ struct DeviceGemmReduce_Xdl_CShuffle index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + DReduceOperation d_reduce_op) : p_a_grid_{p_a_grid}, p_b_grid_{p_b_grid}, p_c_grid_{p_c_grid}, @@ -224,7 +282,8 @@ struct DeviceGemmReduce_Xdl_CShuffle N01_{N01}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, - c_element_op_{c_element_op} + c_element_op_{c_element_op}, + d_reduce_op_{d_reduce_op} { a_grid_desc_k0_m_k1_ = DeviceGemmReduce_Xdl_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); @@ -267,6 +326,7 @@ struct DeviceGemmReduce_Xdl_CShuffle AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; + DReduceOperation d_reduce_op_; }; // Invoker @@ -317,14 +377,15 @@ struct DeviceGemmReduce_Xdl_CShuffle ADataType, // TODO: distiguish A/B datatype CDataType, DDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DReduceOperation, remove_reference_t, remove_reference_t, remove_reference_t< typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, remove_reference_t, true>; @@ -338,13 +399,14 @@ struct DeviceGemmReduce_Xdl_CShuffle arg.p_b_grid_, arg.p_c_grid_, arg.p_d_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d_reduce_op_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.d_grid_desc_mblock_mperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, arg.block_2_ctile_map_); } else @@ -354,14 +416,15 @@ struct DeviceGemmReduce_Xdl_CShuffle ADataType, // TODO: distiguish A/B datatype CDataType, DDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DReduceOperation, remove_reference_t, remove_reference_t, remove_reference_t< typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, remove_reference_t, false>; @@ -375,13 +438,14 @@ struct DeviceGemmReduce_Xdl_CShuffle arg.p_b_grid_, arg.p_c_grid_, arg.p_d_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d_reduce_op_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.d_grid_desc_mblock_mperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, arg.block_2_ctile_map_); } @@ -428,7 +492,8 @@ struct DeviceGemmReduce_Xdl_CShuffle index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + DReduceOperation d_reduce_op) { return Argument{p_a, p_b, @@ -444,7 +509,8 @@ struct DeviceGemmReduce_Xdl_CShuffle 1, a_element_op, b_element_op, - c_element_op}; + c_element_op, + d_reduce_op}; } static auto MakeInvoker() { return Invoker{}; } @@ -462,7 +528,8 @@ struct DeviceGemmReduce_Xdl_CShuffle index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) override + CElementwiseOperation c_element_op, + DReduceOperation d_reduce_op) override { return std::make_unique(static_cast(p_a), static_cast(p_b), @@ -478,7 +545,8 @@ struct DeviceGemmReduce_Xdl_CShuffle 1, a_element_op, b_element_op, - c_element_op); + c_element_op, + d_reduce_op); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v1.hpp index 8970f93172c..e83767b1d2b 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v1.hpp @@ -206,11 +206,9 @@ struct DeviceGemm_Xdl_CShuffle_v1 b_element_op_{b_element_op}, c_element_op_{c_element_op} { - a_grid_desc_k0_m_k1_ = - DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = - DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC); + a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC); if(GridwiseGemm::CheckValidity( a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp index 37cc7b37824..0daf482cf74 100644 --- a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -8,7 +8,13 @@ namespace device { enum GemmSpecialization_t { Default, + MPadding, + NPadding, + KPadding, MNPadding, + MKPadding, + NKPadding, + MNKPadding, }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 15f2fc8caca..5174b298ddf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -15,13 +15,14 @@ template __global__ void @@ -33,14 +34,15 @@ __global__ void const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, FloatD* __restrict__ p_d_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const DReduceOperation d_reduce_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -50,13 +52,14 @@ __global__ void p_c_grid, p_d_grid, p_shared, + a_element_op, + b_element_op, + c_element_op, + d_reduce_op, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_mblock_mperblock_nblock_nperblock, d_grid_desc_mblock_mperblock, - a_element_op, - b_element_op, - c_element_op, block_2_ctile_map); } @@ -66,14 +69,16 @@ template ; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, FloatD* __restrict__ p_d_grid, void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const DReduceOperation& d_reduce_op, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, const Block2CTileMap& block_2_ctile_map) { const auto a_grid_buf = make_dynamic_buffer( @@ -753,7 +759,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 Sequence<0, 1>, 1, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, - InMemoryDataOperationEnum_t::AtomicAdd, + DGlobalMemoryDataOperation, 1, false>{d_grid_desc_mblock_mperblock, make_multi_index(block_work_idx[I0], // mblock @@ -819,20 +825,20 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // reduce in VGPR static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - FloatReduceAcc v = 0; // FloatReduceAcc + FloatReduceAcc r_acc = d_reduce_op.GetReduceZeroValue(); static_for<0, nreduce_per_thread, 1>{}([&](auto in) { constexpr index_t in_offset = c_in_reduce_thread_desc_mperblock_nperblock.CalculateOffset( make_tuple(im, in)); - v += c_in_reduce_thread_buf[Number{}]; + d_reduce_op.Reduce(r_acc, c_in_reduce_thread_buf[Number{}]); }); constexpr index_t out_offset = c_out_reduce_thread_desc_mperblock.CalculateOffset(make_tuple(im)); - c_out_reduce_thread_buf(Number{}) = v; + c_out_reduce_thread_buf(Number{}) = r_acc; }); // VGPR to Global diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index d2c059c2103..54e269c09d3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -154,11 +154,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // lds max alignment constexpr auto max_lds_align = math::lcm(AK1, BK1); - constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = diff --git a/library/include/ck/library/host_tensor/device.hpp b/library/include/ck/library/host_tensor/device.hpp index 87af0bbd784..cefb249d459 100644 --- a/library/include/ck/library/host_tensor/device.hpp +++ b/library/include/ck/library/host_tensor/device.hpp @@ -13,8 +13,10 @@ struct DeviceMem DeviceMem() = delete; DeviceMem(std::size_t mem_size); void* GetDeviceBuffer(); + std::size_t GetBufferSize(); void ToDevice(const void* p); void FromDevice(void* p); + void SetZero(); ~DeviceMem(); void* mpDeviceBuf; diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index f9f462d7fd8..ebd228964ef 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -341,13 +341,13 @@ void check_error(const Tensor& ref, const Tensor& result) { for(int i = 0; i < ref.mData.size(); ++i) { - error += std::abs(double(ref.mData[i]) - double(result.mData[i])); - float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + error += std::abs(float(ref.mData[i]) - float(result.mData[i])); + float diff = std::abs(float(ref.mData[i]) - float(result.mData[i])); if(max_diff < diff) { max_diff = diff; - ref_value = ref.mData[i]; - result_value = result.mData[i]; + ref_value = float(ref.mData[i]); + result_value = float(result.mData[i]); } } } diff --git a/library/src/host_tensor/device.cpp b/library/src/host_tensor/device.cpp index 0d1b3d6883b..3e80df80fba 100644 --- a/library/src/host_tensor/device.cpp +++ b/library/src/host_tensor/device.cpp @@ -7,6 +7,8 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } +std::size_t DeviceMem::GetBufferSize() { return mMemSize; } + void DeviceMem::ToDevice(const void* p) { hipGetErrorString( @@ -18,6 +20,8 @@ void DeviceMem::FromDevice(void* p) hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); } +void DeviceMem::SetZero() { hipGetErrorString(hipMemset(mpDeviceBuf, 0, mMemSize)); } + DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); } struct KernelTimerImpl From ce48abb6ddead53dc7b59509c26b2a1bc1d3df8a Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 17 Mar 2022 03:52:01 +0000 Subject: [PATCH 13/32] adding gemm+reduce --- .../device_gemm_reduce_xdl_cshuffle.hpp | 145 +++++++++++++++--- 1 file changed, 124 insertions(+), 21 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 4acea641bda..c762a20ba66 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -69,15 +69,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce{}; static constexpr auto I2 = Number<2>{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA) + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) { const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same::value) + if constexpr(is_same_v) { return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); @@ -93,6 +93,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } else { + // no pad assert(KRaw % AK1 == 0); const auto AK0 = KRaw / AK1; @@ -146,30 +172,107 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce::value) { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); } }(); - const auto b_grid_desc_k0_n_k1 = transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, BK1)), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding || + GemmSpecialization == GemmSpecialization_t::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || + GemmSpecialization == GemmSpecialization_t::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || + GemmSpecialization == GemmSpecialization_t::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // no pad + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; - return b_grid_desc_k0_n_k1; + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } } static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) @@ -190,8 +293,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce Date: Thu, 17 Mar 2022 04:38:27 +0000 Subject: [PATCH 14/32] adding gemm+reduce --- .../device_gemm_reduce_xdl_cshuffle.hpp | 223 +++++++++++------- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 31 +-- 2 files changed, 148 insertions(+), 106 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index c762a20ba66..0f5f1b8ec52 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -39,19 +39,19 @@ template ::value) + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding || + GemmSpecialization == GemmSpecialization_t::MNKPadding) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + // pad M and N + transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); } - else if constexpr(is_same::value) + else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || + GemmSpecialization == GemmSpecialization_t::MKPadding) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + // pad M, but not N + transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || + GemmSpecialization == GemmSpecialization_t::NKPadding) + { + // pad N, but not M + transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; } } // assume D is packed tensor - static auto MakeDGridDescriptor_M(index_t M) + static auto MakeDGridDescriptor_M(index_t MRaw) { - return make_naive_tensor_descriptor_packed(make_tuple(M)); + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || + GemmSpecialization == GemmSpecialization_t::MNPadding || + GemmSpecialization == GemmSpecialization_t::MKPadding || + GemmSpecialization == GemmSpecialization_t::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } } - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); // GridwiseGemm using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< @@ -312,8 +378,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t, @@ -506,8 +564,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t, @@ -545,8 +603,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce(p_b), static_cast(p_c), static_cast(p_d), - M, - N, - K, + MRaw, + NRaw, + KRaw, StrideA, StrideB, StrideC, - 1, - 1, a_element_op, b_element_op, c_element_op, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 5174b298ddf..58831744974 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -95,7 +95,7 @@ template >::value && // is_known_at_compile_time>::value, @@ -235,16 +232,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 return false; } - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) - return false; - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } @@ -304,7 +291,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) { const auto M = c_grid_desc_m_n.GetLength(I0); const auto N = c_grid_desc_m_n.GetLength(I1); @@ -315,6 +302,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 const auto M0 = M / M1; const auto N0 = N / N1; + // FIXME: remove + constexpr auto M01 = I1; + constexpr auto N01 = I1; + const auto M00 = M0 / M01; const auto N00 = N0 / N01; @@ -345,7 +336,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 remove_cvref_t; using DefaultBlock2CTileMap = - remove_cvref_t; + remove_cvref_t; template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, @@ -411,7 +402,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ABlockTransferSrcVectorDim, 2, ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, + ABlockTransferDstScalarPerVector_AK1, 1, 1, AThreadTransferSrcResetCoordinateAfterRun, @@ -442,7 +433,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 BBlockTransferSrcVectorDim, 2, BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, + BBlockTransferDstScalarPerVector_BK1, 1, 1, BThreadTransferSrcResetCoordinateAfterRun, From 6624172bd54d25c964a986505a7eb04e61d04e9d Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 17 Mar 2022 04:48:29 +0000 Subject: [PATCH 15/32] adding gemm+reduce --- example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp | 3 ++- .../gpu/device/device_gemm_reduce_xdl_cshuffle.hpp | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp index 8b42f350b50..9559e44cdf8 100644 --- a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp +++ b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp @@ -48,7 +48,8 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; using DReduceOp = ReduceSum; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; +//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::MNKPadding; // clang-format off using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 0f5f1b8ec52..31874d4a28b 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -300,7 +300,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce{}, Sequence<1>{}), @@ -310,7 +310,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce{}, Sequence<1>{}), @@ -320,7 +320,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce{}, Sequence<1>{}), From c5002e3576f29c0e0f99621c88ae97ea906bdcf8 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 17 Mar 2022 04:48:50 +0000 Subject: [PATCH 16/32] adding gemm+reduce --- example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp index 9559e44cdf8..8b42f350b50 100644 --- a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp +++ b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp @@ -48,8 +48,7 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; using DReduceOp = ReduceSum; -//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::MNKPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; // clang-format off using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle From 08e54dc82f150d8c2023ccc41a3dd72110af2fec Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 17 Mar 2022 04:49:19 +0000 Subject: [PATCH 17/32] format --- .../gpu/device/device_gemm_reduce_xdl_cshuffle.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 31874d4a28b..970ec515d6a 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -301,10 +301,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); } else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || GemmSpecialization == GemmSpecialization_t::MKPadding) From b11b5bb6a87ac998dcf854c2cfec8711ecae95da Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 17 Mar 2022 05:01:37 +0000 Subject: [PATCH 18/32] clean --- example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp | 13 +++++++------ .../gpu/device/device_gemm_reduce_xdl_cshuffle.hpp | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp index 8b42f350b50..8c62f7ba486 100644 --- a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp +++ b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp @@ -48,15 +48,16 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; using DReduceOp = ReduceSum; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization_t::Default; // clang-format off using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle -//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| -//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, DReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, DReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 970ec515d6a..c6b6161e618 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -93,7 +93,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce Date: Thu, 17 Mar 2022 18:32:02 +0000 Subject: [PATCH 19/32] adding gemm+reduce --- example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp index 8c62f7ba486..6529c3d7d8d 100644 --- a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp +++ b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp @@ -25,6 +25,13 @@ struct ReduceSum __host__ __device__ void Reduce(float& acc, float v) const { acc += v; } }; +struct ReduceSquareSum +{ + __host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); } + + __host__ __device__ void Reduce(float& acc, float v) const { acc += v * v; } +}; + template using S = ck::Sequence; @@ -46,7 +53,7 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; -using DReduceOp = ReduceSum; +using DReduceOp = ReduceSquareSum; static constexpr auto GemmSpecialization = ck::tensor_operation::device::GemmSpecialization_t::Default; @@ -220,12 +227,14 @@ int main(int argc, char* argv[]) for(int m = 0; m < M; ++m) { - d_m_host_result(m) = 0; + float r_acc = d_reduce_op.GetReduceZeroValue(); for(int n = 0; n < N; ++n) { - d_m_host_result(m) += c_m_n_host_result(m, n); + d_reduce_op.Reduce(r_acc, c_m_n_host_result(m, n)); } + + d_m_host_result(m) = r_acc; } check_error(c_m_n_host_result, c_m_n_device_result); From a0c85e86ab6084041a99d373eff9ade6fee23dc0 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 18 Mar 2022 19:00:19 +0000 Subject: [PATCH 20/32] adding profiler for gemm+reduce --- .../15_gemm_reduce/gemm_xdl_reduce_fp16.cpp | 18 +- .../device_gemm_reduce_xdl_cshuffle.hpp | 1 + ...le_v1.hpp => device_gemm_xdl_cshuffle.hpp} | 0 .../gpu/element/element_wise_operation.hpp | 6 +- .../element/element_wise_reduce_operation.hpp | 24 ++ .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 5 + .../gpu/CMakeLists.txt | 8 + .../gpu/gemm_reduce/CMakeLists.txt | 7 + ..._f16_f16_f16_f32_f32_km_kn_mn_instance.cpp | 63 +++++ ..._f16_f16_f16_f32_f32_km_nk_mn_instance.cpp | 63 +++++ ..._f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp | 63 +++++ ..._f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp | 63 +++++ profiler/CMakeLists.txt | 46 ++-- profiler/include/profile_gemm_reduce_impl.hpp | 251 ++++++++++++++++++ profiler/src/profile_gemm_reduce.cpp | 147 ++++++++++ profiler/src/profiler.cpp | 34 ++- 16 files changed, 744 insertions(+), 55 deletions(-) rename include/ck/tensor_operation/gpu/device/{device_gemm_xdl_cshuffle_v1.hpp => device_gemm_xdl_cshuffle.hpp} (100%) create mode 100644 include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp create mode 100644 profiler/include/profile_gemm_reduce_impl.hpp create mode 100644 profiler/src/profile_gemm_reduce.cpp diff --git a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp index 6529c3d7d8d..b3740c2bc87 100644 --- a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp +++ b/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp @@ -16,21 +16,7 @@ #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" -#include "reduction_enums.hpp" - -struct ReduceSum -{ - __host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); } - - __host__ __device__ void Reduce(float& acc, float v) const { acc += v; } -}; - -struct ReduceSquareSum -{ - __host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); } - - __host__ __device__ void Reduce(float& acc, float v) const { acc += v * v; } -}; +#include "element_wise_reduce_operation.hpp" template using S = ck::Sequence; @@ -53,7 +39,7 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; -using DReduceOp = ReduceSquareSum; +using DReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum; static constexpr auto GemmSpecialization = ck::tensor_operation::device::GemmSpecialization_t::Default; diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index c6b6161e618..e5cdde758f2 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -8,6 +8,7 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "gemm_specialization.hpp" namespace ck { namespace tensor_operation { diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp similarity index 100% rename from include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v1.hpp rename to include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 2c45d1f5441..6235d5fd1ee 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,7 +1,4 @@ -#ifndef CK_ELEMENT_WISE_OPERATION_HPP -#define CK_ELEMENT_WISE_OPERATION_HPP -#include "data_type.hpp" - +#pragma once #include "data_type.hpp" namespace ck { @@ -333,4 +330,3 @@ struct UnarySqrt } // namespace element_wise } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp new file mode 100644 index 00000000000..2b5df58aa88 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp @@ -0,0 +1,24 @@ +#pragma once +#include "data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct ReduceSum +{ + __host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); } + + __host__ __device__ void Reduce(float& acc, float v) const { acc += v; } +}; + +struct ReduceSquareSum +{ + __host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); } + + __host__ __device__ void Reduce(float& acc, float v) const { acc += v * v; } +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 58831744974..4fe54be7d83 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -675,6 +675,11 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == 0 && diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 52277f0ee3d..0637aa94053 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -16,10 +16,18 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/external/include/half ) +function(add_instance_library INSTANCE_NAME) + message("adding instance ${INSTANCE_NAME}") + add_library(${INSTANCE_NAME} SHARED ${ARGN}) + target_compile_features(${INSTANCE_NAME} PUBLIC) + set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) +endfunction(add_instance_library INSTANCE_NAME) + add_subdirectory(gemm) add_subdirectory(gemm_bias2d) add_subdirectory(gemm_bias_relu) add_subdirectory(gemm_bias_relu_add) +add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) add_subdirectory(conv1d_fwd) add_subdirectory(conv2d_fwd) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000000..9137f60a4ac --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -0,0 +1,7 @@ +set(DEVICE_GEMM_REDUCE_INSTANCE_SOURCE + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp +) + +add_instance_library(device_gemm_reduce_instance ${DEVICE_GEMM_REDUCE_INSTANCE_SOURCE}) +install(TARGETS device_gemm_reduce_instance LIBRARY DESTINATION lib) +clang_tidy_check(device_gemm_reduce_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000000..7aedd0ce3d9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,63 @@ +#include +#include "config.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "element_wise_reduce_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; +using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// c[m, n] = a[k, m] * b[k, n] +// d0[m] = reduce0(c[m, n]) +// d1[m] = reduce1(c[m, n]) +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000000..b93115318df --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,63 @@ +#include +#include "config.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "element_wise_reduce_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; +using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// c[m, n] = a[m, k] * b[n, k] +// d0[m] = reduce0(c[m, n]) +// d1[m] = reduce1(c[m, n]) +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..b93115318df --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,63 @@ +#include +#include "config.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "element_wise_reduce_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; +using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// c[m, n] = a[m, k] * b[n, k] +// d0[m] = reduce0(c[m, n]) +// d1[m] = reduce1(c[m, n]) +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..b93115318df --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,63 @@ +#include +#include "config.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "element_wise_reduce_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; +using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// c[m, n] = a[m, k] * b[n, k] +// d0[m] = reduce0(c[m, n]) +// d1[m] = reduce1(c[m, n]) +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 5e7156a3996..3a5086b310a 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -22,30 +22,32 @@ include_directories(BEFORE # ck_profiler set(PROFILER_SOURCE src/profiler.cpp - src/profile_gemm.cpp - src/profile_gemm_bias_2d.cpp - src/profile_gemm_bias_relu.cpp - src/profile_gemm_bias_relu_add.cpp - src/profile_batched_gemm.cpp - src/profile_conv_fwd.cpp - src/profile_conv_fwd_bias_relu.cpp - src/profile_conv_fwd_bias_relu_add.cpp - src/profile_conv_fwd_bias_relu_atomic_add.cpp - src/profile_conv_bwd_data.cpp - src/profile_reduce.cpp +# src/profile_gemm.cpp +# src/profile_gemm_bias_2d.cpp +# src/profile_gemm_bias_relu.cpp +# src/profile_gemm_bias_relu_add.cpp + src/profile_gemm_reduce.cpp +# src/profile_batched_gemm.cpp +# src/profile_conv_fwd.cpp +# src/profile_conv_fwd_bias_relu.cpp +# src/profile_conv_fwd_bias_relu_add.cpp +# src/profile_conv_fwd_bias_relu_atomic_add.cpp +# src/profile_conv_bwd_data.cpp +# src/profile_reduce.cpp ) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) -target_link_libraries(ckProfiler PRIVATE device_gemm_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) -target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance) -target_link_libraries(ckProfiler PRIVATE device_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) +#target_link_libraries(ckProfiler PRIVATE device_gemm_instance) +#target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) +#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) +#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) +#target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance) +#target_link_libraries(ckProfiler PRIVATE device_reduce_instance) diff --git a/profiler/include/profile_gemm_reduce_impl.hpp b/profiler/include/profile_gemm_reduce_impl.hpp new file mode 100644 index 00000000000..0732f160637 --- /dev/null +++ b/profiler/include/profile_gemm_reduce_impl.hpp @@ -0,0 +1,251 @@ +#pragma once +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "element_wise_reduce_operation.hpp" +#include "device_gemm_reduce.hpp" +#include "reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::ReduceSum>; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_reduce_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + // d_m[m] + Tensor d_m(HostTensorDescriptor( + std::vector({static_cast(M)}), std::vector({1}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "d_m: " << d_m.mDesc << std::endl; + + std::size_t num_thread = std::thread::hardware_concurrency(); + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using DReduceOp = ck::tensor_operation::element_wise::ReduceSum; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + const auto d_reduce_op = DReduceOp{}; + + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + // TODO + // cpu reduce + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d_device_buf(sizeof(CDataType) * d_m.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.SetZero(); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + {} + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(d_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + d_reduce_op); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + sizeof(CDataType) * M * N + sizeof(CDataType) * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + check_error(c_m_n_host_result, c_m_n_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c0 : ", d_m.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/profile_gemm_reduce.cpp b/profiler/src/profile_gemm_reduce.cpp new file mode 100644 index 00000000000..8335d0a33b2 --- /dev/null +++ b/profiler/src/profile_gemm_reduce.cpp @@ -0,0 +1,147 @@ +#include +#include +#include +#include +#include +#include +#include "profile_gemm_reduce_impl.hpp" + +int profile_gemm_reduce(int argc, char* argv[]) +{ + enum GemmMatrixLayout_t + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum GemmReduceDataType_t + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F32_F32, // 1 + }; + + if(!(argc == 14 || argc == 15)) + { + printf("arg1: tensor operation (gemm: GEMM+Reduce)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg7: run kernel # of times (>1)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + exit(1); + } + + const int data_type = static_cast(std::stoi(argv[2])); + const int layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const int nrepeat = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout_t::MK_KN_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout_t::MK_NK_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout_t::KM_KN_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout_t::KM_NK_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 1; +} diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 80ce1f83247..cba8df3f811 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -4,20 +4,22 @@ #include #include -int profile_gemm(int, char*[]); -int profile_batched_gemm(int, char*[]); -int profile_gemm_bias_2d(int, char*[]); -int profile_gemm_bias_relu(int, char*[]); -int profile_gemm_bias_relu_add(int, char*[]); -int profile_conv_fwd(int, char*[]); -int profile_conv_fwd_bias_relu(int, char*[]); -int profile_conv_fwd_bias_relu_add(int, char*[]); -int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); -int profile_conv_bwd_data(int, char*[]); -int profile_reduce(int, char*[]); +// int profile_gemm(int, char*[]); +// int profile_gemm_bias_2d(int, char*[]); +// int profile_gemm_bias_relu(int, char*[]); +// int profile_gemm_bias_relu_add(int, char*[]); +int profile_gemm_reduce(int, char*[]); +// int profile_batched_gemm(int, char*[]); +// int profile_conv_fwd(int, char*[]); +// int profile_conv_fwd_bias_relu(int, char*[]); +// int profile_conv_fwd_bias_relu_add(int, char*[]); +// int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); +// int profile_conv_bwd_data(int, char*[]); +// int profile_reduce(int, char*[]); int main(int argc, char* argv[]) { +#if 0 if(strcmp(argv[1], "gemm") == 0) { return profile_gemm(argc, argv); @@ -34,6 +36,12 @@ int main(int argc, char* argv[]) { return profile_gemm_bias_relu_add(argc, argv); } +#endif + if(strcmp(argv[1], "gemm_reduce") == 0) + { + return profile_gemm_reduce(argc, argv); + } +#if 0 else if(strcmp(argv[1], "batched_gemm") == 0) { return profile_batched_gemm(argc, argv); @@ -62,6 +70,7 @@ int main(int argc, char* argv[]) { return profile_reduce(argc, argv); } +#endif else { // clang-format off @@ -69,12 +78,13 @@ int main(int argc, char* argv[]) " gemm_bias_2d: GEMM+Bias(2D)\n" " gemm_bias_relu: GEMM+Bias+ReLU\n" " gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n" + " gemm_reduce: GEMM+Reduce\n" " conv_fwd: ForwardConvolution\n" " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" " conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n" " conv_bwd: BackwardConvolution\n" - " reduce: REDUCE\n"); + " reduce: Reduce\n"); // clang-format on return 0; From 1d481516e21bb0b77a8bcbeaec71c78db9ced67b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 20 Mar 2022 02:24:01 +0000 Subject: [PATCH 21/32] adding gemm+reduce profiler --- .../gpu/CMakeLists.txt | 24 +++---- .../gpu/gemm_reduce/CMakeLists.txt | 3 + ..._f16_f16_f16_f32_f32_km_kn_mn_instance.cpp | 43 ++++++------ ..._f16_f16_f16_f32_f32_km_nk_mn_instance.cpp | 45 ++++++------ ..._f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp | 43 ++++++------ ..._f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp | 40 +++++------ profiler/CMakeLists.txt | 44 ++++++------ profiler/include/profile_gemm_reduce_impl.hpp | 69 ++++++++++++++----- profiler/src/profiler.cpp | 28 ++++---- 9 files changed, 192 insertions(+), 147 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 0637aa94053..58ca590e573 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -23,16 +23,16 @@ function(add_instance_library INSTANCE_NAME) set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) endfunction(add_instance_library INSTANCE_NAME) -add_subdirectory(gemm) -add_subdirectory(gemm_bias2d) -add_subdirectory(gemm_bias_relu) -add_subdirectory(gemm_bias_relu_add) +#add_subdirectory(gemm) +#add_subdirectory(gemm_bias2d) +#add_subdirectory(gemm_bias_relu) +#add_subdirectory(gemm_bias_relu_add) add_subdirectory(gemm_reduce) -add_subdirectory(batched_gemm) -add_subdirectory(conv1d_fwd) -add_subdirectory(conv2d_fwd) -add_subdirectory(conv2d_fwd_bias_relu) -add_subdirectory(conv2d_fwd_bias_relu_add) -add_subdirectory(conv2d_fwd_bias_relu_atomic_add) -add_subdirectory(conv2d_bwd_data) -add_subdirectory(reduce) +#add_subdirectory(batched_gemm) +#add_subdirectory(conv1d_fwd) +#add_subdirectory(conv2d_fwd) +#add_subdirectory(conv2d_fwd_bias_relu) +#add_subdirectory(conv2d_fwd_bias_relu_add) +#add_subdirectory(conv2d_fwd_bias_relu_atomic_add) +#add_subdirectory(conv2d_bwd_data) +#add_subdirectory(reduce) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt index 9137f60a4ac..5bc6d17a93a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -1,5 +1,8 @@ set(DEVICE_GEMM_REDUCE_INSTANCE_SOURCE + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp ) add_instance_library(device_gemm_reduce_instance ${DEVICE_GEMM_REDUCE_INSTANCE_SOURCE}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp index 7aedd0ce3d9..9794a443e05 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -28,33 +28,36 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // c[m, n] = a[k, m] * b[k, n] // d0[m] = reduce0(c[m, n]) // d1[m] = reduce1(c[m, n]) -using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances = std::tuple< +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; -void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( std::vector>& instances) { add_device_operation_instances( - instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances{}); + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp index b93115318df..9b1243e5f69 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -25,36 +25,39 @@ using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; -// c[m, n] = a[m, k] * b[n, k] +// c[m, n] = a[k, m] * b[n, k] // d0[m] = reduce0(c[m, n]) // d1[m] = reduce1(c[m, n]) -using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances = std::tuple< +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; -void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( std::vector>& instances) { add_device_operation_instances( - instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances{}); + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp index b93115318df..b09cc0c3f03 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -28,33 +28,36 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // c[m, n] = a[m, k] * b[n, k] // d0[m] = reduce0(c[m, n]) // d1[m] = reduce1(c[m, n]) -using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances = std::tuple< +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + //###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; -void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( std::vector>& instances) { add_device_operation_instances( - instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances{}); + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp index b93115318df..35323ba5714 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -28,33 +28,33 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // c[m, n] = a[m, k] * b[n, k] // d0[m] = reduce0(c[m, n]) // d1[m] = reduce1(c[m, n]) -using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances = std::tuple< +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> // clang-format on >; -void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( std::vector>& instances) { add_device_operation_instances( - instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances{}); + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 3a5086b310a..8f09eb74f6b 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -22,32 +22,32 @@ include_directories(BEFORE # ck_profiler set(PROFILER_SOURCE src/profiler.cpp -# src/profile_gemm.cpp -# src/profile_gemm_bias_2d.cpp -# src/profile_gemm_bias_relu.cpp -# src/profile_gemm_bias_relu_add.cpp + src/profile_gemm.cpp + src/profile_gemm_bias_2d.cpp + src/profile_gemm_bias_relu.cpp + src/profile_gemm_bias_relu_add.cpp src/profile_gemm_reduce.cpp -# src/profile_batched_gemm.cpp -# src/profile_conv_fwd.cpp -# src/profile_conv_fwd_bias_relu.cpp -# src/profile_conv_fwd_bias_relu_add.cpp -# src/profile_conv_fwd_bias_relu_atomic_add.cpp -# src/profile_conv_bwd_data.cpp -# src/profile_reduce.cpp + src/profile_batched_gemm.cpp + src/profile_conv_fwd.cpp + src/profile_conv_fwd_bias_relu.cpp + src/profile_conv_fwd_bias_relu_add.cpp + src/profile_conv_fwd_bias_relu_atomic_add.cpp + src/profile_conv_bwd_data.cpp + src/profile_reduce.cpp ) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) -#target_link_libraries(ckProfiler PRIVATE device_gemm_instance) -#target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) -#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) -#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) -#target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) -#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) -#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) -#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) -#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) -#target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance) -#target_link_libraries(ckProfiler PRIVATE device_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) +target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance) +target_link_libraries(ckProfiler PRIVATE device_reduce_instance) diff --git a/profiler/include/profile_gemm_reduce_impl.hpp b/profiler/include/profile_gemm_reduce_impl.hpp index 0732f160637..840b77f054c 100644 --- a/profiler/include/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profile_gemm_reduce_impl.hpp @@ -22,7 +22,16 @@ using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePt ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::ReduceSum>; -void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( std::vector&); } // namespace device_gemm_instance @@ -67,17 +76,19 @@ void profile_gemm_reduce_impl(int do_verification, Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); - // d_m[m] - Tensor d_m(HostTensorDescriptor( - std::vector({static_cast(M)}), std::vector({1}))); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "d_m: " << d_m.mDesc << std::endl; + std::cout << "d_m: " << d_m_host_result.mDesc << std::endl; std::size_t num_thread = std::thread::hardware_concurrency(); switch(init_method) @@ -92,9 +103,6 @@ void profile_gemm_reduce_impl(int do_verification, b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } - // set zero to c_device_buf - c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); - using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -118,17 +126,28 @@ void profile_gemm_reduce_impl(int do_verification, ref_invoker.Run(ref_argument); - // TODO - // cpu reduce + for(int m = 0; m < M; ++m) + { + float r_acc = d_reduce_op.GetReduceZeroValue(); + + for(int n = 0; n < N; ++n) + { + d_reduce_op.Reduce(r_acc, c_m_n_host_result(m, n)); + } + + d_m_host_result(m) = r_acc; + } } DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem d_device_buf(sizeof(CDataType) * d_m.mDesc.GetElementSpace()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace()); a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); + + c_device_buf.SetZero(); d_device_buf.SetZero(); // add device GEMM instances @@ -141,24 +160,34 @@ void profile_gemm_reduce_impl(int do_verification, if constexpr(is_same::value && is_same::value && is_same::value) - {} + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( + gemm_ptrs); + } else if constexpr(is_same::value && is_same::value && is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_fp32_mk_nk_mn_instances( + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( + gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( + gemm_ptrs); } } @@ -195,6 +224,9 @@ void profile_gemm_reduce_impl(int do_verification, if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { + // set zero + d_device_buf.SetZero(); + std::string gemm_name = gemm_ptr->GetTypeString(); float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); @@ -222,18 +254,23 @@ void profile_gemm_reduce_impl(int do_verification, if(do_verification) { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + d_device_buf.FromDevice(d_m_device_result.mData.data()); check_error(c_m_n_host_result, c_m_n_device_result); + check_error(d_m_host_result, d_m_device_result); if(do_log) { LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c0 : ", d_m.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + LogRangeAsType(std::cout << "c_host: ", c_m_n_host_result.mData, ",") << std::endl; LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") << std::endl; + LogRangeAsType(std::cout << "d_host: ", d_m_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d_device: ", d_m_device_result.mData, ",") + << std::endl; } } } diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index cba8df3f811..14e21d9c3e7 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -4,22 +4,21 @@ #include #include -// int profile_gemm(int, char*[]); -// int profile_gemm_bias_2d(int, char*[]); -// int profile_gemm_bias_relu(int, char*[]); -// int profile_gemm_bias_relu_add(int, char*[]); +int profile_gemm(int, char*[]); +int profile_gemm_bias_2d(int, char*[]); +int profile_gemm_bias_relu(int, char*[]); +int profile_gemm_bias_relu_add(int, char*[]); int profile_gemm_reduce(int, char*[]); -// int profile_batched_gemm(int, char*[]); -// int profile_conv_fwd(int, char*[]); -// int profile_conv_fwd_bias_relu(int, char*[]); -// int profile_conv_fwd_bias_relu_add(int, char*[]); -// int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); -// int profile_conv_bwd_data(int, char*[]); -// int profile_reduce(int, char*[]); +int profile_batched_gemm(int, char*[]); +int profile_conv_fwd(int, char*[]); +int profile_conv_fwd_bias_relu(int, char*[]); +int profile_conv_fwd_bias_relu_add(int, char*[]); +int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); +int profile_conv_bwd_data(int, char*[]); +int profile_reduce(int, char*[]); int main(int argc, char* argv[]) { -#if 0 if(strcmp(argv[1], "gemm") == 0) { return profile_gemm(argc, argv); @@ -36,12 +35,10 @@ int main(int argc, char* argv[]) { return profile_gemm_bias_relu_add(argc, argv); } -#endif - if(strcmp(argv[1], "gemm_reduce") == 0) + else if(strcmp(argv[1], "gemm_reduce") == 0) { return profile_gemm_reduce(argc, argv); } -#if 0 else if(strcmp(argv[1], "batched_gemm") == 0) { return profile_batched_gemm(argc, argv); @@ -70,7 +67,6 @@ int main(int argc, char* argv[]) { return profile_reduce(argc, argv); } -#endif else { // clang-format off From c4bae349a39033e417e88ebae58ac2a02b9fc1fe Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 20 Mar 2022 03:19:53 +0000 Subject: [PATCH 22/32] fix build --- .../gpu/CMakeLists.txt | 24 +++++++++---------- .../include/profile_gemm_bias_2d_impl.hpp | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 58ca590e573..0637aa94053 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -23,16 +23,16 @@ function(add_instance_library INSTANCE_NAME) set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) endfunction(add_instance_library INSTANCE_NAME) -#add_subdirectory(gemm) -#add_subdirectory(gemm_bias2d) -#add_subdirectory(gemm_bias_relu) -#add_subdirectory(gemm_bias_relu_add) +add_subdirectory(gemm) +add_subdirectory(gemm_bias2d) +add_subdirectory(gemm_bias_relu) +add_subdirectory(gemm_bias_relu_add) add_subdirectory(gemm_reduce) -#add_subdirectory(batched_gemm) -#add_subdirectory(conv1d_fwd) -#add_subdirectory(conv2d_fwd) -#add_subdirectory(conv2d_fwd_bias_relu) -#add_subdirectory(conv2d_fwd_bias_relu_add) -#add_subdirectory(conv2d_fwd_bias_relu_atomic_add) -#add_subdirectory(conv2d_bwd_data) -#add_subdirectory(reduce) +add_subdirectory(batched_gemm) +add_subdirectory(conv1d_fwd) +add_subdirectory(conv2d_fwd) +add_subdirectory(conv2d_fwd_bias_relu) +add_subdirectory(conv2d_fwd_bias_relu_add) +add_subdirectory(conv2d_fwd_bias_relu_atomic_add) +add_subdirectory(conv2d_bwd_data) +add_subdirectory(reduce) diff --git a/profiler/include/profile_gemm_bias_2d_impl.hpp b/profiler/include/profile_gemm_bias_2d_impl.hpp index 94223c4f7a9..935725a808e 100644 --- a/profiler/include/profile_gemm_bias_2d_impl.hpp +++ b/profiler/include/profile_gemm_bias_2d_impl.hpp @@ -7,7 +7,7 @@ #include "tensor_layout.hpp" #include "device_tensor.hpp" #include "element_wise_operation.hpp" -#include "device_gemm.hpp" +#include "device_gemm_bias.hpp" #include "reference_gemm_bias_2d.hpp" namespace ck { From f4d3af4528fabcbc32ba141c53a3eadca7615066 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 20 Mar 2022 04:10:14 +0000 Subject: [PATCH 23/32] clean up --- example/15_gemm_reduce/CMakeLists.txt | 2 +- ...duce_fp16.cpp => gemm_reduce_xdl_fp16.cpp} | 41 +++++++++-------- .../device_gemm_reduce_xdl_cshuffle.hpp | 46 +++++++++---------- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 22 ++++----- 4 files changed, 58 insertions(+), 53 deletions(-) rename example/15_gemm_reduce/{gemm_xdl_reduce_fp16.cpp => gemm_reduce_xdl_fp16.cpp} (86%) diff --git a/example/15_gemm_reduce/CMakeLists.txt b/example/15_gemm_reduce/CMakeLists.txt index 569dca61519..08d37b34a6b 100644 --- a/example/15_gemm_reduce/CMakeLists.txt +++ b/example/15_gemm_reduce/CMakeLists.txt @@ -1 +1 @@ -add_example_executable(example_gemm_xdl_reduce_fp16 gemm_xdl_reduce_fp16.cpp) +add_example_executable(example_gemm_reduce_xdl_fp16 gemm_reduce_xdl_fp16.cpp) diff --git a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp b/example/15_gemm_reduce/gemm_reduce_xdl_fp16.cpp similarity index 86% rename from example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp rename to example/15_gemm_reduce/gemm_reduce_xdl_fp16.cpp index b3740c2bc87..b6db94eb7dd 100644 --- a/example/15_gemm_reduce/gemm_xdl_reduce_fp16.cpp +++ b/example/15_gemm_reduce/gemm_reduce_xdl_fp16.cpp @@ -39,7 +39,8 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; -using DReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum; +using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum; +using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum; static constexpr auto GemmSpecialization = ck::tensor_operation::device::GemmSpecialization_t::Default; @@ -50,7 +51,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_ //######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, DReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; + < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: @@ -58,8 +59,8 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: int main(int argc, char* argv[]) { - bool do_verification = 0; - int init_method = 0; + bool do_verification = 1; + int init_method = 1; int nrepeat = 5; // GEMM shape @@ -71,7 +72,11 @@ int main(int argc, char* argv[]) ck::index_t StrideB = 4096; ck::index_t StrideC = 4096; - if(argc == 4) + if(argc == 1) + { + // do nothing + } + else if(argc == 4) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -118,15 +123,15 @@ int main(int argc, char* argv[]) Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor d_m_host_result( + Tensor d0_m_host_result( HostTensorDescriptor(std::vector({static_cast(M)}))); - Tensor d_m_device_result( + Tensor d0_m_device_result( HostTensorDescriptor(std::vector({static_cast(M)}))); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "d_m: " << d_m_host_result.mDesc << std::endl; + std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; switch(init_method) { @@ -147,18 +152,18 @@ int main(int argc, char* argv[]) DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem d_m_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace()); + DeviceMem d0_m_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); // init to 0 - d_m_device_buf.SetZero(); + d0_m_device_buf.SetZero(); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; - auto d_reduce_op = DReduceOp{}; + auto d0_reduce_op = D0ReduceOp{}; // do GEMM auto gemm = DeviceGemmReduceInstance{}; @@ -166,7 +171,7 @@ int main(int argc, char* argv[]) auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), static_cast(b_k_n_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), - static_cast(d_m_device_buf.GetDeviceBuffer()), + static_cast(d0_m_device_buf.GetDeviceBuffer()), M, N, K, @@ -176,7 +181,7 @@ int main(int argc, char* argv[]) a_element_op, b_element_op, c_element_op, - d_reduce_op); + d0_reduce_op); if(!gemm.IsSupportedArgument(argument)) { @@ -201,7 +206,7 @@ int main(int argc, char* argv[]) if(do_verification) { c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - d_m_device_buf.FromDevice(d_m_device_result.mData.data()); + d0_m_device_buf.FromDevice(d0_m_device_result.mData.data()); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -213,18 +218,18 @@ int main(int argc, char* argv[]) for(int m = 0; m < M; ++m) { - float r_acc = d_reduce_op.GetReduceZeroValue(); + float d0_acc = d0_reduce_op.GetReduceZeroValue(); for(int n = 0; n < N; ++n) { - d_reduce_op.Reduce(r_acc, c_m_n_host_result(m, n)); + d0_reduce_op.Reduce(d0_acc, c_m_n_host_result(m, n)); } - d_m_host_result(m) = r_acc; + d0_m_host_result(m) = d0_acc; } check_error(c_m_n_host_result, c_m_n_device_result); - check_error(d_m_host_result, d_m_device_result); + check_error(d0_m_host_result, d0_m_device_result); } return 0; diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index e5cdde758f2..6013768d49c 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -27,7 +27,7 @@ template + D0ReduceOperation> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -376,7 +376,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, remove_reference_t, remove_reference_t< @@ -560,11 +560,11 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, remove_reference_t, remove_reference_t< @@ -599,11 +599,11 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce MakeArgumentPointer(const void* p_a, const void* p_b, void* p_c, - void* p_d, + void* p_d0, index_t MRaw, index_t NRaw, index_t KRaw, @@ -686,12 +686,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce(static_cast(p_a), static_cast(p_b), static_cast(p_c), - static_cast(p_d), + static_cast(p_d0), MRaw, NRaw, KRaw, @@ -701,7 +701,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce(p_a_grid, p_b_grid, p_c_grid, - p_d_grid, + p_d0_grid, p_shared, a_element_op, b_element_op, c_element_op, - d_reduce_op, + d0_reduce_op, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_mblock_mperblock_nblock_nperblock, @@ -72,7 +72,7 @@ template ( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); auto d_grid_buf = make_dynamic_buffer( - p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); + p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); // divide block work by [M, N] const auto block_work_idx = @@ -821,14 +821,14 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // reduce in VGPR static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - FloatReduceAcc r_acc = d_reduce_op.GetReduceZeroValue(); + FloatReduceAcc r_acc = d0_reduce_op.GetReduceZeroValue(); static_for<0, nreduce_per_thread, 1>{}([&](auto in) { constexpr index_t in_offset = c_in_reduce_thread_desc_mperblock_nperblock.CalculateOffset( make_tuple(im, in)); - d_reduce_op.Reduce(r_acc, c_in_reduce_thread_buf[Number{}]); + d0_reduce_op.Reduce(r_acc, c_in_reduce_thread_buf[Number{}]); }); constexpr index_t out_offset = From ae1e89d31d16a42354b5291b7d691c2a3652a1b1 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 20 Mar 2022 05:14:02 +0000 Subject: [PATCH 24/32] gemm+reduce --- .../15_gemm_reduce/gemm_reduce_xdl_fp16.cpp | 59 ++++++--- .../gpu/device/device_gemm_reduce.hpp | 15 ++- .../device_gemm_reduce_xdl_cshuffle.hpp | 117 ++++++++++-------- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 110 +++++++++------- .../include/ck/library/host_tensor/device.hpp | 8 -- ..._f16_f16_f16_f32_f32_km_kn_mn_instance.cpp | 44 +++---- ..._f16_f16_f16_f32_f32_km_nk_mn_instance.cpp | 44 +++---- ..._f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp | 44 +++---- ..._f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp | 38 +++--- profiler/include/profile_gemm_reduce_impl.hpp | 79 ++++++++---- 10 files changed, 335 insertions(+), 223 deletions(-) diff --git a/example/15_gemm_reduce/gemm_reduce_xdl_fp16.cpp b/example/15_gemm_reduce/gemm_reduce_xdl_fp16.cpp index b6db94eb7dd..c42f767c254 100644 --- a/example/15_gemm_reduce/gemm_reduce_xdl_fp16.cpp +++ b/example/15_gemm_reduce/gemm_reduce_xdl_fp16.cpp @@ -47,11 +47,11 @@ static constexpr auto GemmSpecialization = // clang-format off using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle -//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| -//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: @@ -121,17 +121,24 @@ int main(int argc, char* argv[]) Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor d0_m_host_result( HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor d0_m_device_result( HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; + std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; switch(init_method) { @@ -140,30 +147,26 @@ int main(int argc, char* argv[]) a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; - case 2: + default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); } DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem d0_m_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); + DeviceMem d1_m_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - // init to 0 - d0_m_device_buf.SetZero(); - auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; auto d0_reduce_op = D0ReduceOp{}; + auto d1_reduce_op = D1ReduceOp{}; // do GEMM auto gemm = DeviceGemmReduceInstance{}; @@ -172,6 +175,7 @@ int main(int argc, char* argv[]) static_cast(b_k_n_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), static_cast(d0_m_device_buf.GetDeviceBuffer()), + static_cast(d1_m_device_buf.GetDeviceBuffer()), M, N, K, @@ -181,7 +185,8 @@ int main(int argc, char* argv[]) a_element_op, b_element_op, c_element_op, - d0_reduce_op); + d0_reduce_op, + d1_reduce_op); if(!gemm.IsSupportedArgument(argument)) { @@ -190,7 +195,26 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, nrepeat); + // warm up + invoker.Run(argument); + + // timing + KernelTimer timer; + + timer.Start(); + + for(int i = 0; i < nrepeat; ++i) + { + // init DO, D1 to 0 + d0_m_device_buf.SetZero(); + d1_m_device_buf.SetZero(); + + invoker.Run(argument); + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = @@ -207,6 +231,7 @@ int main(int argc, char* argv[]) { c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); d0_m_device_buf.FromDevice(d0_m_device_result.mData.data()); + d1_m_device_buf.FromDevice(d1_m_device_result.mData.data()); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -219,17 +244,21 @@ int main(int argc, char* argv[]) for(int m = 0; m < M; ++m) { float d0_acc = d0_reduce_op.GetReduceZeroValue(); + float d1_acc = d1_reduce_op.GetReduceZeroValue(); for(int n = 0; n < N; ++n) { d0_reduce_op.Reduce(d0_acc, c_m_n_host_result(m, n)); + d1_reduce_op.Reduce(d1_acc, c_m_n_host_result(m, n)); } d0_m_host_result(m) = d0_acc; + d1_m_host_result(m) = d1_acc; } check_error(c_m_n_host_result, c_m_n_device_result); check_error(d0_m_host_result, d0_m_device_result); + check_error(d1_m_host_result, d1_m_device_result); } return 0; diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp index cab4ddb0bd5..76ea2fc864d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -9,13 +9,15 @@ namespace device { template + typename D0ReduceOperation, + typename D1ReduceOperation> struct DeviceGemmReduce : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, void* p_c, - void* p_d, + void* p_d0, + void* p_d1, ck::index_t M, ck::index_t N, ck::index_t K, @@ -25,7 +27,8 @@ struct DeviceGemmReduce : public BaseOperator AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, - DReduceOperation d_reduce_op) = 0; + D0ReduceOperation d0_reduce_op, + D1ReduceOperation d1_reduce_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; @@ -33,11 +36,13 @@ struct DeviceGemmReduce : public BaseOperator template + typename D0ReduceOperation, + typename D1ReduceOperation> using DeviceGemmReducePtr = std::unique_ptr>; + D0ReduceOperation, + D1ReduceOperation>>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 6013768d49c..e85a5b8f9b2 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -28,6 +28,7 @@ template + D0ReduceOperation, + D1ReduceOperation> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -377,6 +379,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, remove_reference_t, remove_reference_t< @@ -551,25 +561,25 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, true>; - ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_d0_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.d0_reduce_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.d_grid_desc_mblock_mperblock_, - arg.block_2_ctile_map_); + launch_kernel(kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_d0_grid_, + arg.p_d1_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_reduce_op_, + arg.d1_reduce_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); } else { @@ -582,6 +592,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, remove_reference_t, remove_reference_t< @@ -590,28 +601,28 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, false>; - ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_d0_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.d0_reduce_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.d_grid_desc_mblock_mperblock_, - arg.block_2_ctile_map_); + launch_kernel(kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_d0_grid_, + arg.p_d1_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_reduce_op_, + arg.d1_reduce_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); } - return ave_time; + return 0; } // polymorphic @@ -643,6 +654,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce(static_cast(p_a), static_cast(p_b), static_cast(p_c), static_cast(p_d0), + static_cast(p_d1), MRaw, NRaw, KRaw, @@ -701,7 +719,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - auto d_grid_buf = make_dynamic_buffer( + auto d0_grid_buf = make_dynamic_buffer( p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); + auto d1_grid_buf = make_dynamic_buffer( + p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); // divide block work by [M, N] const auto block_work_idx = @@ -696,30 +706,31 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); - constexpr auto c_in_reduce_thread_lengths_mperblock_nperblock = + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = Sequence{}; - // VGPR c_in_reduce_thread_desc_mperblock_nperblock - constexpr auto c_in_reduce_thread_desc_mperblock_nperblock = + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); - // VGPR c_out_reduce_thread_desc_mperblock - constexpr auto c_out_reduce_thread_desc_mperblock = + // VGPR d_reduce_thread_desc_mperblock + constexpr auto d_reduce_thread_desc_mperblock = make_naive_tensor_descriptor_packed(make_tuple(Number{})); - // VGPR c_out_reduce_thread_desc_mblock_mperblock - constexpr auto c_out_reduce_thread_desc_mblock_mperblock = + // VGPR d_reduce_thread_desc_mblock_mperblock + constexpr auto d_reduce_thread_desc_mblock_mperblock = make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); // TODO: this should be implemented as a blockwise reduction - auto c_in_reduce_thread_buf = - make_static_buffer( - c_in_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); - auto c_out_reduce_thread_buf = - make_static_buffer( - c_out_reduce_thread_desc_mperblock.GetElementSpaceSize()); + auto d0_thread_buf = make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto d1_thread_buf = make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); // reduce: threadwise copy from LDS to VGPR constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( @@ -730,14 +741,14 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_multi_index(get_thread_local_1d_id())); const auto c_reduce_thread_data_idx_begin = - c_reduce_thread_cluster_idx * c_in_reduce_thread_lengths_mperblock_nperblock; + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< FloatCShuffle, FloatCShuffle, decltype(c_reduce_block_desc_mperblock_nperblock), - decltype(c_in_reduce_thread_desc_mperblock_nperblock), - decltype(c_in_reduce_thread_lengths_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), Sequence<0, 1>, 1, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, @@ -745,10 +756,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; // reduce: copy from VGPR to global - auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< FloatCShuffle, FloatD, - decltype(c_out_reduce_thread_desc_mblock_mperblock), + decltype(d_reduce_thread_desc_mblock_mperblock), decltype(d_grid_desc_mblock_mperblock), ck::tensor_operation::element_wise::PassThrough, Sequence<1, mreduce_per_thread>, @@ -762,6 +773,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_reduce_thread_data_idx_begin[I0]), // mperblock ck::tensor_operation::element_wise::PassThrough{}}; + auto d1_reduce_thread_copy_vgpr_to_global = d0_reduce_thread_copy_vgpr_to_global; + // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, @@ -811,50 +824,63 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // reduce { - // LDS to VGPR - c_reduce_thread_copy_lds_to_vgpr.Run( - c_reduce_block_desc_mperblock_nperblock, - c_shuffle_block_buf, - c_in_reduce_thread_desc_mperblock_nperblock, - make_tuple(I0, I0), - c_in_reduce_thread_buf); + // copy from LDS to VGPR + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); // reduce in VGPR static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - FloatReduceAcc r_acc = d0_reduce_op.GetReduceZeroValue(); + FloatReduceAcc d0_acc = d0_reduce_op.GetReduceZeroValue(); + FloatReduceAcc d1_acc = d1_reduce_op.GetReduceZeroValue(); static_for<0, nreduce_per_thread, 1>{}([&](auto in) { - constexpr index_t in_offset = - c_in_reduce_thread_desc_mperblock_nperblock.CalculateOffset( - make_tuple(im, in)); + constexpr auto offset = + Number{}; - d0_reduce_op.Reduce(r_acc, c_in_reduce_thread_buf[Number{}]); + d0_reduce_op.Reduce(d0_acc, c_reduce_thread_buf[offset]); + d1_reduce_op.Reduce(d1_acc, c_reduce_thread_buf[offset]); }); constexpr index_t out_offset = - c_out_reduce_thread_desc_mperblock.CalculateOffset(make_tuple(im)); + d_reduce_thread_desc_mperblock.CalculateOffset(make_tuple(im)); - c_out_reduce_thread_buf(Number{}) = r_acc; + d0_thread_buf(Number{}) = d0_acc; + d1_thread_buf(Number{}) = d1_acc; }); - // VGPR to Global - c_reduce_thread_copy_vgpr_to_global.Run( - c_out_reduce_thread_desc_mblock_mperblock, - make_tuple(I0, I0), - c_out_reduce_thread_buf, - d_grid_desc_mblock_mperblock, - d_grid_buf); + // copy from VGPR to Global + d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + d0_thread_buf, + d_grid_desc_mblock_mperblock, + d0_grid_buf); + + d1_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + d1_thread_buf, + d_grid_desc_mblock_mperblock, + d1_grid_buf); } if constexpr(access_id < num_access - 1) { constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - // reduce: move - c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + // move on D0 + d0_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + d_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + + // move on D1 + d1_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( d_grid_desc_mblock_mperblock, make_tuple(c_global_step[I0], c_global_step[I1])); } diff --git a/library/include/ck/library/host_tensor/device.hpp b/library/include/ck/library/host_tensor/device.hpp index cefb249d459..f33b8d4f40c 100644 --- a/library/include/ck/library/host_tensor/device.hpp +++ b/library/include/ck/library/host_tensor/device.hpp @@ -50,7 +50,6 @@ template float launch_and_time_kernel( F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { -#if 1 KernelTimer timer; printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", @@ -80,13 +79,6 @@ float launch_and_time_kernel( timer.End(); - // std::this_thread::sleep_for (std::chrono::microseconds(10)); - return timer.GetElapsedTime() / nrepeat; -#else - launch_kernel(kernel, grid_dim, block_dim, lds_byte, args...); - - return 0; -#endif } #endif diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp index 9794a443e05..fe4aaef9439 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -30,31 +30,33 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // d1[m] = reduce1(c[m, n]) using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp index 9b1243e5f69..4ffdf84f8b6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -30,31 +30,33 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // d1[m] = reduce1(c[m, n]) using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp index b09cc0c3f03..3c9aad584ba 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -30,31 +30,33 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // d1[m] = reduce1(c[m, n]) using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp index 35323ba5714..7de3c627dfc 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -30,28 +30,30 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // d1[m] = reduce1(c[m, n]) using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); diff --git a/profiler/include/profile_gemm_reduce_impl.hpp b/profiler/include/profile_gemm_reduce_impl.hpp index 840b77f054c..77e85f9ba2c 100644 --- a/profiler/include/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profile_gemm_reduce_impl.hpp @@ -20,7 +20,8 @@ using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePt ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::ReduceSum>; + ck::tensor_operation::element_wise::ReduceSum, + ck::tensor_operation::element_wise::ReduceSquareSum>; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( std::vector&); @@ -78,17 +79,22 @@ void profile_gemm_reduce_impl(int do_verification, Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor d_m_host_result( + Tensor d0_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_host_result( HostTensorDescriptor(std::vector({static_cast(M)}))); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor d_m_device_result( + Tensor d0_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_device_result( HostTensorDescriptor(std::vector({static_cast(M)}))); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "d_m: " << d_m_host_result.mDesc << std::endl; + std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; + std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; std::size_t num_thread = std::thread::hardware_concurrency(); switch(init_method) @@ -106,12 +112,14 @@ void profile_gemm_reduce_impl(int do_verification, using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; - using DReduceOp = ck::tensor_operation::element_wise::ReduceSum; + using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum; + using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum; const auto a_element_op = AElementOp{}; const auto b_element_op = BElementOp{}; const auto c_element_op = CElementOp{}; - const auto d_reduce_op = DReduceOp{}; + const auto d0_reduce_op = D0ReduceOp{}; + const auto d1_reduce_op = D1ReduceOp{}; if(do_verification) { @@ -128,28 +136,29 @@ void profile_gemm_reduce_impl(int do_verification, for(int m = 0; m < M; ++m) { - float r_acc = d_reduce_op.GetReduceZeroValue(); + float d0_acc = d0_reduce_op.GetReduceZeroValue(); + float d1_acc = d1_reduce_op.GetReduceZeroValue(); for(int n = 0; n < N; ++n) { - d_reduce_op.Reduce(r_acc, c_m_n_host_result(m, n)); + d0_reduce_op.Reduce(d0_acc, c_m_n_host_result(m, n)); + d1_reduce_op.Reduce(d1_acc, c_m_n_host_result(m, n)); } - d_m_host_result(m) = r_acc; + d0_m_host_result(m) = d0_acc; + d1_m_host_result(m) = d1_acc; } } DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); - c_device_buf.SetZero(); - d_device_buf.SetZero(); - // add device GEMM instances std::vector gemm_ptrs; @@ -208,7 +217,8 @@ void profile_gemm_reduce_impl(int do_verification, gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), static_cast(b_device_buf.GetDeviceBuffer()), static_cast(c_device_buf.GetDeviceBuffer()), - static_cast(d_device_buf.GetDeviceBuffer()), + static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer()), M, N, K, @@ -218,18 +228,35 @@ void profile_gemm_reduce_impl(int do_verification, a_element_op, b_element_op, c_element_op, - d_reduce_op); + d0_reduce_op, + d1_reduce_op); auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { - // set zero - d_device_buf.SetZero(); + // warm up + invoker.Run(argument); - std::string gemm_name = gemm_ptr->GetTypeString(); + // timing + KernelTimer timer; + + timer.Start(); + + for(int i = 0; i < nrepeat; ++i) + { + // init DO, D1 to 0 + d0_m_device_buf.SetZero(); + d1_m_device_buf.SetZero(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + invoker.Run(argument); + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + std::string gemm_name = gemm_ptr->GetTypeString(); std::size_t flop = std::size_t(2) * M * N * K; @@ -254,10 +281,12 @@ void profile_gemm_reduce_impl(int do_verification, if(do_verification) { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - d_device_buf.FromDevice(d_m_device_result.mData.data()); + d0_device_buf.FromDevice(d0_m_device_result.mData.data()); + d1_device_buf.FromDevice(d1_m_device_result.mData.data()); check_error(c_m_n_host_result, c_m_n_device_result); - check_error(d_m_host_result, d_m_device_result); + check_error(d0_m_host_result, d0_m_device_result); + check_error(d1_m_host_result, d1_m_device_result); if(do_log) { @@ -267,9 +296,13 @@ void profile_gemm_reduce_impl(int do_verification, << std::endl; LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") << std::endl; - LogRangeAsType(std::cout << "d_host: ", d_m_host_result.mData, ",") + LogRangeAsType(std::cout << "d0_host: ", d0_m_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d0_device: ", d0_m_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d1_host: ", d1_m_host_result.mData, ",") << std::endl; - LogRangeAsType(std::cout << "d_device: ", d_m_device_result.mData, ",") + LogRangeAsType(std::cout << "d1_device: ", d1_m_device_result.mData, ",") << std::endl; } } From 21457e7aea94c9fb92ea1cf79ecafea2197f3011 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 21 Mar 2022 15:38:58 +0000 Subject: [PATCH 25/32] fix build --- profiler/include/profile_gemm_reduce_impl.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/profiler/include/profile_gemm_reduce_impl.hpp b/profiler/include/profile_gemm_reduce_impl.hpp index 77e85f9ba2c..011ceecc1f5 100644 --- a/profiler/include/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profile_gemm_reduce_impl.hpp @@ -236,7 +236,7 @@ void profile_gemm_reduce_impl(int do_verification, if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { // warm up - invoker.Run(argument); + invoker_ptr->Run(argument_ptr.get()); // timing KernelTimer timer; @@ -246,10 +246,10 @@ void profile_gemm_reduce_impl(int do_verification, for(int i = 0; i < nrepeat; ++i) { // init DO, D1 to 0 - d0_m_device_buf.SetZero(); - d1_m_device_buf.SetZero(); + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); - invoker.Run(argument); + invoker_ptr->Run(argument_ptr.get()); } timer.End(); From 38235137636135c9a5c94e837e941b435733cd5b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 22 Mar 2022 17:25:25 +0000 Subject: [PATCH 26/32] update DeviceGemm_Xdl_CShuffle; update enum to enum class --- example/01_gemm/gemm_xdl_fp16.cpp | 78 +-- include/ck/config.hpp | 4 +- ...nvolution_backward_data_specialization.hpp | 2 +- .../convolution_forward_specialization.hpp | 3 +- ...evice_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp | 8 +- .../device_gemm_reduce_xdl_cshuffle.hpp | 23 +- .../gpu/device/device_gemm_xdl_cshuffle.hpp | 542 ++++++++++++------ .../gpu/device/gemm_specialization.hpp | 2 +- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 68 +-- include/ck/utility/amd_address_space.hpp | 2 +- include/ck/utility/data_type_enum.hpp | 2 +- .../ck/library/host_tensor/host_tensor.hpp | 6 +- .../gemm_driver_offline.cpp | 4 +- profiler/src/profile_batched_gemm.cpp | 8 +- profiler/src/profile_conv_bwd_data.cpp | 16 +- profiler/src/profile_conv_fwd.cpp | 16 +- profiler/src/profile_conv_fwd_bias_relu.cpp | 16 +- .../src/profile_conv_fwd_bias_relu_add.cpp | 16 +- .../profile_conv_fwd_bias_relu_atomic_add.cpp | 16 +- profiler/src/profile_gemm.cpp | 8 +- profiler/src/profile_gemm_bias_2d.cpp | 8 +- profiler/src/profile_gemm_bias_relu.cpp | 8 +- profiler/src/profile_gemm_bias_relu_add.cpp | 8 +- profiler/src/profile_gemm_reduce.cpp | 8 +- profiler/src/profile_reduce.cpp | 25 +- test/gemm_split_k/gemm_split_k.cpp | 8 +- 26 files changed, 522 insertions(+), 383 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index f894639a622..8db97a5b256 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -13,7 +13,7 @@ #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" #include "device_gemm_xdl_c_shuffle.hpp" -#include "device_gemm_xdl_cshuffle_v1.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -45,59 +45,12 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; // clang-format off -#if 0 -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl -//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| -//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| -//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -// [256, 128, 4, 8], 1 stage, 2 occupancy - < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; -#elif 0 -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle -//######|AData| BData| CData| AccData| Shuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| Type| Type| Type| Type| Data| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -//######| | | | | Type| | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; -#elif 1 -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v1 -//######|AData| BData| CData| GemmAcc| CShuffle| ALayout| BLayout| CLayout| A| B| C| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| Type| Type| Type| DataType| DataType| | | | Elementwise| Elementwise| Elementwise| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| -//######| | | | | | | | | Operation| Operation| Operation| Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; -// < F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 2, 2, S<1, 16, 1,16>, 8>; -#elif 0 -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl -//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| -//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| -//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -// [128, 144, 8, 8], 1 stage, 1 occupancy, bounded by LDS size -// 99 TFlops, 120 blocks (1024x2160x3840) -// 99 TFlops, 960 blocks (4096x4320x3840) - < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; -// [128, 144, 4, 8], 1 stage, 2 occupancy, -// 92 TFlops, 120 blocks (1024x2160x3840) -// 120 TFlops, 240 blocks (1024x4320x3840) -// 128 TFlops, 960 blocks (4096x4320x3840) -// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; -// [ 64, 144, 8, 8], 1 stage, 2 occupancy/ -// 96 TFlops, 240 blocks (1024x2160x3840) -// 96 TFlops, 480 blocks (1024x4320x3840) -// 99 TFlops,1920 blocks (4096x4320x3840) -// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 8, 8, 16, 16, 1, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; -// [ 64, 144, 8, 8], 2 stage, 2 occupancy -// 93 TFlops -// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 8, 8, 16, 16, 1, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 2>; -// [ 64, 144, 4, 8], 1 stage, 2 occupancy -// 87 TFlops -// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 4, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; -// [ 64, 144, 4, 8], 2 stage, 2 occupancy -// 85 TFlops -// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 4, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 2>; -#endif +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| +//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: @@ -220,7 +173,22 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, nrepeat); + // warm up + invoker.Run(argument); + + // timing + KernelTimer timer; + + timer.Start(); + + for(int i = 0; i < nrepeat; ++i) + { + invoker.Run(argument); + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = diff --git a/include/ck/config.hpp b/include/ck/config.hpp index 7f51d29715d..2f06140d5ce 100644 --- a/include/ck/config.hpp +++ b/include/ck/config.hpp @@ -159,14 +159,14 @@ namespace ck { -enum InMemoryDataOperationEnum_t +enum struct InMemoryDataOperationEnum_t { Set, AtomicAdd, Add }; -enum ActivTypeEnum_t +enum struct ActivTypeEnum_t { None, LeakyRelu, diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp index 4c1d6747c4e..b62721f5e4c 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp @@ -5,7 +5,7 @@ namespace ck { namespace tensor_operation { namespace device { -enum ConvolutionBackwardDataSpecialization_t +enum struct ConvolutionBackwardDataSpecialization_t { Default, Filter1x1Stride1Pad0, diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index e047acee76f..3ab0c7b2eba 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -5,12 +5,13 @@ namespace ck { namespace tensor_operation { namespace device { -enum ConvolutionForwardSpecialization_t +enum struct ConvolutionForwardSpecialization_t { Default, Filter1x1Pad0, Filter1x1Stride1Pad0, OddC, + UnDefined, }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index ffa1815ab7a..471cdb8bf69 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -210,12 +210,16 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) { - static_assert(ConvForwardSpecialization == -1, "Not implemented!"); + static_assert(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::UnDefined, + "Not implemented!"); } else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Pad0) { - static_assert(ConvForwardSpecialization == -1, "Not implemented!"); + static_assert(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::UnDefined, + "Not implemented!"); } else { diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index e85a5b8f9b2..47ad29c7560 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -68,6 +68,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce { + using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -458,13 +460,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t, @@ -593,8 +592,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp index e83767b1d2b..b49d5894bc2 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -13,17 +13,18 @@ namespace ck { namespace tensor_operation { namespace device { -template -struct DeviceGemm_Xdl_CShuffle_v1 +struct DeviceGemm_Xdl_CShuffle : public DeviceGemm { - using DeviceOp = DeviceGemm_Xdl_CShuffle_v1; + using DeviceOp = DeviceGemm_Xdl_CShuffle; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) { - assert(K % AK1 == 0); - - const index_t K0 = K / AK1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); } }(); - const auto a_grid_desc_k0_m_k1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, AK1)), make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; - return a_grid_desc_k0_m_k1; - } + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - assert(K % BK1 == 0); + if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding || + GemmSpecialization == GemmSpecialization_t::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || + GemmSpecialization == GemmSpecialization_t::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || + GemmSpecialization == GemmSpecialization_t::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - const index_t K0 = K / BK1; + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - const auto b_grid_desc_k_n = [&]() { + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); } }(); - const auto b_grid_desc_k0_n_k1 = transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, BK1)), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; - return b_grid_desc_k0_n_k1; + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding || + GemmSpecialization == GemmSpecialization_t::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || + GemmSpecialization == GemmSpecialization_t::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || + GemmSpecialization == GemmSpecialization_t::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } } - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) { - if constexpr(is_same::value) + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding || + GemmSpecialization == GemmSpecialization_t::MNKPadding) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); } - else if constexpr(is_same::value) + else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || + GemmSpecialization == GemmSpecialization_t::MKPadding) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || + GemmSpecialization == GemmSpecialization_t::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; } } - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< @@ -136,13 +337,13 @@ struct DeviceGemm_Xdl_CShuffle_v1 GemmAccDataType, CShuffleDataType, CDataType, - InMemoryDataOperationEnum_t::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, + InMemoryDataOperationEnum_t::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, @@ -154,20 +355,20 @@ struct DeviceGemm_Xdl_CShuffle_v1 NPerXDL, MXdlPerWave, NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, @@ -181,44 +382,39 @@ struct DeviceGemm_Xdl_CShuffle_v1 Argument(const ADataType* p_a_grid, const BDataType* p_b_grid, CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, + index_t MRaw, + index_t NRaw, + index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, - index_t M01, - index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) : p_a_grid_{p_a_grid}, p_b_grid_{p_b_grid}, p_c_grid_{p_c_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, + a_grid_desc_ak0_m_ak1_{}, + b_grid_desc_bk0_n_bk1_{}, c_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} { - a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC); + a_grid_desc_ak0_m_ak1_ = DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA); + b_grid_desc_bk0_n_bk1_ = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB); + c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC); if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) { c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n_); - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); } } @@ -226,14 +422,12 @@ struct DeviceGemm_Xdl_CShuffle_v1 const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; @@ -244,26 +438,27 @@ struct DeviceGemm_Xdl_CShuffle_v1 { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, int /* nrepeat */ = 1) { +#if 0 { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } +#endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); @@ -271,90 +466,74 @@ struct DeviceGemm_Xdl_CShuffle_v1 const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - float ave_time = 0; - if(has_main_k0_block_loop) { - const auto kernel = -#if 0 - kernel_gemm_xdlops_v4r1 -#else - kernel_gemm_xdl_cshuffle_v1 -#endif - , - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t, + true>; + + launch_kernel(kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); } else { - const auto kernel = -#if 0 - kernel_gemm_xdlops_v4r1 -#else - kernel_gemm_xdl_cshuffle_v1 -#endif - , - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t, + false>; + + launch_kernel(kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); } - return ave_time; + return 0; } // polymorphic @@ -372,11 +551,8 @@ struct DeviceGemm_Xdl_CShuffle_v1 static bool IsSupportedArgument(const Argument& arg) { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_); } // polymorphic @@ -388,9 +564,9 @@ struct DeviceGemm_Xdl_CShuffle_v1 static auto MakeArgument(const ADataType* p_a, const BDataType* p_b, CDataType* p_c, - index_t M, - index_t N, - index_t K, + index_t MRaw, + index_t NRaw, + index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, @@ -401,14 +577,12 @@ struct DeviceGemm_Xdl_CShuffle_v1 return Argument{p_a, p_b, p_c, - M, - N, - K, + MRaw, + NRaw, + KRaw, StrideA, StrideB, StrideC, - 1, - 1, a_element_op, b_element_op, c_element_op}; @@ -420,9 +594,9 @@ struct DeviceGemm_Xdl_CShuffle_v1 std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, void* p_c, - index_t M, - index_t N, - index_t K, + index_t MRaw, + index_t NRaw, + index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, @@ -434,14 +608,12 @@ struct DeviceGemm_Xdl_CShuffle_v1 return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_c), - M, - N, - K, + MRaw, + NRaw, + KRaw, StrideA, StrideB, StrideC, - 1, - 1, a_element_op, b_element_op, c_element_op); @@ -459,7 +631,7 @@ struct DeviceGemm_Xdl_CShuffle_v1 auto str = std::stringstream(); // clang-format off - str << "DeviceGemm_Xdl_CShuffle_v1" + str << "DeviceGemm_Xdl_CShuffle" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp index 0daf482cf74..81029e88b17 100644 --- a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -5,7 +5,7 @@ namespace ck { namespace tensor_operation { namespace device { -enum GemmSpecialization_t +enum struct GemmSpecialization_t { Default, MPadding, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 54e269c09d3..0284bbd55ef 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -14,12 +14,12 @@ namespace ck { template __global__ void @@ -29,13 +29,13 @@ __global__ void kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -44,12 +44,12 @@ __global__ void p_b_grid, p_c_grid, p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, b_element_op, c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, block_2_ctile_map); } @@ -57,13 +57,13 @@ template >::value && // is_known_at_compile_time>::value, @@ -217,16 +214,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 return false; } - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) - return false; - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } @@ -271,7 +258,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) { const auto M = c_grid_desc_m_n.GetLength(I0); const auto N = c_grid_desc_m_n.GetLength(I1); @@ -282,6 +269,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 const auto M0 = M / M1; const auto N0 = N / N1; + // FIXME: remove + constexpr auto M01 = I1; + constexpr auto N01 = I1; + const auto M00 = M0 / M01; const auto N00 = N0 / N01; @@ -309,20 +300,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; using DefaultBlock2CTileMap = - remove_cvref_t; + remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, const Block2CTileMap& block_2_ctile_map) { const auto a_grid_buf = make_dynamic_buffer( @@ -370,7 +361,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ABlockTransferSrcVectorDim, 2, ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, + ABlockTransferDstScalarPerVector_AK1, 1, 1, AThreadTransferSrcResetCoordinateAfterRun, @@ -401,7 +392,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 BBlockTransferSrcVectorDim, 2, BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, + BBlockTransferDstScalarPerVector_BK1, 1, 1, BThreadTransferSrcResetCoordinateAfterRun, @@ -572,7 +563,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); - // VGPR to LDS + // shuffle: threadwise copy C from VGPR to LDS auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3 struct DataType; @@ -53,6 +54,7 @@ template <> struct DataType : std::integral_constant { }; +#endif template auto call_f_unpack_args_impl(F f, T args, std::index_sequence) diff --git a/library/src/obselete_driver_offline/gemm_driver_offline.cpp b/library/src/obselete_driver_offline/gemm_driver_offline.cpp index bd8cb00390c..0c59bea6200 100644 --- a/library/src/obselete_driver_offline/gemm_driver_offline.cpp +++ b/library/src/obselete_driver_offline/gemm_driver_offline.cpp @@ -30,7 +30,7 @@ #define USE_GEMM_XDL_KM_KN_NM 0 #define USE_GEMM_XDL_KM_NK_NM 0 -enum GemmMatrixLayout +enum struct GemmMatrixLayout { MK_KN_MN, // 0 MK_NK_MN, // 1 @@ -42,7 +42,7 @@ enum GemmMatrixLayout KM_NK_NM // 7 }; -enum GemmAlgo +enum struct GemmAlgo { Xdl_MK_KN_MN, // 0 Xdl_MK_NK_MN, // 1 diff --git a/profiler/src/profile_batched_gemm.cpp b/profiler/src/profile_batched_gemm.cpp index 6a0edc09659..5309dad2dcd 100644 --- a/profiler/src/profile_batched_gemm.cpp +++ b/profiler/src/profile_batched_gemm.cpp @@ -15,7 +15,7 @@ #include "device_batched_gemm_xdl.hpp" #include "profile_batched_gemm_impl.hpp" -enum GemmMatrixLayout +enum struct GemmMatrixLayout { MK_KN_MN, // 0 MK_NK_MN, // 1 @@ -27,7 +27,7 @@ enum GemmMatrixLayout KM_NK_NM, // 7 }; -enum GemmDataType +enum struct GemmDataType { F32_F32_F32, // 0 F16_F16_F16, // 1 @@ -51,8 +51,8 @@ int profile_batched_gemm(int argc, char* argv[]) exit(1); } - const int data_type = static_cast(std::stoi(argv[2])); - const int layout = static_cast(std::stoi(argv[3])); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); diff --git a/profiler/src/profile_conv_bwd_data.cpp b/profiler/src/profile_conv_bwd_data.cpp index 613c6879e6b..c00c25f8b17 100644 --- a/profiler/src/profile_conv_bwd_data.cpp +++ b/profiler/src/profile_conv_bwd_data.cpp @@ -6,7 +6,7 @@ #include #include "profile_conv_bwd_data_impl.hpp" -enum ConvDataType +enum struct ConvDataType { F32_F32_F32, // 0 F16_F16_F16, // 1 @@ -14,19 +14,19 @@ enum ConvDataType INT8_INT8_INT8, // 3 }; -enum ConvInputLayout +enum struct ConvInputLayout { NCHW, // 0 NHWC, // 1 }; -enum ConvWeightLayout +enum struct ConvWeightLayout { KCYX, // 0 KYXC, // 1 }; -enum ConvOutputLayout +enum struct ConvOutputLayout { NKHW, // 0 NHWK, // 1 @@ -50,10 +50,10 @@ int profile_conv_bwd_data(int argc, char* argv[]) exit(1); } - const int data_type = static_cast(std::stoi(argv[2])); - const int in_layout = static_cast(std::stoi(argv[3])); - const int wei_layout = static_cast(std::stoi(argv[4])); - const int out_layout = static_cast(std::stoi(argv[5])); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); const bool do_verification = std::stoi(argv[6]); const int init_method = std::stoi(argv[7]); const bool do_log = std::stoi(argv[8]); diff --git a/profiler/src/profile_conv_fwd.cpp b/profiler/src/profile_conv_fwd.cpp index f087c1abbc0..3d4aa358f29 100644 --- a/profiler/src/profile_conv_fwd.cpp +++ b/profiler/src/profile_conv_fwd.cpp @@ -6,7 +6,7 @@ #include #include "profile_conv_fwd_impl.hpp" -enum ConvDataType +enum struct ConvDataType { F32_F32_F32, // 0 F16_F16_F16, // 1 @@ -14,19 +14,19 @@ enum ConvDataType INT8_INT8_INT8, // 3 }; -enum ConvInputLayout +enum struct ConvInputLayout { NCHW, // 0 NHWC, // 1 }; -enum ConvWeightLayout +enum struct ConvWeightLayout { KCYX, // 0 KYXC, // 1 }; -enum ConvOutputLayout +enum struct ConvOutputLayout { NKHW, // 0 NHWK, // 1 @@ -50,10 +50,10 @@ int profile_conv_fwd(int argc, char* argv[]) exit(1); } - const int data_type = static_cast(std::stoi(argv[2])); - const int in_layout = static_cast(std::stoi(argv[3])); - const int wei_layout = static_cast(std::stoi(argv[4])); - const int out_layout = static_cast(std::stoi(argv[5])); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); const bool do_verification = std::stoi(argv[6]); const int init_method = std::stoi(argv[7]); const bool do_log = std::stoi(argv[8]); diff --git a/profiler/src/profile_conv_fwd_bias_relu.cpp b/profiler/src/profile_conv_fwd_bias_relu.cpp index 3390a9e4728..1c447b483ea 100644 --- a/profiler/src/profile_conv_fwd_bias_relu.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu.cpp @@ -6,25 +6,25 @@ #include #include "profile_conv_fwd_bias_relu_impl.hpp" -enum ConvDataType +enum struct ConvDataType { F32_F32_F32, // 0 F16_F16_F16, // 1 }; -enum ConvInputLayout +enum struct ConvInputLayout { NCHW, // 0 NHWC, // 1 }; -enum ConvWeightLayout +enum struct ConvWeightLayout { KCYX, // 0 KYXC, // 1 }; -enum ConvOutputLayout +enum struct ConvOutputLayout { NKHW, // 0 NHWK, // 1 @@ -48,10 +48,10 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[]) exit(1); } - const int data_type = static_cast(std::stoi(argv[2])); - const int in_layout = static_cast(std::stoi(argv[3])); - const int wei_layout = static_cast(std::stoi(argv[4])); - const int out_layout = static_cast(std::stoi(argv[5])); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); const bool do_verification = std::stoi(argv[6]); const int init_method = std::stoi(argv[7]); const bool do_log = std::stoi(argv[8]); diff --git a/profiler/src/profile_conv_fwd_bias_relu_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_add.cpp index b6b48222344..522487c77be 100644 --- a/profiler/src/profile_conv_fwd_bias_relu_add.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu_add.cpp @@ -6,25 +6,25 @@ #include #include "profile_conv_fwd_bias_relu_add_impl.hpp" -enum ConvDataType +enum struct ConvDataType { F32_F32_F32, // 0 F16_F16_F16, // 1 }; -enum ConvInputLayout +enum struct ConvInputLayout { NCHW, // 0 NHWC, // 1 }; -enum ConvWeightLayout +enum struct ConvWeightLayout { KCYX, // 0 KYXC, // 1 }; -enum ConvOutputLayout +enum struct ConvOutputLayout { NKHW, // 0 NHWK, // 1 @@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) exit(1); } - const int data_type = static_cast(std::stoi(argv[2])); - const int in_layout = static_cast(std::stoi(argv[3])); - const int wei_layout = static_cast(std::stoi(argv[4])); - const int out_layout = static_cast(std::stoi(argv[5])); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); const bool do_verification = std::stoi(argv[6]); const int init_method = std::stoi(argv[7]); const bool do_log = std::stoi(argv[8]); diff --git a/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp index 3c179d36b2b..833f2851db3 100644 --- a/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp @@ -6,25 +6,25 @@ #include #include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp" -enum ConvDataType +enum struct ConvDataType { F32_F32_F32, // 0 F16_F16_F16, // 1 }; -enum ConvInputLayout +enum struct ConvInputLayout { NCHW, // 0 NHWC, // 1 }; -enum ConvWeightLayout +enum struct ConvWeightLayout { KCYX, // 0 KYXC, // 1 }; -enum ConvOutputLayout +enum struct ConvOutputLayout { NKHW, // 0 NHWK, // 1 @@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[]) exit(1); } - const int data_type = static_cast(std::stoi(argv[2])); - const int in_layout = static_cast(std::stoi(argv[3])); - const int wei_layout = static_cast(std::stoi(argv[4])); - const int out_layout = static_cast(std::stoi(argv[5])); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); const bool do_verification = std::stoi(argv[6]); const int init_method = std::stoi(argv[7]); const bool do_log = std::stoi(argv[8]); diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp index d85eec14657..4c2b6b07ece 100644 --- a/profiler/src/profile_gemm.cpp +++ b/profiler/src/profile_gemm.cpp @@ -6,7 +6,7 @@ #include #include "profile_gemm_impl.hpp" -enum GemmMatrixLayout +enum struct GemmMatrixLayout { MK_KN_MN, // 0 MK_NK_MN, // 1 @@ -18,7 +18,7 @@ enum GemmMatrixLayout KM_NK_NM, // 7 }; -enum GemmDataType +enum struct GemmDataType { F32_F32_F32, // 0 F16_F16_F16, // 1 @@ -45,8 +45,8 @@ int profile_gemm(int argc, char* argv[]) exit(1); } - const int data_type = static_cast(std::stoi(argv[2])); - const int layout = static_cast(std::stoi(argv[3])); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); diff --git a/profiler/src/profile_gemm_bias_2d.cpp b/profiler/src/profile_gemm_bias_2d.cpp index ca941f203a1..dd7e4180878 100644 --- a/profiler/src/profile_gemm_bias_2d.cpp +++ b/profiler/src/profile_gemm_bias_2d.cpp @@ -6,7 +6,7 @@ #include #include "profile_gemm_bias_2d_impl.hpp" -enum GemmMatrixLayout +enum struct GemmMatrixLayout { MK_KN_MN, // 0 MK_NK_MN, // 1 @@ -18,7 +18,7 @@ enum GemmMatrixLayout KM_NK_NM, // 7 }; -enum GemmDataType +enum struct GemmDataType { F32_F32_F32, // 0 F16_F16_F16, // 1 @@ -45,8 +45,8 @@ int profile_gemm_bias_2d(int argc, char* argv[]) exit(1); } - const int data_type = static_cast(std::stoi(argv[2])); - const int layout = static_cast(std::stoi(argv[3])); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); diff --git a/profiler/src/profile_gemm_bias_relu.cpp b/profiler/src/profile_gemm_bias_relu.cpp index 709a0a1671c..67a47cf9ec3 100644 --- a/profiler/src/profile_gemm_bias_relu.cpp +++ b/profiler/src/profile_gemm_bias_relu.cpp @@ -6,7 +6,7 @@ #include #include "profile_gemm_bias_relu_impl.hpp" -enum GemmMatrixLayout +enum struct GemmMatrixLayout { MK_KN_MN, // 0 MK_NK_MN, // 1 @@ -18,7 +18,7 @@ enum GemmMatrixLayout KM_NK_NM, // 7 }; -enum GemmDataType +enum struct GemmDataType { F32_F32_F32, // 0 F16_F16_F16, // 1 @@ -43,8 +43,8 @@ int profile_gemm_bias_relu(int argc, char* argv[]) exit(1); } - const int data_type = static_cast(std::stoi(argv[2])); - const int layout = static_cast(std::stoi(argv[3])); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); diff --git a/profiler/src/profile_gemm_bias_relu_add.cpp b/profiler/src/profile_gemm_bias_relu_add.cpp index 592f10321c3..52406e93d6c 100644 --- a/profiler/src/profile_gemm_bias_relu_add.cpp +++ b/profiler/src/profile_gemm_bias_relu_add.cpp @@ -6,7 +6,7 @@ #include #include "profile_gemm_bias_relu_add_impl.hpp" -enum GemmMatrixLayout +enum struct GemmMatrixLayout { MK_KN_MN, // 0 MK_NK_MN, // 1 @@ -18,7 +18,7 @@ enum GemmMatrixLayout KM_NK_NM, // 7 }; -enum GemmDataType +enum struct GemmDataType { F32_F32_F32, // 0 F16_F16_F16, // 1 @@ -43,8 +43,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) exit(1); } - const int data_type = static_cast(std::stoi(argv[2])); - const int layout = static_cast(std::stoi(argv[3])); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); diff --git a/profiler/src/profile_gemm_reduce.cpp b/profiler/src/profile_gemm_reduce.cpp index 8335d0a33b2..2149f3ce471 100644 --- a/profiler/src/profile_gemm_reduce.cpp +++ b/profiler/src/profile_gemm_reduce.cpp @@ -8,7 +8,7 @@ int profile_gemm_reduce(int argc, char* argv[]) { - enum GemmMatrixLayout_t + enum struct GemmMatrixLayout_t { MK_KN_MN, // 0 MK_NK_MN, // 1 @@ -16,7 +16,7 @@ int profile_gemm_reduce(int argc, char* argv[]) KM_NK_MN, // 3 }; - enum GemmReduceDataType_t + enum struct GemmReduceDataType_t { F32_F32_F32_F32_F32, // 0 F16_F16_F16_F32_F32, // 1 @@ -39,8 +39,8 @@ int profile_gemm_reduce(int argc, char* argv[]) exit(1); } - const int data_type = static_cast(std::stoi(argv[2])); - const int layout = static_cast(std::stoi(argv[3])); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); diff --git a/profiler/src/profile_reduce.cpp b/profiler/src/profile_reduce.cpp index ef8fd1115bd..7c0819528d6 100644 --- a/profiler/src/profile_reduce.cpp +++ b/profiler/src/profile_reduce.cpp @@ -82,7 +82,7 @@ static std::vector getTypeValuesFromString(const char* cstr_values) return (values); } -typedef enum +enum struct appDataType_t { appHalf = 0, appFloat = 1, @@ -91,7 +91,7 @@ typedef enum appInt8x4 = 4, appBFloat16 = 5, appDouble = 6, -} appDataType_t; +}; static void check_reduce_dims(const int rank, const std::vector& reduceDims) { @@ -127,8 +127,8 @@ class AppArgs std::vector scales; ReduceTensorOp_t reduceOp = ReduceTensorOp_t::ADD; - appDataType_t compTypeId = appFloat; - appDataType_t outTypeId = appFloat; + appDataType_t compTypeId = appDataType_t::appFloat; + appDataType_t outTypeId = appDataType_t::appFloat; bool compType_assigned = false; bool outType_assigned = false; @@ -329,15 +329,16 @@ int profile_reduce(int argc, char* argv[]) if(args.use_half) { if(!args.compType_assigned) - args.compTypeId = appHalf; + args.compTypeId = appDataType_t::appHalf; - if(args.outType_assigned && (args.outTypeId != appHalf && args.outTypeId != appFloat)) - args.outTypeId = appFloat; + if(args.outType_assigned && + (args.outTypeId != appDataType_t::appHalf && args.outTypeId != appDataType_t::appFloat)) + args.outTypeId = appDataType_t::appFloat; if(!args.outType_assigned) - args.outTypeId = appHalf; + args.outTypeId = appDataType_t::appHalf; - if(args.compTypeId == appHalf) + if(args.compTypeId == appDataType_t::appHalf) { profile_reduce_impl(args.do_verification, args.init_method, @@ -352,7 +353,7 @@ int profile_reduce(int argc, char* argv[]) args.scales[0], args.scales[1]); } - else if(args.compTypeId == appFloat) + else if(args.compTypeId == appDataType_t::appFloat) { profile_reduce_impl(args.do_verification, args.init_method, @@ -387,7 +388,7 @@ int profile_reduce(int argc, char* argv[]) } else { - if(args.compTypeId == appFloat) + if(args.compTypeId == appDataType_t::appFloat) { profile_reduce_impl(args.do_verification, args.init_method, @@ -402,7 +403,7 @@ int profile_reduce(int argc, char* argv[]) args.scales[0], args.scales[1]); } - else if(args.compTypeId == appDouble) + else if(args.compTypeId == appDataType_t::appDouble) { profile_reduce_impl(args.do_verification, args.init_method, diff --git a/test/gemm_split_k/gemm_split_k.cpp b/test/gemm_split_k/gemm_split_k.cpp index 408336769c2..98a98b5518b 100644 --- a/test/gemm_split_k/gemm_split_k.cpp +++ b/test/gemm_split_k/gemm_split_k.cpp @@ -12,7 +12,7 @@ #include "tensor_layout.hpp" #include "device_gemm_xdl_splitk.hpp" -enum GemmMatrixLayout +enum struct GemmMatrixLayout { MK_KN_MN, // 0 MK_NK_MN, // 1 @@ -59,7 +59,7 @@ static bool check_out(const Tensor& ref, const Tensor& result) struct gemmArgs { - int layout; + GemmMatrixLayout layout; int M; int N; int K; @@ -216,13 +216,13 @@ int main(int argc, char* argv[]) std::vector test_cases; if(argc == 1) { - test_cases = {{0, 3, 3, 3, 3, 3, 3, 1}}; + test_cases = {{GemmMatrixLayout::MK_KN_MN, 3, 3, 3, 3, 3, 3, 1}}; // JD: Populate with more and meaningful return 0; } else if(argc == 9) { - const int layout = static_cast(std::stoi(argv[1])); + const auto layout = static_cast(std::stoi(argv[1])); const int M = std::stoi(argv[2]); const int N = std::stoi(argv[3]); From effd8e17b64c6adfce9790b3fea013b6368c2616 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 22 Mar 2022 17:30:04 +0000 Subject: [PATCH 27/32] clean up --- .../device_gemm_reduce_xdl_cshuffle.hpp | 32 +++++++++---------- .../gpu/device/device_gemm_xdl_cshuffle.hpp | 32 +++++++++---------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 47ad29c7560..43339cbc37e 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -31,29 +31,29 @@ template Date: Tue, 22 Mar 2022 23:09:46 +0000 Subject: [PATCH 28/32] add test for gemm+reduce --- .../ck/library/host_tensor/host_tensor.hpp | 4 +- profiler/include/profile_gemm_reduce_impl.hpp | 18 +++++-- test/CMakeLists.txt | 2 + test/gemm_reduce/CMakeLists.txt | 9 ++++ test/gemm_reduce/gemm_reduce_fp16.cpp | 52 +++++++++++++++++++ 5 files changed, 80 insertions(+), 5 deletions(-) create mode 100644 test/gemm_reduce/CMakeLists.txt create mode 100644 test/gemm_reduce/gemm_reduce_fp16.cpp diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index a7525253db0..90c5c43430f 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -319,7 +319,7 @@ float bf16_to_f32_(ck::bhalf_t src_val); void bf16_to_f32_(const Tensor& src, Tensor& dst); template -void check_error(const Tensor& ref, const Tensor& result) +float check_error(const Tensor& ref, const Tensor& result) { float error = 0; float max_diff = -1; @@ -356,6 +356,8 @@ void check_error(const Tensor& ref, const Tensor& result) std::cout << "error: " << error << std::endl; std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; + + return max_diff; } template diff --git a/profiler/include/profile_gemm_reduce_impl.hpp b/profiler/include/profile_gemm_reduce_impl.hpp index 011ceecc1f5..a11a673ac23 100644 --- a/profiler/include/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profile_gemm_reduce_impl.hpp @@ -50,7 +50,7 @@ template -void profile_gemm_reduce_impl(int do_verification, +bool profile_gemm_reduce_impl(int do_verification, int init_method, bool do_log, int nrepeat, @@ -61,6 +61,8 @@ void profile_gemm_reduce_impl(int do_verification, int StrideB, int StrideC) { + bool pass = true; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(is_same::value) @@ -101,10 +103,12 @@ void profile_gemm_reduce_impl(int do_verification, { case 0: break; case 1: + std::srand(0); a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; default: + std::srand(0); a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } @@ -284,9 +288,13 @@ void profile_gemm_reduce_impl(int do_verification, d0_device_buf.FromDevice(d0_m_device_result.mData.data()); d1_device_buf.FromDevice(d1_m_device_result.mData.data()); - check_error(c_m_n_host_result, c_m_n_device_result); - check_error(d0_m_host_result, d0_m_device_result); - check_error(d1_m_host_result, d1_m_device_result); + float c_error = check_error(c_m_n_host_result, c_m_n_device_result); + float d0_error = check_error(d0_m_host_result, d0_m_device_result); + float d1_error = check_error(d1_m_host_result, d1_m_device_result); + + pass = pass && (c_error < 1E-6); + pass = pass && (d0_error < 1E-6); + pass = pass && (d1_error < 1E-6); if(do_log) { @@ -315,6 +323,8 @@ void profile_gemm_reduce_impl(int do_verification, std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + + return pass; } } // namespace profiler diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index eec8b5b852e..378ae57899e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -16,6 +16,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu ${PROJECT_SOURCE_DIR}/test/include + ${PROJECT_SOURCE_DIR}/profiler/include ${PROJECT_SOURCE_DIR}/external/include/half ) @@ -36,6 +37,7 @@ add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) add_subdirectory(gemm_split_k) +add_subdirectory(gemm_reduce) add_subdirectory(conv2d_fwd) add_subdirectory(convnd_fwd) add_subdirectory(conv2d_bwd_data) diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000000..e474af32301 --- /dev/null +++ b/test/gemm_reduce/CMakeLists.txt @@ -0,0 +1,9 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/profiler/include + ${PROJECT_SOURCE_DIR}/test/include + ${PROJECT_SOURCE_DIR}/external/include/half +) + +add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) +target_link_libraries(test_gemm_reduce_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance) diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16.cpp new file mode 100644 index 00000000000..0b3421a667e --- /dev/null +++ b/test/gemm_reduce/gemm_reduce_fp16.cpp @@ -0,0 +1,52 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "profile_gemm_reduce_impl.hpp" + +int main() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + int M = 512; + int N = 256; + int K = 128; + + bool pass = true; + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, 1, M, N, K, K, N, N); + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, 1, M, N, K, K, K, N); + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, 1, M, N, K, M, N, N); + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, 1, M, N, K, M, K, N); + + if(pass) + { + std::cout << "test GEMM+Reduce fp16: Pass" << std::endl; + return 0; + } + else + { + std::cout << "test GEMM+Reduce fp16: Fail" << std::endl; + return -1; + } +} From 89afbc1b1015469e2391928e8ac4e2df9a19491b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 22 Mar 2022 23:14:49 +0000 Subject: [PATCH 29/32] clean up --- .../ck/library/host_tensor/host_tensor.hpp | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index 90c5c43430f..955b5b88a1e 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -40,22 +40,6 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) return os; } -#if 0 -enum struct DataType_t -{ - Half = 0, - Float = 1, -}; - -template -struct DataType; - -template <> -struct DataType : std::integral_constant -{ -}; -#endif - template auto call_f_unpack_args_impl(F f, T args, std::index_sequence) { From fb2779c551e3ef56e5aec2ec918c0ed8c66173e0 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 23 Mar 2022 18:30:12 +0000 Subject: [PATCH 30/32] refactor --- .../16_gemm_reduce/gemm_reduce_xdl_fp16.cpp | 50 ++++++++------- .../convolution_forward_specialization.hpp | 1 - ...evice_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp | 61 +++++++------------ .../device_gemm_reduce_xdl_cshuffle.hpp | 38 +++++------- .../gpu/device/device_gemm_xdl_cshuffle.hpp | 31 ++++------ include/ck/utility/data_type.hpp | 9 +-- .../ck/library/host_tensor/host_tensor.hpp | 60 +++++++++--------- profiler/include/profile_gemm_reduce_impl.hpp | 16 +++-- 8 files changed, 121 insertions(+), 145 deletions(-) diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp index c42f767c254..6f173ae1de9 100644 --- a/example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp +++ b/example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp @@ -153,14 +153,14 @@ int main(int argc, char* argv[]) break; } - DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem d0_m_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); - DeviceMem d1_m_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -171,11 +171,11 @@ int main(int argc, char* argv[]) // do GEMM auto gemm = DeviceGemmReduceInstance{}; auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - static_cast(d0_m_device_buf.GetDeviceBuffer()), - static_cast(d1_m_device_buf.GetDeviceBuffer()), + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer()), M, N, K, @@ -199,22 +199,26 @@ int main(int argc, char* argv[]) invoker.Run(argument); // timing - KernelTimer timer; - - timer.Start(); + float total_time = 0; for(int i = 0; i < nrepeat; ++i) { // init DO, D1 to 0 - d0_m_device_buf.SetZero(); - d1_m_device_buf.SetZero(); + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); + + KernelTimer timer; + + timer.Start(); invoker.Run(argument); - } - timer.End(); + timer.End(); + + total_time += timer.GetElapsedTime(); + } - float ave_time = timer.GetElapsedTime() / nrepeat; + float ave_time = total_time / nrepeat; std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = @@ -229,9 +233,9 @@ int main(int argc, char* argv[]) if(do_verification) { - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - d0_m_device_buf.FromDevice(d0_m_device_result.mData.data()); - d1_m_device_buf.FromDevice(d1_m_device_result.mData.data()); + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + d0_device_buf.FromDevice(d0_m_device_result.mData.data()); + d1_device_buf.FromDevice(d1_m_device_result.mData.data()); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index 3ab0c7b2eba..da77193b101 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -11,7 +11,6 @@ enum struct ConvolutionForwardSpecialization_t Filter1x1Pad0, Filter1x1Stride1Pad0, OddC, - UnDefined, }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 471cdb8bf69..cc434c0d42d 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -207,45 +207,28 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ const index_t Ho = output_spatial_lengths[1]; const index_t Wo = output_spatial_lengths[2]; - if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) - { - static_assert(ConvForwardSpecialization == - ConvolutionForwardSpecialization_t::UnDefined, - "Not implemented!"); - } - else if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization_t::Filter1x1Pad0) - { - static_assert(ConvForwardSpecialization == - ConvolutionForwardSpecialization_t::UnDefined, - "Not implemented!"); - } - else - { - const auto in_desc_n_di_hi_wi_c = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - const auto wei_desc_k_z_y_x_c = - make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); - const auto out_desc_n_do_ho_wo_k = - make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); - - const auto descs = - transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad( - in_desc_n_di_hi_wi_c, - wei_desc_k_z_y_x_c, - out_desc_n_do_ho_wo_k, - make_tuple( - conv_filter_strides[0], conv_filter_strides[1], conv_filter_strides[2]), - make_tuple(conv_filter_dilations[0], - conv_filter_dilations[1], - conv_filter_dilations[2]), - make_tuple(input_left_pads[0], input_left_pads[1], input_left_pads[2]), - make_tuple(input_right_pads[0], input_right_pads[1], input_right_pads[2]), - Number{}); - - return descs; - } + static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Default, + "Wrong! This specialization not implemented!"); + + const auto in_desc_n_di_hi_wi_c = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + const auto wei_desc_k_z_y_x_c = + make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + const auto out_desc_n_do_ho_wo_k = + make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + + const auto descs = transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad( + in_desc_n_di_hi_wi_c, + wei_desc_k_z_y_x_c, + out_desc_n_do_ho_wo_k, + make_tuple(conv_filter_strides[0], conv_filter_strides[1], conv_filter_strides[2]), + make_tuple( + conv_filter_dilations[0], conv_filter_dilations[1], conv_filter_dilations[2]), + make_tuple(input_left_pads[0], input_left_pads[1], input_left_pads[2]), + make_tuple(input_right_pads[0], input_right_pads[1], input_right_pads[2]), + Number{}); + + return descs; } using ABCGridDescs = remove_cvref_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - remove_reference_t, - remove_reference_t, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, true>; launch_kernel(kernel, @@ -592,12 +585,11 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - remove_reference_t, - remove_reference_t, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, false>; launch_kernel(kernel, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp index bb685ddc072..0e0e092a993 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -394,19 +394,15 @@ struct DeviceGemm_Xdl_CShuffle : p_a_grid_{p_a_grid}, p_b_grid_{p_b_grid}, p_c_grid_{p_c_grid}, - a_grid_desc_ak0_m_ak1_{}, - b_grid_desc_bk0_n_bk1_{}, - c_grid_desc_m_n_{}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_ctile_map_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} { - a_grid_desc_ak0_m_ak1_ = DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA); - b_grid_desc_bk0_n_bk1_ = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB); - c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC); - if(GridwiseGemm::CheckValidity( a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) { @@ -460,8 +456,7 @@ struct DeviceGemm_Xdl_CShuffle if(!GridwiseGemm::CheckValidity( arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_)) { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); @@ -479,11 +474,10 @@ struct DeviceGemm_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - remove_reference_t, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, true>; launch_kernel(kernel, @@ -510,11 +504,10 @@ struct DeviceGemm_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - remove_reference_t, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, false>; launch_kernel(kernel, diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 15701855707..81ab64b51cb 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,6 +1,4 @@ -#ifndef CK_FLOAT_TYPE_AMD_HPP -#define CK_FLOAT_TYPE_AMD_HPP - +#pragma once #include "statically_indexed_array.hpp" namespace ck { @@ -937,7 +935,7 @@ __host__ __device__ Y type_convert(X x) // convert bfp16 to fp32 template <> -inline __host__ __device__ float type_convert(bhalf_t x) +__host__ __device__ float type_convert(bhalf_t x) { union { @@ -950,7 +948,7 @@ inline __host__ __device__ float type_convert(bhalf_t x) // convert fp32 to bfp16 template <> -inline __host__ __device__ bhalf_t type_convert(float x) +__host__ __device__ bhalf_t type_convert(float x) { union { @@ -1090,4 +1088,3 @@ struct NumericLimits }; } // namespace ck -#endif diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index 955b5b88a1e..41e5fdcd0d6 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -305,43 +305,47 @@ void bf16_to_f32_(const Tensor& src, Tensor& dst); template float check_error(const Tensor& ref, const Tensor& result) { - float error = 0; - float max_diff = -1; - float ref_value = 0, result_value = 0; + float l1_error = 0; + float linf_error = -1; + float linf_rel_error = -1; - if constexpr(std::is_same::value) + float linf_ref_value = 0, linf_result_value = 0; + float linf_rel_ref_value = 0, linf_rel_result_value = 0; + + constexpr float eps = 1e-10; + + for(int i = 0; i < ref.mData.size(); ++i) { - for(int i = 0; i < ref.mData.size(); ++i) + float ref_v = ck::type_convert(ref.mData[i]); + float result_v = ck::type_convert(result.mData[i]); + + float diff = std::abs(ref_v - result_v); + float rel_diff = diff / std::max(std::abs(ref_v), eps); + + l1_error += diff; + + if(linf_error < diff) { - error += std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i])); - float diff = std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i])); - if(max_diff < diff) - { - max_diff = diff; - ref_value = bf16_to_f32_(ref.mData[i]); - result_value = bf16_to_f32_(result.mData[i]); - } + linf_error = diff; + linf_ref_value = ref_v; + linf_result_value = result_v; } - } - else - { - for(int i = 0; i < ref.mData.size(); ++i) + + if(linf_rel_error < rel_diff) { - error += std::abs(float(ref.mData[i]) - float(result.mData[i])); - float diff = std::abs(float(ref.mData[i]) - float(result.mData[i])); - if(max_diff < diff) - { - max_diff = diff; - ref_value = float(ref.mData[i]); - result_value = float(result.mData[i]); - } + linf_rel_error = rel_diff; + linf_rel_ref_value = ref_v; + linf_rel_result_value = result_v; } } - std::cout << "error: " << error << std::endl; - std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; + std::cout << "Absolute Error L1 Norm (sum of abs diff): " << l1_error << std::endl; + std::cout << "Absolute Error L-inf Norm (max abs diff): " << linf_error << ", ref " + << linf_ref_value << ", result " << linf_result_value << std::endl; + std::cout << "Relative Error L-inf Norm (max relative abs diff): " << linf_rel_error << ", ref " + << linf_rel_ref_value << ", result " << linf_rel_result_value << std::endl; - return max_diff; + return linf_error; } template diff --git a/profiler/include/profile_gemm_reduce_impl.hpp b/profiler/include/profile_gemm_reduce_impl.hpp index a11a673ac23..8b3a85a2089 100644 --- a/profiler/include/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profile_gemm_reduce_impl.hpp @@ -243,9 +243,7 @@ bool profile_gemm_reduce_impl(int do_verification, invoker_ptr->Run(argument_ptr.get()); // timing - KernelTimer timer; - - timer.Start(); + float total_time = 0; for(int i = 0; i < nrepeat; ++i) { @@ -253,12 +251,18 @@ bool profile_gemm_reduce_impl(int do_verification, d0_device_buf.SetZero(); d1_device_buf.SetZero(); + KernelTimer timer; + + timer.Start(); + invoker_ptr->Run(argument_ptr.get()); - } - timer.End(); + timer.End(); + + total_time += timer.GetElapsedTime(); + } - float ave_time = timer.GetElapsedTime() / nrepeat; + float ave_time = total_time / nrepeat; std::string gemm_name = gemm_ptr->GetTypeString(); From ddad6a802d15568a4595ac0d3cf98b1f01926f35 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 23 Mar 2022 20:50:43 +0000 Subject: [PATCH 31/32] fix build --- include/ck/utility/data_type.hpp | 4 ++-- .../ck/library/host_tensor/host_tensor.hpp | 4 ++++ library/src/host_tensor/host_tensor.cpp | 4 ++++ profiler/src/profile_reduce.cpp | 20 ++++++++++--------- 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 81ab64b51cb..f1e541313c5 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -935,7 +935,7 @@ __host__ __device__ Y type_convert(X x) // convert bfp16 to fp32 template <> -__host__ __device__ float type_convert(bhalf_t x) +inline __host__ __device__ float type_convert(bhalf_t x) { union { @@ -948,7 +948,7 @@ __host__ __device__ float type_convert(bhalf_t x) // convert fp32 to bfp16 template <> -__host__ __device__ bhalf_t type_convert(float x) +inline __host__ __device__ bhalf_t type_convert(float x) { union { diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index 41e5fdcd0d6..c70c0e55328 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -298,9 +298,13 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector s void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); +#if 1 +// FIXME: remove float bf16_to_f32_(ck::bhalf_t src_val); +// FIXME: remove void bf16_to_f32_(const Tensor& src, Tensor& dst); +#endif template float check_error(const Tensor& ref, const Tensor& result) diff --git a/library/src/host_tensor/host_tensor.cpp b/library/src/host_tensor/host_tensor.cpp index 89b76f9a386..76d420e00b9 100644 --- a/library/src/host_tensor/host_tensor.cpp +++ b/library/src/host_tensor/host_tensor.cpp @@ -64,6 +64,8 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream os << "}" << std::endl; } +#if 1 +// FIXME: remove float bf16_to_f32_(ck::bhalf_t src_val) { union @@ -74,8 +76,10 @@ float bf16_to_f32_(ck::bhalf_t src_val) return u.fp32; } +// FIXME: remove void bf16_to_f32_(const Tensor& src, Tensor& dst) { for(int i = 0; i < src.mData.size(); ++i) dst.mData[i] = bf16_to_f32_(src.mData[i]); } +#endif diff --git a/profiler/src/profile_reduce.cpp b/profiler/src/profile_reduce.cpp index dda5adf4bba..b6a515b61f8 100644 --- a/profiler/src/profile_reduce.cpp +++ b/profiler/src/profile_reduce.cpp @@ -399,15 +399,16 @@ int profile_reduce(int argc, char* argv[]) else if(args.use_int8) { if(!args.compType_assigned) - args.compTypeId = appInt8; + args.compTypeId = appDataType_t::appInt8; - if(args.outType_assigned && (args.outTypeId != appInt8 && args.outTypeId != appInt32)) - args.outTypeId = appInt32; + if(args.outType_assigned && + (args.outTypeId != appDataType_t::appInt8 && args.outTypeId != appDataType_t::appInt32)) + args.outTypeId = appDataType_t::appInt32; if(!args.outType_assigned) - args.outTypeId = appInt8; + args.outTypeId = appDataType_t::appInt8; - if(args.compTypeId == appInt8) + if(args.compTypeId == appDataType_t::appInt8) { profile_reduce_impl(args.do_verification, args.init_method, @@ -422,7 +423,7 @@ int profile_reduce(int argc, char* argv[]) args.scales[0], args.scales[1]); } - else if(args.compTypeId == appInt32) + else if(args.compTypeId == appDataType_t::appInt32) { profile_reduce_impl(args.do_verification, args.init_method, @@ -442,11 +443,12 @@ int profile_reduce(int argc, char* argv[]) } else if(args.use_bf16) { - if(args.outType_assigned && (args.outTypeId != appBFloat16 && args.outTypeId != appFloat)) - args.outTypeId = appFloat; + if(args.outType_assigned && (args.outTypeId != appDataType_t::appBFloat16 && + args.outTypeId != appDataType_t::appFloat)) + args.outTypeId = appDataType_t::appFloat; if(!args.outType_assigned) - args.outTypeId = appBFloat16; + args.outTypeId = appDataType_t::appBFloat16; profile_reduce_impl(args.do_verification, args.init_method, From 39b3267d5076f847debe880bc3b046e1845c147e Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 23 Mar 2022 21:28:42 +0000 Subject: [PATCH 32/32] fix build --- .../gpu/device/convolution_forward_specialization.hpp | 8 ++++---- test/CMakeLists.txt | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index 5824ff1a67a..e37a9913f94 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -19,10 +19,10 @@ inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecializ { switch(s) { - case Default: return "Default"; - case Filter1x1Pad0: return "Filter1x1Pad0"; - case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; - case OddC: return "OddC"; + case ConvolutionForwardSpecialization_t::Default: return "Default"; + case ConvolutionForwardSpecialization_t::Filter1x1Pad0: return "Filter1x1Pad0"; + case ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + case ConvolutionForwardSpecialization_t::OddC: return "OddC"; default: return "Unrecognized specialization!"; } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 76add476c10..b3a7794c218 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -40,7 +40,6 @@ add_subdirectory(gemm_split_k) add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) -add_subdirectory(conv2d_fwd) add_subdirectory(convnd_fwd) add_subdirectory(conv2d_bwd_data) add_subdirectory(reduce)