Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
fd26911
simplify karg in device/grid split-k op
carlushuang Mar 20, 2023
e21267f
fix mk_kn_mn instances
carlushuang Mar 22, 2023
2456e9c
Merge remote-tracking branch 'origin/develop' into simplified_karg_un…
carlushuang Mar 22, 2023
2863635
add more instances
carlushuang Mar 23, 2023
115f8a4
Merge remote-tracking branch 'origin/develop' into simplified_karg_un…
carlushuang Mar 23, 2023
baf6868
Merge branch 'simplified_karg_unify_op' into aosewski/ggemm_splitk
Mar 24, 2023
10cabc2
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_splitk
Mar 30, 2023
a84d502
B2C with 3D grid for KSplit
Apr 3, 2023
96b535b
Remove unused code.
Apr 3, 2023
5f7eda6
Use default B2C (3D grid) in grid gemm v2r4r2.
Apr 3, 2023
de33033
Device gemm splitk use B2C map.
Apr 3, 2023
8c0f936
Device GroupedGemmXdlSplitKCShuffle
Apr 4, 2023
53d3ee2
Example for GroupedGemm Xdl SplitK
Apr 4, 2023
24c7c49
Introduce Device GroupedGemmSplitK
Apr 5, 2023
f1814e2
Fix updating kbatch size.
Apr 5, 2023
3af2a90
Add instance mk-nk-mn
Apr 5, 2023
e29e5bf
Enable set kbatch in profiler.
Apr 5, 2023
9ff4fff
Add GGemmSplitK mk-kn-mn instances
Apr 5, 2023
ed0cc89
Merge branch 'develop' into aosewski/ggemm_splitk
Apr 5, 2023
6d550ac
Add more instances & split into multiple files.
Apr 7, 2023
2f55336
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_splitk
Apr 7, 2023
0b7a77c
Merge branch 'develop' into aosewski/ggemm_splitk
zjing14 Apr 10, 2023
a5abe1a
Merge branch 'develop' into aosewski/ggemm_splitk
zjing14 Apr 17, 2023
d06ce45
Merge branch 'develop' into aosewski/ggemm_splitk
zjing14 Apr 20, 2023
c38d8fd
merge develop
Apr 22, 2023
7277329
minor fix
Apr 22, 2023
4ab3cad
tuning
Apr 23, 2023
08426a8
clean
Apr 24, 2023
fa95b4f
disabled failed instances
Apr 24, 2023
cbf1494
use pipe v2
Apr 24, 2023
0b25af7
Ignore arg on not supported arch.
Apr 24, 2023
0b8319b
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_splitk
Apr 24, 2023
923b798
fix warning
Apr 24, 2023
540d43b
Merge branch 'aosewski/ggemm_splitk' of https://github.com/ROCmSoftwa…
Apr 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion example/15_grouped_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)


add_dependencies(example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp32
example_grouped_gemm_xdl_fp16
example_grouped_gemm_xdl_bfp16
example_grouped_gemm_xdl_int8
example_grouped_gemm_multiple_d_dl_fp16)
example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_xdl_splitk_fp16)

if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
Expand Down
97 changes: 97 additions & 0 deletions example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using F16 = ck::half_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>;
using EDataType = F16;

using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on

#include "run_grouped_gemm_example.inc"

int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;

problem_size.group_count = 16;

problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};

for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ns.push_back(768);
problem_size.Ks.push_back(4608);

problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}

if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}

return !run_grouped_gemm(problem_size, config);
}
1 change: 1 addition & 0 deletions example/15_grouped_gemm/run_grouped_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#else
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
c_tensors_device[i]->SetZero();
#endif

p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct DeviceGroupedGemm : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();

static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor");
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");

virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_a,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once
#include <iostream>
#include <vector>

#include "device_grouped_gemm.hpp"

namespace ck {
namespace tensor_operation {
namespace device {

template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
};

} // namespace device
} // namespace tensor_operation
} // namespace ck
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CBlockTransferScalarPerVector_NWaveNPerXDL,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;

using Argument = typename GridwiseGemm::Argument;
using Argument = typename GridwiseGemm::Argument;
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;

// Invoker
struct Invoker : public BaseInvoker
Expand All @@ -138,8 +139,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
"setting");
}

const auto b2c_map = DefaultBlock2CTileMap{};
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg);
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
const auto K0 = karg.K0;

const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
Expand All @@ -152,7 +154,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));

ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map);
};

if(has_main_k0_block_loop)
Expand All @@ -162,7 +164,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set>;
InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>;

Run(kernel);
}
Expand All @@ -171,7 +174,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd>;
InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>;

Run(kernel);
}
Expand All @@ -183,7 +187,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set>;
InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>;

Run(kernel);
}
Expand All @@ -192,7 +197,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd>;
InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>;

Run(kernel);
}
Expand Down
Loading