From 2c53a47cbc0d830fca07ddc0f58851e9f27df53f Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 10 Feb 2022 00:02:45 +0000 Subject: [PATCH 01/11] clean up --- .../threadwise_tensor_slice_transfer_v3r1.hpp | 100 +++++------------- 1 file changed, 28 insertions(+), 72 deletions(-) diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp index 438f925306b..094e96fca9b 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp @@ -78,6 +78,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + static constexpr auto I0 = Number<0>{}; + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( const SrcDesc& src_desc, const Index& src_slice_origin, @@ -102,9 +104,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); } - template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) { static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -114,9 +115,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_same, remove_cvref_t>::value, "wrong! SrcBuffer and SrcData data type are inconsistent"); - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( @@ -138,8 +136,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_step( - src_desc, forward_step_idx, src_step_hacks[I0][i]); + return make_tensor_coordinate_step(src_desc, forward_step_idx); }, Number{}); @@ -152,8 +149,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_step( - src_desc, backward_step_idx, src_step_hacks[I1][i]); + return make_tensor_coordinate_step(src_desc, backward_step_idx); }, Number{}); @@ -348,9 +344,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 #endif } - template - __device__ void - RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) { // if there is transpose, it's done here // TODO move this elsewhere @@ -364,9 +359,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_same, remove_cvref_t>::value, "wrong! SrcBuffer or DstBuffer data type is wrong"); - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - // src scalar per access on each dim // TODO: don't use this constexpr auto dst_scalar_per_access = generate_sequence( @@ -388,8 +380,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_step( - dst_desc, forward_step_idx, dst_step_hacks[I0][i]); + return make_tensor_coordinate_step(dst_desc, forward_step_idx); }, Number{}); @@ -402,8 +393,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_step( - dst_desc, backward_step_idx, dst_step_hacks[I1][i]); + return make_tensor_coordinate_step(dst_desc, backward_step_idx); }, Number{}); @@ -515,39 +505,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 } } - template - __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) - { - constexpr index_t ntransform_src = remove_cvref_t::GetNumOfTransform(); - - constexpr auto zeros = typename uniform_sequence_gen::type{}; - - constexpr auto src_step_hacks = - make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), - generate_tuple([&](auto) { return zeros; }, Number{})); - - RunRead(src_desc, src_buf, src_step_hacks); - } - - template - __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) - { - // TODO: why need remove_cvref_t ? - constexpr index_t ntransform_dst = remove_cvref_t::GetNumOfTransform(); - - constexpr auto zeros = typename uniform_sequence_gen::type{}; - - constexpr auto dst_step_hacks = - make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), - generate_tuple([&](auto) { return zeros; }, Number{})); - - RunWrite(dst_desc, dst_buf, dst_step_hacks); - } - __device__ static constexpr auto GetSrcCoordinateResetStep() { - constexpr auto I0 = Number<0>{}; - // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( @@ -606,8 +565,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ static constexpr auto GetDstCoordinateResetStep() { - constexpr auto I0 = Number<0>{}; - // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( @@ -679,25 +636,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason - template - __device__ void - MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx, - const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) - { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step( - src_desc, adjusted_step_idx, src_move_slice_window_step_hack); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) @@ -815,6 +753,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; +#if 1 StaticTensorTupleOfVectorBuffer dst_thread_scratch_; +#else + using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; + + StaticallyIndexedArray src_thread_scratch_tuple_; + + DstThreadScratch dst_thread_scratch_; +#endif SrcCoord src_coord_; DstCoord dst_coord_; From a3fa42b17a6316e547dd83bcd87e7acc3bc8e023 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 10 Feb 2022 04:13:59 +0000 Subject: [PATCH 02/11] add mutilple thread scratch to ThreadwiseTensorSliceTransfer_v3r1 --- .../blockwise_tensor_slice_transfer_v4r1.hpp | 56 ++++++------------- .../gridwise_gemm_xdlops_v3r1.hpp | 22 ++++---- .../threadwise_tensor_slice_transfer_v3r1.hpp | 49 +++++++--------- .../include/device_gemm_xdl_c_shuffle.hpp | 2 - example/1_gemm_xdl/gemm_xdl.cpp | 19 +++++++ 5 files changed, 70 insertions(+), 78 deletions(-) diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp index b2722bf0786..69a0a6c1a1d 100644 --- a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp @@ -13,7 +13,8 @@ namespace ck { // 1. Use StaticallyIndexedArray instead of C array for thread buffer // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate -template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); + threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); } } - template - __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunRead(src_desc, src_buf); - } - } - - template - __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, Number thread_scratch_id) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunWrite(dst_desc, dst_buf); + threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); } } - template + template __device__ void Run(const SrcDesc& src_desc, const SrcBuffer& src_buf, const DstDesc& dst_desc, - DstBuffer& dst_buf) + DstBuffer& dst_buf, + Number thread_scratch_id) { - RunRead(src_desc, src_buf); - RunWrite(dst_desc, dst_buf); + RunRead(src_desc, src_buf, thread_scratch_id); + RunWrite(dst_desc, dst_buf, thread_scratch_id); } __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) @@ -136,21 +130,6 @@ struct BlockwiseTensorSliceTransfer_v4r1 } } - // SrcMoveSliceWindowStepHack to control index calculation move slice window - template - __device__ void - MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& step, - const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveSrcSliceWindow( - src_desc, step, src_move_slice_window_step_hack); - } - } - __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) { if(BlockSize == thread_cluster_desc_.GetElementSize() or @@ -165,7 +144,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); using ThreadwiseTransfer = - ThreadwiseTensorSliceTransfer_v3r1 - __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id) { static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -211,8 +214,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 }); // copy data from src_vector_container into src_thread_scratch_ - src_thread_scratch_.template SetAsType( - src_data_idx_seq, src_vector_container.template AsType()[I0]); + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); constexpr auto move_on_dim = [&]() constexpr { @@ -259,12 +263,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 } } - __device__ void TransferDataFromSrcThreadScratchToDstThreadScratch() + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) { #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE static_ford{}([&](auto idx) { // convert from SrcData to DstData here - dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple[thread_scratch_id][idx]); }); #else // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ @@ -314,7 +321,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 const auto src_vector_refs = generate_tie( [&](auto i) -> const src_vector_t& { // i increment corresponds to movement in DstVectorDim - return src_thread_scratch_.GetVectorTypeReference( + return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( data_idx_seq + i * dst_scalar_step_in_vector); }, Number{}); @@ -338,18 +345,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1 { static_ford{}([&](auto idx) { // convert from SrcData to DstData here - dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); }); } #endif } - template - __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, Number thread_scratch_id) { // if there is transpose, it's done here // TODO move this elsewhere - TransferDataFromSrcThreadScratchToDstThreadScratch(); + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -753,21 +762,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; -#if 1 - StaticTensorTupleOfVectorBuffer - src_thread_scratch_; - - StaticTensorTupleOfVectorBuffer - dst_thread_scratch_; -#else using SrcThreadScratch = StaticTensorTupleOfVectorBuffer src_thread_scratch_tuple_; DstThreadScratch dst_thread_scratch_; -#endif SrcCoord src_coord_; DstCoord dst_coord_; diff --git a/device_operation/include/device_gemm_xdl_c_shuffle.hpp b/device_operation/include/device_gemm_xdl_c_shuffle.hpp index 6127e6e6fef..125c1ace3ed 100644 --- a/device_operation/include/device_gemm_xdl_c_shuffle.hpp +++ b/device_operation/include/device_gemm_xdl_c_shuffle.hpp @@ -4,9 +4,7 @@ #include #include #include "device.hpp" -#include "device_base.hpp" #include "device_gemm.hpp" -#include "device_gemm_xdl.hpp" #include "common_header.hpp" #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index d9ed011fbeb..bf604a8fe62 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -31,6 +31,7 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; +#if 1 // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< ADataType, // ADataType @@ -71,6 +72,24 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl // clang-format on +#else +// clang-format off +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmInstance = ck::tensor_operation::device:: + //#####################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; +// clang-format on +#endif using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; From 8340275a8583cf3195eaa59da22d8d2dc5fd899c Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 10 Feb 2022 07:19:34 +0000 Subject: [PATCH 03/11] add 2 stage prefetch --- .../gridwise_gemm_xdlops_v2r3.hpp | 22 +- .../gridwise_gemm_xdlops_v4r1.hpp | 797 ++++++++++++++++++ device_operation/CMakeLists.txt | 159 ++-- .../device_gemm_xdl_c_shuffle_2_stage.hpp | 473 +++++++++++ ..._2_stage_f16_f16_f16_mk_nk_mn_instance.cpp | 56 ++ example/1_gemm_xdl/gemm_xdl.cpp | 17 +- profiler/CMakeLists.txt | 24 +- profiler/include/profile_gemm_impl.hpp | 9 + profiler/src/profiler.cpp | 14 +- 9 files changed, 1456 insertions(+), 115 deletions(-) create mode 100644 composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v4r1.hpp create mode 100644 device_operation/include/device_gemm_xdl_c_shuffle_2_stage.hpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 0db11aedeff..87c64cb58fd 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -419,7 +419,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1 +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v4r1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v4r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + + return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using Block2CTileMap = remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1<2, + BlockSize, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum_t::Set, + Sequence, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4r1<2, + BlockSize, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum_t::Set, + Sequence, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // preload data into LDS + { + // Read 0 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + + // Read 1 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + + // Write i + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); + + // Read i+2 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + + // Write i+1 + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); + + // Read i+3 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm i+1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + k0_block_data_begin += 2 * K0PerBlock; + } while(k0_block_data_begin < (K0 - 2 * K0PerBlock)); + } + + // tail + { + // Write K0-2 + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm K0-2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Write K0-1 + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm K0-1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum_t::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< + BlockSize, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck +#endif diff --git a/device_operation/CMakeLists.txt b/device_operation/CMakeLists.txt index d9a4ebb499c..9771c5302ae 100644 --- a/device_operation/CMakeLists.txt +++ b/device_operation/CMakeLists.txt @@ -26,86 +26,87 @@ set(DEVICE_GEMM_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; ) # device_gemm_bias_relu_instance -set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp; -) - -# device_gemm_bias_relu_add_instance -set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp; -) - -# device_conv2d_fwd_instance -set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -# device_conv2d_fwd_bias_relu_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -# device_conv2d_fwd_bias_relu_add_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -# device_conv2d_fwd_bias_relu_atomic_add_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) -add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) -add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) - -target_include_directories(device_gemm_instance SYSTEM PUBLIC $) -target_include_directories(device_gemm_bias_relu_instance SYSTEM PUBLIC $) -target_include_directories(device_gemm_bias_relu_add_instance SYSTEM PUBLIC $) -target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $) -target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $) -target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $) -target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $) - -target_compile_features(device_gemm_instance PUBLIC) -target_compile_features(device_gemm_bias_relu_instance PUBLIC) -target_compile_features(device_gemm_bias_relu_add_instance PUBLIC) -target_compile_features(device_conv2d_fwd_instance PUBLIC) -target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) -target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) -target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC) - -set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +#set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp; +#) +# +## device_gemm_bias_relu_add_instance +#set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp; +#) +# +## device_conv2d_fwd_instance +#set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; +#) +# +## device_conv2d_fwd_bias_relu_instance +#set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; +#) +# +## device_conv2d_fwd_bias_relu_add_instance +#set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; +#) +# +## device_conv2d_fwd_bias_relu_atomic_add_instance +#set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE +# ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; +#) -install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) -install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib) -install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) + add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) +#add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) +#add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) +#add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) +#add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) +#add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) +#add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) +# + target_include_directories(device_gemm_instance SYSTEM PUBLIC $) +#target_include_directories(device_gemm_bias_relu_instance SYSTEM PUBLIC $) +#target_include_directories(device_gemm_bias_relu_add_instance SYSTEM PUBLIC $) +#target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $) +#target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $) +#target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $) +#target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $) +# + target_compile_features(device_gemm_instance PUBLIC) +#target_compile_features(device_gemm_bias_relu_instance PUBLIC) +#target_compile_features(device_gemm_bias_relu_add_instance PUBLIC) +#target_compile_features(device_conv2d_fwd_instance PUBLIC) +#target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) +#target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) +#target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC) +# + set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +#set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +#set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +#set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +#set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +#set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +#set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +# + install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) +#install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib) +#install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib) +#install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) +#install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) +#install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) +#install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) diff --git a/device_operation/include/device_gemm_xdl_c_shuffle_2_stage.hpp b/device_operation/include/device_gemm_xdl_c_shuffle_2_stage.hpp new file mode 100644 index 00000000000..207f4073015 --- /dev/null +++ b/device_operation/include/device_gemm_xdl_c_shuffle_2_stage.hpp @@ -0,0 +1,473 @@ +#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_2_STAGE_HPP +#define DEVICE_GEMM_XDL_C_SHUFFLE_2_STAGE_HPP + +#include +#include +#include "device.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v4r1.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceGemmXdl_C_Shuffle_2_Stage + : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_k0_m_k1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_k0_m_k1; + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_k0_n_k1 = + transform_tensor_descriptor(b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_k0_n_k1; + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v4r1< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceGemmXdl_C_Shuffle_2_Stage::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceGemmXdl_C_Shuffle_2_Stage::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = + DeviceGemmXdl_C_Shuffle_2_Stage::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdl_C_Shuffle_2_Stage::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v4r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v4r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl_C_Shuffle_2_Stage" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..b6ed8fb8107 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,56 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_2_stage.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index bf604a8fe62..7564f0ce63f 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -12,6 +12,7 @@ #include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_c_shuffle_2_stage.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" @@ -31,9 +32,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; -#if 1 +#if 0 // clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_2_Stage< ADataType, // ADataType BDataType, // BDataType CDataType, // CDataType @@ -82,12 +83,12 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using DeviceGemmInstance = ck::tensor_operation::device:: - //#####################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_2_Stage + //|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; // clang-format on #endif diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 71e795b4d49..46717248a19 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -17,21 +17,21 @@ include_directories(BEFORE set(PROFILER_SOURCE src/profiler.cpp src/profile_gemm.cpp - src/profile_gemm_bias_relu.cpp - src/profile_gemm_bias_relu_add.cpp - src/profile_conv_fwd.cpp - src/profile_conv_fwd_bias_relu.cpp - src/profile_conv_fwd_bias_relu_add.cpp - src/profile_conv_fwd_bias_relu_atomic_add.cpp +# src/profile_gemm_bias_relu.cpp +# src/profile_gemm_bias_relu_add.cpp +# src/profile_conv_fwd.cpp +# src/profile_conv_fwd_bias_relu.cpp +# src/profile_conv_fwd_bias_relu_add.cpp +# src/profile_conv_fwd_bias_relu_atomic_add.cpp ) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE device_gemm_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) +#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) +#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 9962c6579d5..f01d795942a 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -30,6 +30,9 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(std::vector&); + +#if 0 void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); @@ -39,6 +42,7 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); +#endif } // namespace device_gemm_instance } // namespace device @@ -141,6 +145,7 @@ void profile_gemm_impl(int do_verification, if constexpr(is_same::value && is_same::value && is_same::value) { +#if 0 if constexpr(is_same::value && is_same::value && is_same::value) @@ -202,6 +207,7 @@ void profile_gemm_impl(int do_verification, add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); } } +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) @@ -225,6 +231,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 6855d5bdced..946495a6eca 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -6,12 +6,12 @@ #include int profile_gemm(int, char*[]); -int profile_gemm_bias_relu(int, char*[]); -int profile_gemm_bias_relu_add(int, char*[]); -int profile_conv_fwd(int, char*[]); -int profile_conv_fwd_bias_relu(int, char*[]); -int profile_conv_fwd_bias_relu_add(int, char*[]); -int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); +//int profile_gemm_bias_relu(int, char*[]); +//int profile_gemm_bias_relu_add(int, char*[]); +//int profile_conv_fwd(int, char*[]); +//int profile_conv_fwd_bias_relu(int, char*[]); +//int profile_conv_fwd_bias_relu_add(int, char*[]); +//int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); int main(int argc, char* argv[]) { @@ -19,6 +19,7 @@ int main(int argc, char* argv[]) { return profile_gemm(argc, argv); } +#if 0 if(strcmp(argv[1], "gemm_bias_relu") == 0) { return profile_gemm_bias_relu(argc, argv); @@ -43,6 +44,7 @@ int main(int argc, char* argv[]) { return profile_conv_fwd_bias_relu_atomic_add(argc, argv); } +#endif else { // clang-format off From 447b9f20a92e655d2d6150202e92f0eda9c93a72 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 12 Feb 2022 18:05:48 +0000 Subject: [PATCH 04/11] add more sanity check into transform_tensor_descriptor --- .../include/tensor_description/tensor_descriptor.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index 8f6a5a3e43c..9cd51c61d66 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -307,6 +307,10 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, { // sanity check { + static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() && + NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(), + "wrong! inconsitent number of transform"); + constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, NewLowerDimensionOldVisibleIdss{}); From b6885ac5a00d2ed15b382f873e989381003bb27b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 13 Feb 2022 05:09:59 +0000 Subject: [PATCH 05/11] tweak --- .../gridwise_gemm_xdlops_v2r3.hpp | 49 +++++++ .../gridwise_gemm_xdlops_v3r1.hpp | 49 +++++++ device_operation/include/device_gemm_xdl.hpp | 7 +- example/1_gemm_xdl/gemm_xdl.cpp | 130 ++++++++++-------- example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp | 73 +++++----- profiler/CMakeLists.txt | 28 ++-- profiler/include/profile_gemm_impl.hpp | 9 +- profiler/src/profiler.cpp | 16 ++- 8 files changed, 238 insertions(+), 123 deletions(-) diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 0db11aedeff..e592e891dfe 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -513,6 +513,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); +#if 0 // preload data into LDS { a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); @@ -558,6 +559,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } +#else + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#endif // output: register to global memory { diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp index 3022f3f0fc8..ff3f7b3bca9 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp @@ -470,6 +470,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); +#if 0 // preload data into LDS { a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); @@ -515,6 +516,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } +#else + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#endif // shuffle C and write out { diff --git a/device_operation/include/device_gemm_xdl.hpp b/device_operation/include/device_gemm_xdl.hpp index 956c66819eb..5a49a47f3c5 100644 --- a/device_operation/include/device_gemm_xdl.hpp +++ b/device_operation/include/device_gemm_xdl.hpp @@ -493,7 +493,12 @@ struct DeviceGemmXdl << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ">"; // clang-format on diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index 7564f0ce63f..bf2ab1d5490 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -11,14 +11,24 @@ #include "host_tensor_generator.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" #include "device_gemm_xdl_c_shuffle.hpp" #include "device_gemm_xdl_c_shuffle_2_stage.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" +#include "gemm_specialization.hpp" template using S = ck::Sequence; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using ADataType = ck::half_t; using BDataType = ck::half_t; using CDataType = ck::half_t; @@ -32,63 +42,69 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding; + #if 0 -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_2_Stage< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl< + ADataType, // ADataType + BDataType, // BDataType + CDataType, // CDataType + AccDataType, // AccDataType + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmDefault, // GemmSpecialization + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector +#elif 0 +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl + // clang-format off + //#|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // 99TFlops, 1 wave per SIMD, limited by LDS, not enough blocks + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; + // 99TFlops, 1 wave per SIMD, not enough blocks + //< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; + // 97TFlops, 2 wave per SIMD + //< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 8, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; // clang-format on -#else -// clang-format off -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_2_Stage - //|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; +#elif 1 +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle + // clang-format off + //#|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 1, 9, S<1, 1, 8, 1, 9, 2>, 8>; // clang-format on #endif @@ -219,8 +235,8 @@ int main(int argc, char* argv[]) float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); diff --git a/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp b/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp index 4c62a7af152..a2085c8f0e2 100644 --- a/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp +++ b/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp @@ -12,6 +12,7 @@ #include "device_tensor.hpp" #include "tensor_layout.hpp" #include "element_wise_operation.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" #include "reference_conv_fwd.hpp" @@ -34,45 +35,41 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; -// clang-format off using DeviceConvFwdInstance = ck::tensor_operation::device:: - DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< - InDataType, // InDataType - WeiDataType, // WeiDataType - OutDataType, // OutDataType - AccDataType, // AccDataType - InElementOp, // InElementwiseOperation - WeiElementOp, // WeiElementwiseOperation - OutElementOp, // OutElementwiseOperation - ConvFwdDefault, // ConvForwardSpecialization - 256, // BlockSize - 128, // MPerBlock - 256, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl -// clang-format on + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvFwdDefault, // ConvForwardSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector using ReferenceConvFwdInstance = ck::tensor_operation::host:: ReferenceConvFwd; diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index a25e64f5bab..6ab9c2ae4ef 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -17,23 +17,23 @@ include_directories(BEFORE set(PROFILER_SOURCE src/profiler.cpp src/profile_gemm.cpp - src/profile_gemm_bias_relu.cpp - src/profile_gemm_bias_relu_add.cpp - src/profile_conv_fwd.cpp - src/profile_conv_fwd_bias_relu.cpp - src/profile_conv_fwd_bias_relu_add.cpp - src/profile_conv_fwd_bias_relu_atomic_add.cpp - src/profile_batched_gemm.cpp +# src/profile_gemm_bias_relu.cpp +# src/profile_gemm_bias_relu_add.cpp +# src/profile_conv_fwd.cpp +# src/profile_conv_fwd_bias_relu.cpp +# src/profile_conv_fwd_bias_relu_add.cpp +# src/profile_conv_fwd_bias_relu_atomic_add.cpp +# src/profile_batched_gemm.cpp ) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE device_gemm_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) -target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) -target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) +#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) +#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) +#target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 969edba322b..b3924e44a19 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -1,4 +1,5 @@ #pragma once +#include #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -33,7 +34,6 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); -#if 0 void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); @@ -43,7 +43,6 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); -#endif } // namespace device_gemm_instance } // namespace device @@ -146,7 +145,6 @@ void profile_gemm_impl(int do_verification, if constexpr(is_same::value && is_same::value && is_same::value) { -#if 0 if constexpr(is_same::value && is_same::value && is_same::value) @@ -208,7 +206,6 @@ void profile_gemm_impl(int do_verification, add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); } } -#endif } else if constexpr(is_same::value && is_same::value && is_same::value) @@ -303,8 +300,8 @@ void profile_gemm_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm_name << std::endl; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << gemm_name << std::endl; if(tflops > best_tflops) { diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 399ea8ee4db..761a992ed7a 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -6,13 +6,13 @@ #include int profile_gemm(int, char*[]); -int profile_batched_gemm(int, char*[]); -int profile_gemm_bias_relu(int, char*[]); -int profile_gemm_bias_relu_add(int, char*[]); -int profile_conv_fwd(int, char*[]); -int profile_conv_fwd_bias_relu(int, char*[]); -int profile_conv_fwd_bias_relu_add(int, char*[]); -int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); +// int profile_batched_gemm(int, char*[]); +// int profile_gemm_bias_relu(int, char*[]); +// int profile_gemm_bias_relu_add(int, char*[]); +// int profile_conv_fwd(int, char*[]); +// int profile_conv_fwd_bias_relu(int, char*[]); +// int profile_conv_fwd_bias_relu_add(int, char*[]); +// int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); int main(int argc, char* argv[]) { @@ -20,6 +20,7 @@ int main(int argc, char* argv[]) { return profile_gemm(argc, argv); } +#if 0 else if(strcmp(argv[1], "gemm_bias_relu") == 0) { return profile_gemm_bias_relu(argc, argv); @@ -48,6 +49,7 @@ int main(int argc, char* argv[]) { return profile_conv_fwd_bias_relu_atomic_add(argc, argv); } +#endif else { // clang-format off From d55aa1c22690233461fa19a299341642c7ab5231 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 14 Feb 2022 03:57:17 +0000 Subject: [PATCH 06/11] enabling 2 stage prefetch to exsiting gridwise gemm; tweak --- .../gridwise_gemm_xdlops_v2r3.hpp | 233 +++++++++++++----- .../gridwise_gemm_xdlops_v3r1.hpp | 227 ++++++++++++----- device_operation/include/device_gemm_xdl.hpp | 6 +- example/1_gemm_xdl/gemm_xdl.cpp | 101 ++++---- 4 files changed, 389 insertions(+), 178 deletions(-) diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index e592e891dfe..6b5603f17c9 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -148,7 +148,8 @@ template + index_t CThreadTransferDstScalarPerVector, + index_t NumPrefetch = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { static constexpr auto I0 = Number<0>{}; @@ -439,7 +440,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 1, 1, AThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -469,7 +471,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 1, 1, BThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -513,100 +516,210 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); -#if 0 - // preload data into LDS + static_assert(NumPrefetch == 1 || NumPrefetch == 2, "wrong!"); + + if constexpr(NumPrefetch == 1) { +#if 0 + // preload data into LDS a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + // Initialize C + c_thread_buf.Clear(); + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } - // Initialize C - c_thread_buf.Clear(); + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, + b_block_slice_copy_step); - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + block_sync_lds(); - block_sync_lds(); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#elif 1 + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } + // Initialize C + c_thread_buf.Clear(); - // tail - { - block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } -#else - // preload data into LDS - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + do + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - // Initialize C - c_thread_buf.Clear(); + block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - do - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + block_sync_lds(); - block_sync_lds(); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, + b_block_slice_copy_step); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + // tail + { block_sync_lds(); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#endif + } + else if constexpr(NumPrefetch == 2) + { + // preload data into LDS + { + // Read 0 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); + + // Move a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + // Read 1 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); + } - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } + // Initialize C + c_thread_buf.Clear(); - // tail - { - block_sync_lds(); + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, + b_block_slice_copy_step); + + // Write i + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); + + // Read i+2 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, + b_block_slice_copy_step); + + // Write i+1 + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); + + // Read i+3 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); + + // Sync + block_sync_lds(); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + // Gemm i+1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + k0_block_data_begin += 2 * K0PerBlock; + } while(k0_block_data_begin < (K0 - 2 * K0PerBlock)); + } + + // tail + { + // Write K0-2 + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm K0-2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Write K0-1 + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm K0-1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } } -#endif // output: register to global memory { diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp index ff3f7b3bca9..dedd895559f 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp @@ -95,7 +95,8 @@ template < index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumPrefetch = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 { static constexpr auto I0 = Number<0>{}; @@ -470,100 +471,210 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); -#if 0 - // preload data into LDS + static_assert(NumPrefetch == 1 || NumPrefetch == 2, "wrong!"); + + if constexpr(NumPrefetch == 1) { +#if 0 + // preload data into LDS a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + // Initialize C + c_thread_buf.Clear(); + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } - // Initialize C - c_thread_buf.Clear(); + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, + b_block_slice_copy_step); - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + block_sync_lds(); - block_sync_lds(); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#elif 1 + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } + // Initialize C + c_thread_buf.Clear(); - // tail - { - block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } -#else - // preload data into LDS - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); + do + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - // Initialize C - c_thread_buf.Clear(); + block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - do - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + block_sync_lds(); - block_sync_lds(); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, + b_block_slice_copy_step); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + // tail + { block_sync_lds(); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#endif + } + else if constexpr(NumPrefetch == 2) + { + // preload data into LDS + { + // Read 0 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); + + // Move a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); + // Read 1 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); + } - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } + // Initialize C + c_thread_buf.Clear(); - // tail - { - block_sync_lds(); + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, + b_block_slice_copy_step); + + // Write i + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + // Read i+2 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, + b_block_slice_copy_step); + + // Write i+1 + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); + + // Read i+3 + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm i+1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + k0_block_data_begin += 2 * K0PerBlock; + } while(k0_block_data_begin < (K0 - 2 * K0PerBlock)); + } + + // tail + { + // Write K0-2 + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm K0-2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Write K0-1 + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm K0-1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } } -#endif // shuffle C and write out { diff --git a/device_operation/include/device_gemm_xdl.hpp b/device_operation/include/device_gemm_xdl.hpp index 5a49a47f3c5..188cf26080a 100644 --- a/device_operation/include/device_gemm_xdl.hpp +++ b/device_operation/include/device_gemm_xdl.hpp @@ -52,7 +52,8 @@ template + ck::index_t CThreadTransferDstScalarPerVector, + ck::index_t NumPrefetch = 1> struct DeviceGemmXdl : public DeviceGemm { @@ -218,7 +219,8 @@ struct DeviceGemmXdl BBlockLdsAddExtraN, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector>; + CThreadTransferDstScalarPerVector, + NumPrefetch>; // Argument struct Argument : public BaseArgument diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index bf2ab1d5490..ed613f69473 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -45,68 +45,53 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding; -#if 0 -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - GemmDefault, // GemmSpecialization - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 7, // CThreadTransferSrcDstVectorDim - 1>; // CThreadTransferDstScalarPerVector -#elif 0 +// clang-format off +#if 1 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl - // clang-format off - //#|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //#| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //#| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //#| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // 99TFlops, 1 wave per SIMD, limited by LDS, not enough blocks - < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; - // 99TFlops, 1 wave per SIMD, not enough blocks - //< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; - // 97TFlops, 2 wave per SIMD - //< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 8, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; -// clang-format on +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// [256, 128, 4, 8], 1 stage, 2 occupancy + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; #elif 1 +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// [128, 144, 8, 8], 1 stage, 1 occupancy, bounded by LDS size +// 99 TFlops, 120 blocks (1024x2160x3840) +// 99 TFlops, 960 blocks (4096x4320x3840) + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; +// [128, 144, 4, 8], 1 stage, 2 occupancy, +// 92 TFlops, 120 blocks (1024x2160x3840) +// 120 TFlops, 240 blocks (1024x4320x3840) +// 128 TFlops, 960 blocks (4096x4320x3840) +// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; +// [ 64, 144, 8, 8], 1 stage, 2 occupancy/ +// 96 TFlops, 240 blocks (1024x2160x3840) +// 96 TFlops, 480 blocks (1024x4320x3840) +// 99 TFlops,1920 blocks (4096x4320x3840) +// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 8, 8, 16, 16, 1, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; +// [ 64, 144, 8, 8], 2 stage, 2 occupancy +// 93 TFlops +// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 8, 8, 16, 16, 1, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 2>; +// [ 64, 144, 4, 8], 1 stage, 2 occupancy +// 87 TFlops +// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 4, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; +// [ 64, 144, 4, 8], 2 stage, 2 occupancy +// 85 TFlops +// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 4, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 2>; +#elif 0 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle - // clang-format off - //#|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 1, 9, S<1, 1, 8, 1, 9, 2>, 8>; -// clang-format on +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 1, 9, S<1, 1, 8, 1, 9, 2>, 8>; #endif +// clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; From f3caf6edc1e462c64f29d1aaa6ec4ef646b97f32 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 14 Feb 2022 04:35:11 +0000 Subject: [PATCH 07/11] enabling 2 stage prefetch to exsiting gridwise gemm --- .../gridwise_gemm_xdlops_v2r3.hpp | 19 + .../gridwise_gemm_xdlops_v3r1.hpp | 25 +- .../gridwise_gemm_xdlops_v4r1.hpp | 795 ------------------ .../include/device_gemm_xdl_c_shuffle.hpp | 6 +- .../device_gemm_xdl_c_shuffle_2_stage.hpp | 473 ----------- ..._2_stage_f16_f16_f16_mk_nk_mn_instance.cpp | 36 +- example/1_gemm_xdl/gemm_xdl.cpp | 20 +- 7 files changed, 74 insertions(+), 1300 deletions(-) delete mode 100644 composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v4r1.hpp delete mode 100644 device_operation/include/device_gemm_xdl_c_shuffle_2_stage.hpp diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 6b5603f17c9..eaca45af2ea 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -253,6 +253,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; + // check NumPrefetch + if constexpr(NumPrefetch == 1) + { + // 1-stage prefetch always supported + } + else if constexpr(NumPrefetch == 2) + { + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K0 / K0PerBlock) % 2 == 0)) + { + return false; + } + } + else + { + return false; + } + // check M01, N01 constexpr auto M1 = Number{}; constexpr auto N1 = Number{}; diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp index dedd895559f..54a9fdfded2 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp @@ -229,6 +229,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; + // check NumPrefetch + if constexpr(NumPrefetch == 1) + { + // 1-stage prefetch always supported + } + else if constexpr(NumPrefetch == 2) + { + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K0 / K0PerBlock) % 2 == 0)) + { + return false; + } + } + else + { + return false; + } + // check M01, N01 constexpr auto M1 = Number{}; constexpr auto N1 = Number{}; @@ -397,7 +416,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 1, 1, AThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -427,7 +447,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 1, 1, BThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v4r1.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v4r1.hpp deleted file mode 100644 index ae781e0ed6c..00000000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v4r1.hpp +++ /dev/null @@ -1,795 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_v4r1_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_v4r1_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "blockwise_tensor_slice_transfer_v6r1.hpp" -#include "threadwise_tensor_slice_transfer.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v4r1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) -{ - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); -} - -template < - index_t BlockSize, - typename FloatAB, - typename FloatAcc, - typename FloatC, - InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, - typename AGridDesc_K0_M_K1, - typename BGridDesc_K0_N_K1, - typename CGridDesc_M_N, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - index_t MPerBlock, - index_t NPerBlock, - index_t K0PerBlock, - index_t MPerXdl, - index_t NPerXdl, - index_t K1Value, - index_t MXdlPerWave, - index_t NXdlPerWave, - typename ABlockTransferThreadClusterLengths_K0_M_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - index_t ABlockTransferSrcVectorDim, - index_t ABlockTransferSrcScalarPerVector, - index_t ABlockTransferDstScalarPerVector_K1, - bool AThreadTransferSrcResetCoordinateAfterRun, - bool ABlockLdsExtraM, - typename BBlockTransferThreadClusterLengths_K0_N_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - index_t BBlockTransferSrcVectorDim, - index_t BBlockTransferSrcScalarPerVector, - index_t BBlockTransferDstScalarPerVector_K1, - bool BThreadTransferSrcResetCoordinateAfterRun, - bool BBlockLdsExtraN, - index_t CShuffleMXdlPerWavePerShuffle, - index_t CShuffleNXdlPerWavePerShuffle, - typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> -struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v4r1 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto K1 = Number{}; - - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - return a_block_desc_k0_m_k1; - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() - { - constexpr auto max_lds_align = K1; - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - return b_block_desc_k0_n_k1; - } - - __host__ __device__ static constexpr auto - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; - } - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - - constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - - constexpr auto max_lds_align = K1; - - constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); - - constexpr auto c_block_size = - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl - .GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), - c_block_size * sizeof(FloatC)); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01, - index_t N01) - { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, - "Invalid tuning param!"); - - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) - return false; - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; - - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) - return false; - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) - { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; - - return has_main_k0_block_loop; - } - - __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const auto MBlock = M / MPerBlock; - const auto NBlock = N / NPerBlock; - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple( - MBlock, Number{}, Number{})), - make_unmerge_transform(make_tuple( - NBlock, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; - } - - // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return c_blockid_to_m0_n0_block_cluster_adaptor; - } - using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; - - using Block2CTileMap = remove_cvref_t; - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - void* __restrict__ p_shared, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl - .GetElementSpaceSize()); - - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - // divide block work by [M, N] - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - 2>(a_grid_desc_k0_m_k1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_k0_m_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - 2>(b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0_n_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[K0PerBlock, MPerBlock] is in LDS - // b_mtx[K0PerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - - auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); - - auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, - b_block_desc_k0_n_k1.GetElementSpaceSize()); - - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - - // preload data into LDS - { - // Read 0 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); - - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - // Read 1 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - // Write i - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); - - // Read i+2 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); - - // Sync - block_sync_lds(); - - // Gemm i - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - // Sync - block_sync_lds(); - - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - // Write i+1 - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); - - // Read i+3 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); - - // Sync - block_sync_lds(); - - // Gemm i+1 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - // Sync - block_sync_lds(); - - k0_block_data_begin += 2 * K0PerBlock; - } while(k0_block_data_begin < (K0 - 2 * K0PerBlock)); - } - - // tail - { - // Write K0-2 - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); - - // Sync - block_sync_lds(); - - // Gemm K0-2 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - // Sync - block_sync_lds(); - - // Write K0-1 - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); - - // Sync - block_sync_lds(); - - // Gemm K0-1 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - - // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); - - auto c_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl - .GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - make_tuple( - make_freeze_transform(I0), // freeze mblock - make_pass_through_transform( - Number{}), // M0 (MXdlPerWave) per shuffle - make_unmerge_transform( - make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_freeze_transform(I0), // freeze nblock - make_pass_through_transform( - Number{}), // N0 (NXdlPerWave) per shuffle - make_unmerge_transform( - make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<>{}, - Sequence<0>{}, - Sequence<2, 4, 5, 6>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<3, 7>{}) - - ); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< - BlockSize, // index_t BlockSize, - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle, - MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle, - NWave * NPerXdl>, // BlockSliceLengths, - CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, - FloatC, // typename SrcData, - FloatC, // typename DstData, - decltype( - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), - decltype( - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), - Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, - 5, // index_t VectorDim, - CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - make_multi_index(0, 0, 0, 0, 0, 0), - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), - c_element_op}; - - constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); - constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); - constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); - - static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { - constexpr auto mxdlperwave = mxdlperwave_iter; - - static_for<0, - NXdlPerWave, - CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { - constexpr bool nxdlperwave_forward_sweep = - (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); - - constexpr index_t nxdlperwave_value = - nxdlperwave_forward_sweep - ? nxdlperwave_iter - : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); - - constexpr auto nxdlperwave = Number{}; - - // make sure it's safe to do ds_write - block_sync_lds(); - - // VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_block_buf); - - // make sure it's safe to do ds_read - block_sync_lds(); - - // LDS to global - c_block_copy_lds_to_global.Run( - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - c_block_buf, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - c_grid_buf); - - // move on nxdlperwave dimension - if constexpr(nxdlperwave_forward_sweep && - (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - nxdlperwave_forward_step); - } - else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - nxdlperwave_backward_step); - } - }); - - // move on mxdlperwave dimension - if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - mxdlperwave_forward_step); - } - }); - } - } -}; - -} // namespace ck -#endif diff --git a/device_operation/include/device_gemm_xdl_c_shuffle.hpp b/device_operation/include/device_gemm_xdl_c_shuffle.hpp index 6e58fdec7a2..0b7a01c138c 100644 --- a/device_operation/include/device_gemm_xdl_c_shuffle.hpp +++ b/device_operation/include/device_gemm_xdl_c_shuffle.hpp @@ -52,7 +52,8 @@ template < index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumPrefetch = 1> struct DeviceGemmXdl_C_Shuffle : public DeviceGemm { @@ -172,7 +173,8 @@ struct DeviceGemmXdl_C_Shuffle CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - CBlockTransferScalarPerVector_NWaveNPerXdl>; + CBlockTransferScalarPerVector_NWaveNPerXdl, + NumPrefetch>; // Argument struct Argument : public BaseArgument diff --git a/device_operation/include/device_gemm_xdl_c_shuffle_2_stage.hpp b/device_operation/include/device_gemm_xdl_c_shuffle_2_stage.hpp deleted file mode 100644 index 6a1d53c72ca..00000000000 --- a/device_operation/include/device_gemm_xdl_c_shuffle_2_stage.hpp +++ /dev/null @@ -1,473 +0,0 @@ -#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_2_STAGE_HPP -#define DEVICE_GEMM_XDL_C_SHUFFLE_2_STAGE_HPP - -#include -#include -#include "device.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v4r1.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - ck::index_t BlockSize, - ck::index_t MPerBlock, - ck::index_t NPerBlock, - ck::index_t K0PerBlock, - ck::index_t K1, - ck::index_t MPerXDL, - ck::index_t NPerXDL, - ck::index_t MXdlPerWave, - ck::index_t NXdlPerWave, - typename ABlockTransferThreadClusterLengths_K0_M_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - ck::index_t ABlockTransferSrcVectorDim, - ck::index_t ABlockTransferSrcScalarPerVector, - ck::index_t ABlockTransferDstScalarPerVector_K1, - bool ABlockLdsAddExtraM, - typename BBlockTransferThreadClusterLengths_K0_N_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - ck::index_t BBlockTransferSrcVectorDim, - ck::index_t BBlockTransferSrcScalarPerVector, - ck::index_t BBlockTransferDstScalarPerVector_K1, - bool BBlockLdsAddExtraN, - index_t CShuffleMXdlPerWavePerShuffle, - index_t CShuffleNXdlPerWavePerShuffle, - typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> -struct DeviceGemmXdl_C_Shuffle_2_Stage - : public DeviceGemm -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - - static constexpr auto K1Number = Number{}; - - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_k0_m_k1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_k0_m_k1; - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto b_grid_desc_k0_n_k1 = - transform_tensor_descriptor(b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_k0_n_k1; - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v4r1< - BlockSize, - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CDataType, - InMemoryDataOperationEnum_t::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - MPerBlock, - NPerBlock, - K0PerBlock, - MPerXDL, - NPerXDL, - K1, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, - BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - CBlockTransferScalarPerVector_NWaveNPerXdl>; - - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} - { - a_grid_desc_k0_m_k1_ = - DeviceGemmXdl_C_Shuffle_2_Stage::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = - DeviceGemmXdl_C_Shuffle_2_Stage::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c_grid_desc_m_n_ = - DeviceGemmXdl_C_Shuffle_2_Stage::MakeCGridDescriptor_M_N(M, N, StrideC); - - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - - block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceGemmXdl_C_Shuffle_2_Stage::Argument; - - float Run(const Argument& arg, int nrepeat = 1) - { - { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); - } - - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - float ave_time = 0; - - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v4r1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - else - { - const auto kernel = kernel_gemm_xdlops_v4r1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override - { - return Run(*dynamic_cast(p_arg), nrepeat); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - CDataType* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{p_a, - p_b, - p_c, - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t /* KBatch */ = 1) override - { - return std::make_unique(static_cast(p_a), - static_cast(p_b), - static_cast(p_c), - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmXdl_C_Shuffle_2_Stage" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << K0PerBlock - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp index b6ed8fb8107..ee25f2ba40f 100644 --- a/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_gemm_xdl_c_shuffle_2_stage.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -23,23 +23,23 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off - //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceGemmXdl_C_Shuffle_2_Stage< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + //#####################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Num| + //#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Prefetch| + //#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2> // clang-format on >; diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index ed613f69473..cf6d24634d5 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -13,7 +13,6 @@ #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" #include "device_gemm_xdl_c_shuffle.hpp" -#include "device_gemm_xdl_c_shuffle_2_stage.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -46,15 +45,15 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpeciali static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding; // clang-format off -#if 1 +#if 0 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // [256, 128, 4, 8], 1 stage, 2 occupancy - < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; -#elif 1 + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 2>; +#elif 0 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| @@ -83,13 +82,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl // [ 64, 144, 4, 8], 2 stage, 2 occupancy // 85 TFlops // < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 4, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 2>; -#elif 0 +#elif 1 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle -//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -//######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 1, 9, S<1, 1, 8, 1, 9, 2>, 8>; +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Num| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Prefetch| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// [128, 144, 8, 8], 1 stage, 1 occupancy, bounded by LDS size + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 1, 9, S<1, 1, 8, 1, 9, 2>, 8, 1>; #endif // clang-format on From 4f5950785d76bcbc288781efb9145609e7df7813 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 14 Feb 2022 05:51:15 +0000 Subject: [PATCH 08/11] move gridwise gemm pipeline in class; clean up --- .../gridwise_gemm_pipeline_v1.hpp | 245 +++++++ .../gridwise_gemm_xdlops_v2r3.hpp | 289 ++------ .../gridwise_gemm_xdlops_v2r5.hpp | 635 ------------------ .../gridwise_gemm_xdlops_v2r6.hpp | 617 ----------------- .../gridwise_gemm_xdlops_v3r1.hpp | 247 ++----- .../gridwise_gemm_xdlops_v3r2.hpp | 116 ++-- .../gridwise_gemm_xdlops_v3r3.hpp | 110 +-- example/1_gemm_xdl/gemm_xdl.cpp | 4 +- 8 files changed, 474 insertions(+), 1789 deletions(-) create mode 100644 composable_kernel/include/tensor_operation/gridwise_gemm_pipeline_v1.hpp delete mode 100644 composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp delete mode 100644 composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_pipeline_v1.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_pipeline_v1.hpp new file mode 100644 index 00000000000..bc6068ead03 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_pipeline_v1.hpp @@ -0,0 +1,245 @@ +#ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP +#define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP + +#include "common_header.hpp" + +namespace ck { + +template +struct GridwiseGemmPipeline_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t NumLoop) + { + static_assert(NumPrefetch == 1 || NumPrefetch == 2, "wrong!"); + + if constexpr(NumPrefetch == 1) + { +#if 0 + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (NumLoop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#else + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (NumLoop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#endif + } + else if constexpr(NumPrefetch == 2) + { + // preload data into LDS + { + // Read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Read i+2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i+1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Read i+3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm i+1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + i += 2; + } while(i < (NumLoop - 2)); + } + + // tail + { + // Write NumLoop - 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm NumLoop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Write NumLoop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm NumLoop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index eaca45af2ea..dd43e870cb0 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -8,6 +8,7 @@ #include "blockwise_gemm_xdlops.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" namespace ck { @@ -22,7 +23,7 @@ template + bool HasMainK0BlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -41,17 +42,17 @@ __global__ void { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER template (p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); } #endif @@ -394,7 +395,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -535,210 +536,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - static_assert(NumPrefetch == 1 || NumPrefetch == 2, "wrong!"); - - if constexpr(NumPrefetch == 1) - { -#if 0 - // preload data into LDS - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - // Initialize C - c_thread_buf.Clear(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, - a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, - b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } -#elif 1 - // preload data into LDS - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - // Initialize C - c_thread_buf.Clear(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, - a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, - b_block_slice_copy_step); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } -#endif - } - else if constexpr(NumPrefetch == 2) - { - // preload data into LDS - { - // Read 0 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); - - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - // Read 1 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, - a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, - b_block_slice_copy_step); - - // Write i - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); - - // Read i+2 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); - - // Sync - block_sync_lds(); - - // Gemm i - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - // Sync - block_sync_lds(); - - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, - a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, - b_block_slice_copy_step); - - // Write i+1 - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); - - // Read i+3 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); - - // Sync - block_sync_lds(); - - // Gemm i+1 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - // Sync - block_sync_lds(); - - k0_block_data_begin += 2 * K0PerBlock; - } while(k0_block_data_begin < (K0 - 2 * K0PerBlock)); - } - - // tail - { - // Write K0-2 - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); - - // Sync - block_sync_lds(); - - // Gemm K0-2 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - // Sync - block_sync_lds(); - - // Write K0-1 - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); - - // Sync - block_sync_lds(); - - // Gemm K0-1 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - } + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumPrefetch, + HasMainK0BlockLoop>{}; + + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // output: register to global memory { diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp deleted file mode 100644 index 986809de9c6..00000000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp +++ /dev/null @@ -1,635 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R5_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R5_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "threadwise_tensor_slice_transfer_v1r4.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r5( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const FloatC* __restrict__ p_c1_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_c0_grid, - p_c1_grid, - p_shared_block, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); -} - -template -struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto K1 = Number{}; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01, - index_t N01) - { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, - "Invalid tuning param!"); - - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) - return false; - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; - - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) - return false; - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) - { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; - - return has_main_k0_block_loop; - } - - // TODO fix this - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N_any& c_grid_desc_m_n) - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - - return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - } - - // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return c_blockid_to_m0_n0_block_cluster_adaptor; - } - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - - using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{})); - - using C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C1GridDesc_M_N{})); - - using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const FloatC* __restrict__ p_c1_grid, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - auto c0_grid_buf = make_dynamic_buffer( - p_c0_grid, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - auto c1_grid_buf = make_dynamic_buffer( - p_c1_grid, c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - // divide block work by [M, N] - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_grid_desc_k0_m_k1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_k0_m_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0_n_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[K0PerBlock, MPerBlock] is in LDS - // b_mtx[K0PerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - - auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; - - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - - auto a_block_buf = make_dynamic_buffer( - p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); - - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); - - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_grid_idx = - m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_grid)); - - const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_grid_idx = - n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_grid)); - - auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r4, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2]), - c_element_op}; - - c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_grid_buf, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_buf, - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c1_grid_buf); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp deleted file mode 100644 index a96cd6e74ac..00000000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp +++ /dev/null @@ -1,617 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R6_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R6_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "threadwise_tensor_slice_transfer_v1r5.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r6( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_c0_grid, - p_shared_block, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); -} - -template -struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto K1 = Number{}; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01, - index_t N01) - { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, - "Invalid tuning param!"); - - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) - return false; - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; - - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) - return false; - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) - { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; - - return has_main_k0_block_loop; - } - - // TODO fix this - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N_any& c_grid_desc_m_n) - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - - return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - } - - // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return c_blockid_to_m0_n0_block_cluster_adaptor; - } - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - - using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{})); - - using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - auto c0_grid_buf = make_dynamic_buffer( - p_c0_grid, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - // divide block work by [M, N] - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_grid_desc_k0_m_k1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_k0_m_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0_n_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[K0PerBlock, MPerBlock] is in LDS - // b_mtx[K0PerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - - auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; - - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - - auto a_block_buf = make_dynamic_buffer( - p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); - - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); - - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_grid_idx = - m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_grid)); - - const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_grid_idx = - n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_grid)); - - auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r5, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2]), - c_element_op}; - - c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_grid_buf, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_buf); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp index 54a9fdfded2..55f6b8e6c6a 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp @@ -9,6 +9,7 @@ #include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v6r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" namespace ck { @@ -22,7 +23,7 @@ template + bool HasMainK0BlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -42,7 +43,7 @@ __global__ void { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, @@ -348,7 +349,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 using Block2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -492,210 +493,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - static_assert(NumPrefetch == 1 || NumPrefetch == 2, "wrong!"); - - if constexpr(NumPrefetch == 1) - { -#if 0 - // preload data into LDS - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - // Initialize C - c_thread_buf.Clear(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, - a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, - b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } -#elif 1 - // preload data into LDS - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - // Initialize C - c_thread_buf.Clear(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, - a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, - b_block_slice_copy_step); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } -#endif - } - else if constexpr(NumPrefetch == 2) - { - // preload data into LDS - { - // Read 0 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); - - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - // Read 1 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, - a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, - b_block_slice_copy_step); - - // Write i - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); - - // Read i+2 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I0); - - // Sync - block_sync_lds(); - - // Gemm i - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - // Sync - block_sync_lds(); - - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, - a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, - b_block_slice_copy_step); - - // Write i+1 - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); - - // Read i+3 - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, I1); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, I1); - - // Sync - block_sync_lds(); - - // Gemm i+1 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - // Sync - block_sync_lds(); - - k0_block_data_begin += 2 * K0PerBlock; - } while(k0_block_data_begin < (K0 - 2 * K0PerBlock)); - } - - // tail - { - // Write K0-2 - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I0); - - // Sync - block_sync_lds(); - - // Gemm K0-2 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - // Sync - block_sync_lds(); - - // Write K0-1 - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf, I1); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf, I1); - - // Sync - block_sync_lds(); - - // Gemm K0-1 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - } + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumPrefetch, + HasMainK0BlockLoop>{}; + + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // shuffle C and write out { diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp index 30059525c71..973d058526e 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp @@ -9,6 +9,7 @@ #include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v6r2.hpp" #include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" namespace ck { @@ -23,7 +24,7 @@ template + bool HasMainK0BlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -46,7 +47,7 @@ __global__ void { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, @@ -102,7 +103,8 @@ template < index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumPrefetch = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 { static constexpr auto I0 = Number<0>{}; @@ -235,6 +237,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; + // check NumPrefetch + if constexpr(NumPrefetch == 1) + { + // 1-stage prefetch always supported + } + else if constexpr(NumPrefetch == 2) + { + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K0 / K0PerBlock) % 2 == 0)) + { + return false; + } + } + else + { + return false; + } + // check M01, N01 constexpr auto M1 = Number{}; constexpr auto N1 = Number{}; @@ -341,7 +362,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 using Block2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -416,7 +437,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 1, 1, AThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -446,7 +468,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 1, 1, BThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -490,51 +513,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumPrefetch, + HasMainK0BlockLoop>{}; + + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // shuffle C and write out { diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp index 7601aa6a07e..3d14a223c6b 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp @@ -9,6 +9,7 @@ #include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v6r3.hpp" #include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" namespace ck { @@ -24,7 +25,7 @@ template + bool HasMainK0BlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -50,7 +51,7 @@ __global__ void { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, @@ -109,7 +110,8 @@ template < index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumPrefetch = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 { static constexpr auto I0 = Number<0>{}; @@ -242,6 +244,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; + // check NumPrefetch + if constexpr(NumPrefetch == 1) + { + // 1-stage prefetch always supported + } + else if constexpr(NumPrefetch == 2) + { + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K0 / K0PerBlock) % 2 == 0)) + { + return false; + } + } + else + { + return false; + } + // check M01, N01 constexpr auto M1 = Number{}; constexpr auto N1 = Number{}; @@ -353,7 +374,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 using Block2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -509,51 +530,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumPrefetch, + HasMainK0BlockLoop>{}; + + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // shuffle C and write out { diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index cf6d24634d5..9819bd2d998 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -45,14 +45,14 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpeciali static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding; // clang-format off -#if 0 +#if 1 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // [256, 128, 4, 8], 1 stage, 2 occupancy - < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 2>; + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; #elif 0 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| From b67083c1b3a36e93276418343a47e13bbdf2a057 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 14 Feb 2022 05:58:48 +0000 Subject: [PATCH 09/11] add some irregular tile size --- ...gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp | 49 ++++++++++++------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp index 41479f60e88..42b20fe21f7 100644 --- a/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -26,23 +26,36 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +// irregular tile size +using device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> // clang-format on >; @@ -50,6 +63,8 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); } } // namespace device_gemm_instance From a3311ebfd18d1c839cbc077e3dcb54a180ab6693 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 23 Feb 2022 06:45:02 +0000 Subject: [PATCH 10/11] update CalculateHasMainK0BlockLoop for multi-stage-prefetch --- .../gridwise_gemm_xdlops_v2r3.hpp | 3 +- .../gridwise_gemm_xdlops_v2r5.hpp | 635 ------------------ .../gridwise_gemm_xdlops_v2r6.hpp | 617 ----------------- .../gridwise_gemm_xdlops_v3r1.hpp | 3 +- .../gridwise_gemm_xdlops_v3r2.hpp | 3 +- .../gridwise_gemm_xdlops_v3r3.hpp | 3 +- 6 files changed, 8 insertions(+), 1256 deletions(-) delete mode 100644 composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp delete mode 100644 composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 64ebcf55f3b..47622ad148f 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -240,9 +240,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return grid_size; } + // TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; return has_main_k0_block_loop; } diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp deleted file mode 100644 index b4d7ef7d841..00000000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp +++ /dev/null @@ -1,635 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R5_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R5_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "threadwise_tensor_slice_transfer_v1r4.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r5( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const FloatC* __restrict__ p_c1_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_c0_grid, - p_c1_grid, - p_shared_block, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); -} - -template -struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto K1 = Number{}; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01, - index_t N01) - { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, - "Invalid tuning param!"); - - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) - return false; - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; - - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) - return false; - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) - { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; - - return has_main_k0_block_loop; - } - - // TODO fix this - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N_any& c_grid_desc_m_n) - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - - return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - } - - // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return cblockid_to_m0_n0_block_cluster_adaptor; - } - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - - using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{})); - - using C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C1GridDesc_M_N{})); - - using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const FloatC* __restrict__ p_c1_grid, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - auto c0_grid_buf = make_dynamic_buffer( - p_c0_grid, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - auto c1_grid_buf = make_dynamic_buffer( - p_c1_grid, c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - // divide block work by [M, N] - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_grid_desc_k0_m_k1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_k0_m_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0_n_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[K0PerBlock, MPerBlock] is in LDS - // b_mtx[K0PerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - - auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; - - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - - auto a_block_buf = make_dynamic_buffer( - p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); - - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); - - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_grid_idx = - m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_grid)); - - const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_grid_idx = - n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_grid)); - - auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r4, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2]), - c_element_op}; - - c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_grid_buf, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_buf, - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c1_grid_buf); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp deleted file mode 100644 index 7d6c86f5165..00000000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp +++ /dev/null @@ -1,617 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R6_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R6_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "threadwise_tensor_slice_transfer_v1r5.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r6( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_c0_grid, - p_shared_block, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); -} - -template -struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto K1 = Number{}; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01, - index_t N01) - { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, - "Invalid tuning param!"); - - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) - return false; - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; - - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) - return false; - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) - { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; - - return has_main_k0_block_loop; - } - - // TODO fix this - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N_any& c_grid_desc_m_n) - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - - return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - } - - // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return cblockid_to_m0_n0_block_cluster_adaptor; - } - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - - using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{})); - - using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - auto c0_grid_buf = make_dynamic_buffer( - p_c0_grid, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - // divide block work by [M, N] - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_grid_desc_k0_m_k1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_k0_m_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0_n_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[K0PerBlock, MPerBlock] is in LDS - // b_mtx[K0PerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - - auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; - - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - - auto a_block_buf = make_dynamic_buffer( - p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); - - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); - - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_grid_idx = - m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_grid)); - - const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_grid_idx = - n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_grid)); - - auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r5, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2]), - c_element_op}; - - c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_grid_buf, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_buf); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp index 0b325c106f4..336617d9d49 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp @@ -274,9 +274,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 return grid_size; } + // TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; return has_main_k0_block_loop; } diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp index d6ae2240495..588c16d01b4 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp @@ -281,9 +281,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 return grid_size; } + // TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; return has_main_k0_block_loop; } diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp index 758ab75ce05..3f8b74f5445 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp @@ -288,9 +288,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 return grid_size; } + // TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; return has_main_k0_block_loop; } From a301414107b95fd9fd475e302df09240ddcb6690 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 23 Feb 2022 17:40:09 +0000 Subject: [PATCH 11/11] refactor gridwise gemm pipeline class --- .../gridwise_gemm_pipeline_v1.hpp | 358 +++++++++++------- 1 file changed, 219 insertions(+), 139 deletions(-) diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_pipeline_v1.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_pipeline_v1.hpp index bc6068ead03..dcacd99ae17 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_pipeline_v1.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_pipeline_v1.hpp @@ -21,7 +21,40 @@ template -struct GridwiseGemmPipeline_v1 +struct GridwiseGemmPipeline_v1; + +// 1-stage prefetch +template +struct GridwiseGemmPipeline_v1 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -40,203 +73,250 @@ struct GridwiseGemmPipeline_v1 const BBlockTransferStep& b_block_copy_step, const BlockwiseGemm& blockwise_gemm, CThreadBuffer& c_thread_buf, - index_t NumLoop) + index_t num_loop) { - static_assert(NumPrefetch == 1 || NumPrefetch == 2, "wrong!"); - - if constexpr(NumPrefetch == 1) - { #if 0 - // preload data into LDS - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - // Initialize C - c_thread_buf.Clear(); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - // main body - if constexpr(HasMainLoop) - { - index_t i = 0; + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + // Initialize C + c_thread_buf.Clear(); - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - block_sync_lds(); + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - block_sync_lds(); + block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - ++i; - } while(i < (NumLoop - 1)); - } + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - // tail - { block_sync_lds(); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } -#else - // preload data into LDS - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - // Initialize C - c_thread_buf.Clear(); + ++i; + } while(i < (num_loop - 1)); + } - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // tail + { + block_sync_lds(); - // main body - if constexpr(HasMainLoop) - { - index_t i = 0; + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#else + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - do - { - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - block_sync_lds(); + // Initialize C + c_thread_buf.Clear(); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; - block_sync_lds(); + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - ++i; - } while(i < (NumLoop - 1)); - } + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - // tail - { block_sync_lds(); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } -#endif - } - else if constexpr(NumPrefetch == 2) - { - // preload data into LDS - { - // Read 0 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); - - // Move a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - // Read 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainLoop) - { - index_t i = 0; - - do - { - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - // Write i - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); - - // Read i+2 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); - - // Sync - block_sync_lds(); - - // Gemm i - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + ++i; + } while(i < (num_loop - 1)); + } - // Sync - block_sync_lds(); + // tail + { + block_sync_lds(); - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#endif + } +}; - // Write i+1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); +// 2-stage prefetch +template +struct GridwiseGemmPipeline_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; - // Read i+3 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + { + // Read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); - // Sync - block_sync_lds(); + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - // Gemm i+1 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + // Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + } - // Sync - block_sync_lds(); + // Initialize C + c_thread_buf.Clear(); - i += 2; - } while(i < (NumLoop - 2)); - } + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; - // tail + do { - // Write NumLoop - 2 + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + // Read i+2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + // Sync block_sync_lds(); - // Gemm NumLoop - 2 + // Gemm i blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); // Sync block_sync_lds(); - // Write NumLoop - 1 + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i+1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + // Read i+3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + // Sync block_sync_lds(); - // Gemm NumLoop - 1 + // Gemm i+1 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } + + // Sync + block_sync_lds(); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + { + // Write num_loop - 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Write num_loop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } } };