Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a624666
add DeviceGemmSplitKXdl
ltqin Nov 22, 2021
d199894
add file device_gemm_splitk_xdl.hpp
ltqin Nov 23, 2021
b7ec207
set c matrix zero
ltqin Nov 23, 2021
114f929
using atomic
ltqin Nov 24, 2021
6a2157f
add all tuning parameter to f32 mkkn
ltqin Nov 24, 2021
b98e339
grid size change to 720
ltqin Nov 25, 2021
000db48
add tunning parameter for NT
ltqin Nov 25, 2021
b282e62
add tunning parameter for TN
ltqin Nov 25, 2021
0694d6e
add tunning parameter for TT
ltqin Nov 25, 2021
a037693
Merge branch 'develop' into conv_splitk_f32
ltqin Dec 1, 2021
5576da2
add m=96tunning parameter
ltqin Dec 2, 2021
134af43
add lost config
ltqin Dec 2, 2021
c29dc4c
Merge branch 'develop' into conv_splitk_f32
ltqin Dec 9, 2021
0eed507
add element wise operation
ltqin Dec 9, 2021
b59d549
fixed MPerBlock=96
ltqin Dec 9, 2021
f683fed
remove marco for slpitk swtich
ltqin Dec 9, 2021
982e59b
Merge branch 'develop' into conv_splitk_f32
ltqin Dec 14, 2021
1b4ae8b
add test
ltqin Dec 16, 2021
f880480
add new line at the end of device_gemm_xdl_instance.hpp
ltqin Dec 24, 2021
aaa8991
Merge branch 'develop' into conv_splitk_f32
ltqin Dec 27, 2021
303b1a8
remove step hack
ltqin Dec 27, 2021
1b9e6e1
seperate split-k instance files
ltqin Dec 29, 2021
cca0cee
add tunning parameters
ltqin Dec 30, 2021
7b01dbe
change disired grid size to parameters
ltqin Dec 30, 2021
adc79bd
remove slice length
ltqin Dec 30, 2021
d862fdf
add desiredgridsize parameter to ckProfiler
ltqin Dec 31, 2021
accb4ca
add losting file device_gemm_xdl_splitk_instance.hpp
ltqin Dec 31, 2021
62a860a
change desired gride size to kbatch
ltqin Jan 4, 2022
615fb48
Merge remote-tracking branch 'origin/develop' into conv_splitk_f32
Jan 27, 2022
fe027ba
format
Jan 27, 2022
0e67221
format
Jan 27, 2022
8160c31
clean up
Jan 28, 2022
25751e3
add selection of device_instances
Jan 29, 2022
df6f43d
clean code
Jan 29, 2022
30eb2cc
fix build issue
Feb 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ template <typename GridwiseGemm,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
Expand All @@ -74,7 +77,10 @@ __global__ void
const void CONSTANT* p_a_b_k0_m_k1_grid_desc,
const void CONSTANT* p_b_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const void CONSTANT* p_c_block_cluster_adaptor)
const void CONSTANT* p_a_element_op,
const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
Expand All @@ -86,8 +92,14 @@ __global__ void
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
*reinterpret_cast<const CM0N0M1N1M2M3M4N2GridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
cast_pointer_to_generic_address_space(p_block_2_ctile_map));
const auto a_element_op = *reinterpret_cast<const AElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_a_element_op));
const auto b_element_op = *reinterpret_cast<const BElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_b_element_op));
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_c_element_op));

__shared__ FloatAB p_shared_block[shared_block_size];

Expand All @@ -98,7 +110,10 @@ __global__ void
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#endif

Expand All @@ -110,6 +125,9 @@ template <index_t BlockSize,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CMNGridDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
Expand All @@ -118,28 +136,25 @@ template <index_t BlockSize,
index_t K1Value,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM,
bool BBlockLdsExtraN>
index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
{
static constexpr auto I0 = Number<0>{};
Expand Down Expand Up @@ -358,6 +373,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
Expand Down Expand Up @@ -456,7 +474,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
Expand Down Expand Up @@ -487,7 +504,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
Expand Down Expand Up @@ -583,8 +599,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);

k_block_data_begin += K0PerBlock;
} while(k_block_data_begin < (K0 - K0PerBlock));
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock));
}

// tail
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"

namespace ck {
namespace tensor_operation {
Expand All @@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;

// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn =
using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
Expand All @@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn =
// clang-format on
>;

template <>
void add_device_gemm_instance<F16, F16, F16, Col, Row, Row>(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn;

const auto device_gemms = DeviceGemms{};

ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;

auto gemm = Gemm{};

device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_kn_mn_instances{});
}

} // namespace device_gemm_instance
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"

namespace ck {
namespace tensor_operation {
Expand All @@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;

// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn =
using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
Expand All @@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn =
// clang-format on
>;

template <>
void add_device_gemm_instance<F16, F16, F16, Col, Col, Row>(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn;

const auto device_gemms = DeviceGemms{};

ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;

auto gemm = Gemm{};

device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_nk_mn_instances{});
}

} // namespace device_gemm_instance
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"

namespace ck {
namespace tensor_operation {
Expand All @@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;

// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn =
using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
Expand All @@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn =
// clang-format on
>;

template <>
void add_device_gemm_instance<F16, F16, F16, Row, Row, Row>(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn;

const auto device_gemms = DeviceGemms{};

ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;

auto gemm = Gemm{};

device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{});
}

} // namespace device_gemm_instance
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"

namespace ck {
namespace tensor_operation {
Expand All @@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;

// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn =
using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
Expand All @@ -44,21 +44,10 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn =
// clang-format on
>;

template <>
void add_device_gemm_instance<F16, F16, F16, Row, Col, Row>(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn;

const auto device_gemms = DeviceGemms{};

ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;

auto gemm = Gemm{};

device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{});
}

} // namespace device_gemm_instance
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"

namespace ck {
namespace tensor_operation {
Expand All @@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;

// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn =
using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
Expand All @@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn =
// clang-format on
>;

template <>
void add_device_gemm_instance<F32, F32, F32, Col, Row, Row>(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn;

const auto device_gemms = DeviceGemms{};

ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;

auto gemm = Gemm{};

device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});
add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_kn_mn_instances{});
}

} // namespace device_gemm_instance
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"

namespace ck {
namespace tensor_operation {
Expand All @@ -21,7 +21,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;

// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn =
using device_gemm_xdl_f32_f32_f32_km_nk_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
Expand All @@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn =
// clang-format on
>;

template <>
void add_device_gemm_instance<F32, F32, F32, Col, Col, Row>(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn;

const auto device_gemms = DeviceGemms{};

ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;

auto gemm = Gemm{};

device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});
add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_nk_mn_instances{});
}

} // namespace device_gemm_instance
Expand Down
Loading