From eed702b73d81dd15189cd81885c354934172fa4b Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 4 Feb 2025 14:47:22 +0000 Subject: [PATCH 1/4] Use SmemPack in HotLoop scheduler --- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 0bd78072380..f42392a0add 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -103,21 +103,27 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); - constexpr index_t A_LDS_Read_Width = KPerXDL; - constexpr index_t B_LDS_Read_Width = KPerXDL; + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB(); constexpr index_t A_Buffer_Load_Inst_Num = MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); constexpr index_t B_Buffer_Load_Inst_Num = NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); - constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); - constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); constexpr index_t A_LDS_Read_Inst_Num = - WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / From 69bda85098f25570170176a75e4085edea875231 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 6 Feb 2025 14:20:42 +0000 Subject: [PATCH 2/4] Additional debug print information --- .../impl/device_gemm_xdl_cshuffle_v3.hpp | 5 +- include/ck/utility/blkgemmpipe_scheduler.hpp | 12 ++++- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 53 +++++++++++++++++++ ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 1 - 4 files changed, 67 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 600f12139d6..1c144966504 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) { arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); } if(!GridwiseGemm::CheckValidity(arg)) @@ -733,7 +734,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 +#include + #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" @@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 return Policy::template GetSmemSize(); } + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + template struct PipelineImpl : public PipelineImplBase { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 33f105a4354..babf6e91154 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { - using ADataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; From 8fb20fc0ef1f55e2780855ead8b48c1f473815c8 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 6 Feb 2025 14:21:14 +0000 Subject: [PATCH 3/4] Change KPack value. Hardcode for now, as without AK1/BK1 there's no good way to determine its value. --- .../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 646d380a185..ab21398b99b 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr // TODO: Should we have two policies? Interwave & Intrawave ?? static constexpr index_t InterWaveSchedulingMacClusters = 1; - static constexpr index_t KPack = WarpGemm::kKPerThread; + // should be at least equal to: WarpGemm::Impl::kABKPerLane + // and the question is how to assess upper limit or exact value? + // TODO: Should we introduce AK1/BK1 parameters ? + static constexpr index_t KPack = 8; static constexpr index_t KPerThread = KIterPerWarp * KPack; static constexpr index_t KRepeat = KPerThread / KPack; }; From 96c8d9480eb90d4f51a6c05ec610e8d7453b7395 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 6 Feb 2025 14:22:59 +0000 Subject: [PATCH 4/4] Fix HotLoopScheduler MFMA instr parameters. --- .../ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 1c50092a8d5..0a40ca359ee 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -148,9 +148,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 CK_TILE_DEVICE static constexpr auto HotLoopScheduler() { - constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); - constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); - constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; constexpr index_t WaveSize = 64; constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});