Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if(stream_config.log_level_ > 0)
{
arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}

if(!GridwiseGemm::CheckValidity(arg))
Expand Down Expand Up @@ -733,7 +734,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "Kpack: "
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
// clang-format on

return str.str();
Expand Down
12 changes: 10 additions & 2 deletions include/ck/utility/blkgemmpipe_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL);

printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C MFMA inst: %d\n",
"%d, %d\n C MFMA inst: %d\n"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d\n",
A_Buffer_Load_Inst_Num,
B_Buffer_Load_Inst_Num,
A_LDS_Write_Inst_Num,
B_LDS_Write_Inst_Num,
A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num,
C_MFMA_Inst_Num);
C_MFMA_Inst_Num,
A_LDS_Read_Width,
B_LDS_Read_Width,
ALDSWriteWidth,
BLDSWriteWidth,
ABufferLoadWidth,
BBufferLoadWidth);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr
// TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1;

static constexpr index_t KPack = WarpGemm::kKPerThread;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static constexpr index_t KPack = 8;
static constexpr index_t KPerThread = KIterPerWarp * KPack;
static constexpr index_t KRepeat = KPerThread / KPack;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

#pragma once

#include <string>
#include <sstream>

#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
Expand Down Expand Up @@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>();
}

CK_TILE_HOST static std::string Print()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;

constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});

// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();

constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();

constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());

constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);

constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);

constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);

auto str = std::stringstream{};

str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n"
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
<< "\n"
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
}

template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
Expand All @@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>

CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;

constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});

constexpr index_t A_LDS_Read_Width = KPerXDL;
constexpr index_t B_LDS_Read_Width = KPerXDL;
// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();

constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();

constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());

constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);

constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);

constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{

using ADataType = remove_cvref_t<typename Problem::ADataType>;

constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
Expand Down