Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ template <typename GridwiseGemm,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor>
typename CBlockClusterAdaptor,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
Expand All @@ -37,14 +38,14 @@ __global__ void

__shared__ FloatAB p_shared_block[shared_block_size];

GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm,
Expand Down Expand Up @@ -81,14 +82,14 @@ __global__ void

__shared__ FloatAB p_shared_block[shared_block_size];

GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
#endif

Expand All @@ -102,7 +103,7 @@ template <index_t BlockSize,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t K0PerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
Expand Down Expand Up @@ -158,13 +159,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();

Expand All @@ -173,13 +174,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();

Expand Down Expand Up @@ -217,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
K1 == b_k0_n_k1_grid_desc.GetLength(I2)))
return false;

if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0))
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;

// check M01, N01
Expand Down Expand Up @@ -245,6 +246,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return grid_size;
}

__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1;

return has_main_k0_block_loop;
}

__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
Expand All @@ -255,13 +263,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();

Expand All @@ -270,13 +278,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();

Expand Down Expand Up @@ -334,6 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1));

template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
Expand Down Expand Up @@ -371,13 +380,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();

Expand All @@ -386,21 +395,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();

// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, MPerBlock, K1>,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
Expand All @@ -426,7 +435,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, NPerBlock, K1>,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
Expand All @@ -450,8 +459,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3

// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
Expand All @@ -477,8 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;

constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);

// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
Expand All @@ -504,32 +513,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}

// main body
index_t k_block_data_begin = 0;
index_t k0_block_data_begin = 0;

do
if constexpr(HasMainKBlockLoop)
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);

a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
a_blockwise_copy.RunRead(
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);

block_sync_lds();
block_sync_lds();

b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);

blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);

block_sync_lds();
block_sync_lds();

a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);

k_block_data_begin += KPerBlock;
} while(k_block_data_begin < (K0 - KPerBlock));
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock));
}

// tail
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;

constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;

Expand Down Expand Up @@ -188,7 +188,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;

constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256;

Expand Down
Loading