Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);

const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0];

static_for<0, i, 1>{}([&](auto j) {
static_for<1, i, 1>{}([&](auto j) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ struct XdlopsGemm
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);

return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
Expand All @@ -599,7 +600,7 @@ struct XdlopsGemm
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
mfma_instr.num_input_blks,
mfma_instr.group_size)),
make_pass_through_transform(mfma_instr.num_threads_per_blk)),
make_pass_through_transform(N2)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 160, 128, 4, 4, 16, 16, 5, 4, S<1, 5, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 160, 4, 4, 16, 16, 4, 5, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 5, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 192, 128, 4, 4, 32, 32, 3, 2, S<1, 3, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 192, 4, 4, 32, 32, 2, 3, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 3, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 4, 16, 16, 3, 4, S<1, 3, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 96, 4, 4, 16, 16, 4, 3, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 3, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
Expand Down
77 changes: 51 additions & 26 deletions device_operation/include/device_gemm_xdl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,14 @@ struct DeviceGemmXdl
}
}();

const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;

std::cout << "PadM = " << PadM << " M = " << M + PadM << std::endl;

const auto a_grid_desc_k0_m_k1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_pad_transform(M, I0, PadM)),
Comment thread
asroy marked this conversation as resolved.
Outdated
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));

Expand All @@ -105,10 +109,14 @@ struct DeviceGemmXdl
}
}();

const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;

std::cout << "PadN = " << PadN << " N = " << N + PadN << std::endl;

const auto b_grid_desc_k0_n_k1 =
transform_tensor_descriptor(b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_pad_transform(N, I0, PadN)),
Comment thread
asroy marked this conversation as resolved.
Outdated
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));

Expand All @@ -117,14 +125,27 @@ struct DeviceGemmXdl

static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();

const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;

const auto c_grid_desc_m_n_ = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pad_transform(M, I0, PadM), make_pad_transform(N, I0, PadN)),
Comment thread
asroy marked this conversation as resolved.
Outdated
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));

return c_grid_desc_m_n_;
}

using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
Expand All @@ -149,22 +170,22 @@ struct DeviceGemmXdl
Sequence<0, 0, 0>{})); // 2-: K1

static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2

static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};

Expand Down Expand Up @@ -293,6 +314,10 @@ struct DeviceGemmXdl
float Run(const Argument& arg, int nrepeat = 1)
{
{
std::cout << "MPerBlock = " << MPerBlock << " NPerBlock = " << NPerBlock
<< " MXdlPerWave = " << MXdlPerWave << " NXdlPerWave = " << NXdlPerWave
<< std::endl;

std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
Expand Down
41 changes: 34 additions & 7 deletions host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;

constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;

Expand Down Expand Up @@ -274,7 +274,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;

constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;

Expand Down Expand Up @@ -302,7 +302,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;

constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;

Expand All @@ -329,15 +329,42 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;

constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
constexpr index_t BlockSize = 256;

constexpr index_t MPerBlock = 96;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;

constexpr index_t MPerXDL = 16;
constexpr index_t NPerXDL = 16;
constexpr index_t K1 = 8;

constexpr index_t MRepeat = 3;
constexpr index_t NRepeat = 4;

using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 3, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;

constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;

using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;

constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;

constexpr index_t CThreadTransferDstScalarPerVector = 1;
#endif

const auto K = a_m_k.mDesc.GetLengths()[1];
const auto M = a_m_k.mDesc.GetLengths()[0];
const auto N = b_k_n.mDesc.GetLengths()[1];
const index_t K = a_m_k.mDesc.GetLengths()[1];
const index_t M = a_m_k.mDesc.GetLengths()[0];
const index_t N = b_k_n.mDesc.GetLengths()[1];

constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number;
const index_t K0 = K / K1Number;

const auto a_k0_m_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
Expand Down
Loading