Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,99 +10,99 @@ template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
typename ABlockDesc_E1_K1_E2,
typename BBlockDesc_E1_N_Ho_Wo_E2,
typename CThreadDesc_K_N_Ho_Wo,
index_t EPerThreadLoop,
index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W>
index_t KPerThreadLoop>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
struct MatrixIndex
{
index_t k;
index_t h;
index_t w;
};
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};

using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;

static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0);
static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1);
static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2);

static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);

// HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4;
static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0);
static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);

static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<E2>{}));

static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
static constexpr auto b_thread_mtx_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<EPerThreadLoop>{},
Number<1>{},
Number<HoPerThread>{},
Number<WoPerThread>{},
Number<E2>{}));

static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));

using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
1>;
Number<KPerThreadLoop>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));

__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
: c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)}
{
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(),
static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() &&
BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");

constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};

static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
static_assert(
ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) &&
ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4),
"wrong! E dimension not consistent\n");

constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t H = BlockMatrixB{}.GetLength(I2);
constexpr index_t W = BlockMatrixB{}.GetLength(I3);
static_assert(E1 % EPerThreadLoop == 0, "");
static_assert(KPerThread % KPerThreadLoop == 0, "");

static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0,
static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 &&
WoPerBlock % WoPerThread == 0,
"wrong! Cannot evenly divide work among\n");

constexpr auto KThreadCluster = K / KPerThread;
constexpr auto HThreadCluster = H / HPerThread;
constexpr auto WThreadCluster = W / WPerThread;
constexpr auto KThreadCluster = KPerBlock / KPerThread;
constexpr auto HThreadCluster = HoPerBlock / HoPerThread;
constexpr auto WThreadCluster = WoPerBlock / WoPerThread;

static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
"wrong! wrong blocksize\n");
}

__device__ static constexpr auto GetThreadMatrixCLengths()
__device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths()
{
return Sequence<KPerThread, 1, HPerThread, WPerThread>{};
return Sequence<KPerThread, I1, HoPerThread, WoPerThread>{};
}

__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
__device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
{
constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{});
constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{});

constexpr auto num_w_threads = W / WPerThread;
constexpr auto num_h_threads = H / HPerThread;
constexpr auto num_hw_threads = num_w_threads * num_h_threads;

index_t k_thread_id = thread_id / num_hw_threads;
index_t hw_thread_id = thread_id % num_hw_threads;

index_t h_thread_id = hw_thread_id / num_w_threads;
index_t w_thread_id = hw_thread_id % num_w_threads;

return MatrixIndex{k_thread_id, h_thread_id, w_thread_id};
constexpr auto K0 = KPerBlock / KPerThread;
constexpr auto N0 = I1;
constexpr auto H0 = HoPerBlock / HoPerThread;
constexpr auto W0 = WoPerBlock / WoPerThread;

constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));

const auto c_k_n_h_w_thread_cluster_idx =
c_threadid_to_k_n_h_w_thread_cluster_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));

return c_k_n_h_w_thread_cluster_idx;
}

template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
Expand All @@ -116,19 +116,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");

constexpr auto I0 = Number<0>{};

constexpr auto a_block_mtx = BlockMatrixA{};

constexpr auto EPerBlock = a_block_mtx.GetLength(I0);

// HACK: fix this @Jing Zhang
constexpr auto HoPerThreadSubC = 2;
constexpr auto WoPerThreadSubC = 2;

static_assert(KPerThread % KPerThreadSubC == 0, "");
static_assert(HPerThread % HoPerThreadSubC == 0, "");
static_assert(WPerThread % WoPerThreadSubC == 0, "");
constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{};

// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
Expand All @@ -139,42 +127,46 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
FloatC,
decltype(a_thread_mtx_),
decltype(b_thread_mtx_),
decltype(c_thread_mtx_),
HoPerThreadSubC,
WoPerThreadSubC>{};
decltype(c_thread_mtx_)>{};

static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) {
static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) {
static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) {
static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) {
a_thread_copy_.Run(a_block_mtx,
make_tuple(e_begin, k_begin),
make_tuple(e_begin, k_begin, I0),
a_block_buf,
a_thread_mtx_,
make_tuple(I0, I0),
make_tuple(I0, I0, I0),
a_thread_buf);

static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) {
static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) {
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0),
b_thread_buf,
make_tuple(e_begin, I0, h_begin, w_begin),
c_thread_buf,
make_tuple(k_begin, I0, h_begin, w_begin));
});
});
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(e_begin, I0, I0, I0, I0),
c_thread_buf,
make_tuple(k_begin, I0, I0, I0));
});
});
}

template <typename ABlockSliceMoveStepIdx>
__device__ void MoveASliceWindow(const BlockMatrixA&,
const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx);
}

private:
MatrixIndex c_thread_begin_mtx_idx_;
using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc_E1_K1_E2,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadLoop, E2>,
Sequence<0, 1, 2>,
2,
E2,
E2>;

CIndex c_thread_origin_data_idx_;

AThreadCopy a_thread_copy_;
};
Expand Down
Loading