diff --git a/projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp b/projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp index 09364350b1d..50a7bfd66bc 100644 --- a/projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp +++ b/projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp @@ -3,7 +3,7 @@ #pragma once -#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp" namespace ck { template {}; + return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3{}; } else { diff --git a/projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp b/projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp deleted file mode 100644 index 67a9769acab..00000000000 --- a/projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp +++ /dev/null @@ -1,1148 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" - -namespace ck { - -// Naive pipeline with lowest resource request per WGP -// GlobalPrefetchStages: 2 -// LocalPreFillStages: 1 -// LocalPreFetchStages: 1 -// LocalSharedMemoryBuffer: 1 - -template -struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle -{ -}; - -template -struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle - : BlockwiseGemmXdlops_mx_pipeline_base - -{ - - using Base = BlockwiseGemmXdlops_mx_pipeline_base; - using Base::A_K1; - using Base::I0; - using Base::I1; - using Base::KRepeat; - using Base::MWaves; - using Base::NWaves; - using Base::WaveSize; - using Base::xdlops_gemm; - using typename Base::HotLoopInstList; - - using Base::CalculateCThreadOriginDataIndex; - using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; - using Base::GetCThreadBuffer; - using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; - using Base::GetWaveIdx; - using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; - - using Base::a_block_desc_m0_m1_m2_m3_k; - using Base::b_block_desc_n0_n1_n2_n3_k; - - using Base::AMmaKStride; - using Base::APackedSize; - using Base::BMmaKStride; - using Base::BPackedSize; - using Base::KThreadChunk; - - using Base::KXdlPack; - using Base::MXdlPack; - using Base::NXdlPack; - - using AccType = typename Base::AccType; - using Tuple5 = typename Base::Tuple5; - using ComputeTypeA = typename Base::ComputeTypeA; - using ComputeTypeB = typename Base::ComputeTypeB; - - static constexpr index_t PrefetchStages = 2; - static constexpr index_t LocalPrefetchStages = 2; - static constexpr index_t PrefillStages = 1; - static constexpr index_t GlobalBufferNum = 1; - static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; - - static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; - static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; - static constexpr auto async_vmcnt = - num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num; - static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384; - - static constexpr auto ScalesPerKBlockSize = - KPerBlock / ScaleBlockSize; // How many mx-vectors per K block - - //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() - static constexpr auto ScalesPerXdlopsRun = - (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; - - //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() - static constexpr auto ScalesPerXdlopsRunPerThread = - ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; - - using mx_scale_t = e8m0_bexp_t; - static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); - static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); - static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, - "A scale pack data type too large!"); - static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, - "B scale pack data type too large!"); - static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; - static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; - - __host__ static constexpr bool BlockHasHotloop(index_t num_loop) - { - return num_loop > PrefetchStages; - } - - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) - { - return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; - } - - __device__ static constexpr auto HotLoopScheduler() - { - // A/B split schedule - // compiler is likely to use ds_read2 when instruction width smaller than 16bytes - constexpr auto num_ds_read_inst_a = - HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 - ? HotLoopInstList::A_LDS_Read_Inst_Num - : HotLoopInstList::A_LDS_Read_Inst_Num / 2; - - constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; - constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; - constexpr auto num_buffer_load_stage1 = - num_buffer_load_inst_b + num_buffer_load_a_scale + num_buffer_load_b_scale; - - constexpr auto num_buffer_load_stage2 = num_buffer_load_inst_a; - - constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize; - constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; - - constexpr auto ds_read_a_issue_cycle = - HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; - constexpr auto ds_read_a_mfma_rate = - math::integer_divide_ceil(mfma_cycle - 8, 2 * ds_read_a_issue_cycle); - - // constexpr auto num_dsread_a_mfma = - // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; - - constexpr auto num_total_stages = std::max(2, MRepeat); - - if constexpr(num_total_stages > 2) - { - // Group num_mfma_perstage num_ds_read_a_perstage - // since we want to reuse a local register buffer - constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; - constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; - - constexpr auto num_ds_read_a_mfma_perstage = - math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); - - constexpr auto num_ds_read_a_prefetch_stages = 2; - - constexpr auto buffer_load_perstage_more = - math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); - constexpr auto buffer_load_perstage_less = - math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); - constexpr auto buffer_load_perstage_stage2 = - math::integer_divide_floor((num_buffer_load_stage2), 2); - - constexpr auto buffer_load_stages_more = - num_buffer_load_stage1 - - math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * - ((num_total_stages - 2)); - - constexpr auto buffer_load_issue_point_interval_more = - num_mfma_perstage / buffer_load_perstage_more; - constexpr auto buffer_load_issue_point_interval_less = - num_mfma_perstage / buffer_load_perstage_less; - constexpr auto buffer_load_issue_point_interval_stage2 = - num_mfma_perstage / buffer_load_perstage_stage2; - - // Stage 1 - // global read more - static_ford>{}([&](auto ii) { - constexpr auto imfma = Number{}]>{}; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - - if constexpr(imfma % buffer_load_issue_point_interval_more == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); - - // global read less - static_ford>{}([&](auto ii) { - constexpr auto imfma = Number{}]>{}; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_less == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); - - // Stage 2, Sync - // lds synchronization, prefetch next loop local A - static_ford>{}([&](auto ii) { - constexpr auto imfma = Number{}]>{}; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); - } - else - { - constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + - num_buffer_load_a_scale + - num_buffer_load_b_scale; - constexpr auto num_dsread_a_mfma = math::integer_divide_ceil( - num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a - - // stage 1 - constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma; - - constexpr auto mfma_perstage_more = - math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); - constexpr auto mfma_perstage_less = - math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); - - constexpr auto mfma_stages_more = - num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; - - static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { - if constexpr(i < mfma_stages_more) - { - static_for<0, mfma_perstage_more, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - else - { - static_for<0, mfma_perstage_less, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - }); - - static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { - if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) - { - static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - else - { - static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - }); - - static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { - if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < - mfma_stages_more) - { - static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - else - { - static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - }); - - static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { - if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + - num_buffer_load_a_scale) < mfma_stages_more) - { - static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - else - { - static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - }); - - // stage 2 - static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= - ds_read_a_mfma_rate) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x100, - num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, - 0); // DS read - } - }); - } - } - - template - __device__ void Run( - // ABlockCopy - const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_blockwise_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_bufs, - const ABlockTransferStep& a_block_copy_step, - // BBlockCopy - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_blockwise_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_bufs, - const BBlockTransferStep& b_block_copy_step, - // CThread - CThreadBuffer& c_thread_buf, - // A and B scales - const AScaleGridDesc& a_scale_grid_desc, - AScaleThreadTransfer& a_scale_thread_copy, - const AScaleGridBuffer& a_scale_grid_buf, - const BScaleGridDesc& b_scale_grid_desc, - BScaleThreadTransfer& b_scale_thread_copy, - const BScaleGridBuffer& b_scale_grid_buf, - index_t num_loop) const - { - ignore = b_block_bufs; - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - StaticallyIndexedArray{}> b_thread_bufs; - constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0); - - auto a_scale_thread_buf = make_static_buffer( - a_scale_thread_desc.GetElementSpaceSize()); - - auto b_scale_thread_buf = make_static_buffer( - b_scale_thread_desc.GetElementSpaceSize()); - - StaticallyIndexedArray{}> a_scale_thread_bufs; - StaticallyIndexedArray{}> b_scale_thread_bufs; - - // Global prefetch 1 - a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0)); - b_blockwise_copy.Run( - b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0)); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Prefetch a_scales - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(m0, k0, I0), - a_scale_thread_bufs(I0)); - - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(0, I1, 0)); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); - }); - - // restore row id and advance to the next set of scales - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); - - // Prefetch b_scales - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, k0, I0), - b_scale_thread_bufs(I0)); - - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(0, I1, 0)); - }); - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); - }); - - // restore col id and advance to the next set of scales - // NWaves * NPerXDL * NRepeat == NPerBlock - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); - - // Local prefetch 1, sync the async load - __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); - block_sync_lds(); - static_ford>{}([&](auto mk) { - constexpr auto m0 = Number{}]>{}; - constexpr auto k = Number{}]>{}; - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_m3_k, - make_tuple(I0, I0, Number{}, I0, Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple( - I0, I0, Number{}, k, Number{}), - a_thread_buf); - }); - }); - - // Global prefetch 2 - a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1)); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - - // Initialize C - c_thread_buf.Clear(); - __builtin_amdgcn_sched_barrier(0); - constexpr index_t SwitchM = MRepeat - LocalPrefetchStages; - // main body - if constexpr(HasMainLoop) - { - // loop over k with the step KPerBlock - index_t i = 0; - do - { - auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc, - b_block_origin_idx, - b_thread_bufs(scale_mem_buf)); - - // Prefetch a_scales - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(m0, k0, I0), - a_scale_thread_bufs(scale_mem_buf)); - - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(0, I1, 0)); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); - }); - - // restore row id and advance to the next set of scales - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); - - // Prefetch b_scales - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, k0, I0), - b_scale_thread_bufs(scale_mem_buf)); - - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(0, I1, 0)); - }); - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); - }); - - // restore col id and advance to the next set of scales - // NWaves * NPerXDL * NRepeat == NPerBlock - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); - - // a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - constexpr auto im_major = m0 / MXdlPack; - constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset( - make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset( - make_tuple(in_major, ik_major, I0)); - - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); - - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; - - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs( - scale_comp_buf)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = b_thread_bufs - [scale_comp_buf][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - - if constexpr(m0.value == SwitchM) - { - __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); - block_sync_lds(); - a_blockwise_copy.Run(a_grid_desc, - a_grid_buf, - a_block_desc, - a_block_bufs(scale_comp_buf)); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - } - - constexpr auto lds_buf = - m0.value >= SwitchM ? scale_mem_buf : scale_comp_buf; - - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, - xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), - 1>{}([&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % - (MRepeat / MXdlPack)>{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(Number{}), - a_thread_desc_, - make_tuple(I0, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - }); - - HotLoopScheduler(); - __builtin_amdgcn_sched_barrier(0); - }; - - LoopFunc(I0, I1); - LoopFunc(I1, I0); - - i += 2; - } while(i < (num_loop - 2)); - } - - // tail - if constexpr(TailNum == TailNumber::Even) - { - b_blockwise_copy.Run( - b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1)); - - // Prefetch a_scales - static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(m0, k0, I0), - a_scale_thread_bufs(I1)); - - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(0, I1, 0)); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); - }); - - // Prefetch b_scales - static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { - static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, k0, I0), - b_scale_thread_bufs(I1)); - - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(0, I1, 0)); - }); - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); - }); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - constexpr auto im_major = m0 / MXdlPack; - constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); - - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - if constexpr(m0.value == SwitchM) - { - __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); - block_sync_lds(); - } - - constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0; - - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % - (MRepeat / MXdlPack)>{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(Number{}), - a_thread_desc_, - make_tuple( - I0, I0, Number{}, k, Number{}), - a_thread_buf); - }); - }); - }); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - constexpr auto im_major = m0 / MXdlPack; - constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); - - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I1)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I1)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) - { - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % - (MRepeat / MXdlPack)>{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I1), - a_thread_desc_, - make_tuple(I0, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - } - }); - } - else if constexpr(TailNum == TailNumber::Odd) - { - static_for<0, MRepeat, 1>{}([&](auto m0) { - constexpr auto im_major = m0 / MXdlPack; - constexpr auto im_minor = m0 % MXdlPack; - static_for<0, KRepeat, 1>{}([&](auto k0) { - constexpr auto ik_major = k0 / KXdlPack; - constexpr auto ik_minor = k0 % KXdlPack; - static_for<0, NRepeat, 1>{}([&](auto n0) { - constexpr auto in_major = n0 / NXdlPack; - constexpr auto in_minor = n0 % NXdlPack; - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); - - static_assert(0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops " - "per Thread."); - - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - - // Pack scale_thread_buf into scale_thread_vec - static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs(I0)[Number{}]; - }); - - static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs(I0)[Number{}]; - }); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - - using mfma_input_type_b = - typename vector_type::type; - - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(im_major, in_major, im_minor, in_minor, 0)); - - // MFMA accumulation - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) - { - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); - static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + - chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_m3_k, - make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % - (MRepeat / MXdlPack)>{}, - I0, - Number{}, - I0, - Number{}), - a_block_bufs(I0), - a_thread_desc_, - make_tuple(I0, - I0, - Number{}, - k, - Number{}), - a_thread_buf); - }); - }); - } - }); - } - } - - // Length: A[ARegBuf, MWave, MXdlPack, KRepeat, KPack] - // Order: 1 0 3 2 4 - static constexpr auto ARegBuf = 2; - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, Number{}, Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - A_K1, - A_K1>; - AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; - - // TODO: make this field protected when a_scale_thread_copy_ is moved - // here - static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{})); - - // TODO: make this field protected when b_scale_thread_copy_ is moved - // here - static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{})); - - protected: - // using Base::a_thread_copy_; - // using Base::a_thread_desc_; - using Base::b_thread_copy_; - using Base::b_thread_desc_; - using Base::c_thread_desc_; -}; - -} // namespace ck