Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
5aaa11c
[What] Rename the example
rocking5566 Apr 28, 2022
484497a
Add global oparation to the parameter
rocking5566 Apr 28, 2022
52bb531
Add atomicmax
rocking5566 Apr 28, 2022
d4ebc8f
Fix compile error
rocking5566 Apr 28, 2022
119a532
Support atomicMax (hip library)
rocking5566 Apr 28, 2022
ca77a79
Rename the reduction example
rocking5566 Apr 28, 2022
9bd7b22
Fix target name
rocking5566 Apr 28, 2022
0162956
use p_d1_grid as the indicator directly
rocking5566 Apr 28, 2022
f1efd7a
Prevent performance issue. Let passthrough handle it.
rocking5566 Apr 28, 2022
d43efc6
Implement the function template the specialize the float2
rocking5566 Apr 29, 2022
04e5a26
No need to separate into two lines
rocking5566 Apr 29, 2022
bb8ee9e
Remove empty line
rocking5566 Apr 29, 2022
43a9777
add comment
Apr 30, 2022
0343326
Merge branch 'develop' into gemm_reduce_max
rocking5566 May 3, 2022
3856b42
Fix compile error due to merge from develop
rocking5566 May 3, 2022
f7b15bc
Merge branch 'gemm_reduce_max' of https://github.com/ROCmSoftwarePlat…
rocking5566 May 3, 2022
e046063
make the implementation of atomic_max / atomic_add explicit for each …
rocking5566 May 3, 2022
4f703e7
Refine typo
rocking5566 May 3, 2022
61ceeb2
For future CI test
rocking5566 May 3, 2022
b7a4036
Fix compiler error in ckProfiler
rocking5566 May 3, 2022
b920399
Merge commit 'de2769e3a6695b38a20529261273ddc5cdaab2fe'
rocking5566 May 10, 2022
420fe89
simply use remove_pointer
rocking5566 May 10, 2022
b1dfdb3
Rename type and var
rocking5566 May 10, 2022
52b2d1c
Refine example
rocking5566 May 10, 2022
4446797
Modify reducemax example
rocking5566 May 10, 2022
006a90c
Fix bug in reduction
rocking5566 May 11, 2022
08506cd
Change initialize range
rocking5566 May 11, 2022
a771da2
Implement F64 version of atomicMax
rocking5566 May 11, 2022
2c91be4
Move reduction code together
rocking5566 May 11, 2022
4ca460a
Add buffer atomic_max
May 12, 2022
d19ecd5
Fix coding style by clang-format
rocking5566 May 12, 2022
e698864
Integrate new api of DeviceGemmReduce_Xdl_CShuffle
rocking5566 May 12, 2022
5aab5aa
Integrate Batch gemm reduction
rocking5566 May 13, 2022
b348bf5
Fix example
rocking5566 May 13, 2022
afc62fa
Merge branch 'develop' into gemm_reduce_max
rocking5566 May 13, 2022
5172a16
Merge remote-tracking branch 'origin/develop' into gemm_reduce_max
May 14, 2022
bc8853c
fix example
May 14, 2022
727388b
clean up
May 14, 2022
00e5bd4
Fix batch gemm tensor operation
rocking5566 May 15, 2022
879dc62
Merge branch 'gemm_reduce_max' of https://github.com/ROCmSoftwarePlat…
rocking5566 May 15, 2022
a98b335
Fix coding style
rocking5566 May 15, 2022
980c095
Fix template augument
rocking5566 May 15, 2022
628fb87
Fix clang format
rocking5566 May 15, 2022
20ae1ce
Keep flexible of different stride for each D tensor
rocking5566 May 16, 2022
b96f809
Fix compile error for ckProfiler
rocking5566 May 16, 2022
d9e57ea
Fix typo
rocking5566 May 16, 2022
d11c7ec
[What] Fix naming
rocking5566 May 19, 2022
ec22933
Add DoutElementOp
rocking5566 May 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/16_gemm_reduce/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_example_executable(example_gemm_reduce_xdl_fp16 gemm_reduce_xdl_fp16.cpp)
add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp)
add_example_executable(example_gemm_reduce_xdl_sum_squaresum_fp16 gemm_reduce_xdl_sum_squaresum_fp16.cpp)
249 changes: 249 additions & 0 deletions example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>

#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using F16 = ck::half_t;
using F32 = float;
using F64 = double;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using ReduceAccDataType = F32;
using DDataType = F64;
using DPtrsGlobal = ck::Tuple<DDataType*>;

using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;

using AElementOp = ck::tensor_operation::element_wise::PassThrough;
Comment thread
asroy marked this conversation as resolved.
Outdated
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using DsReduceOp = ck::Tuple<ck::reduce::Max<ReduceAccDataType>>;
using DsElementOp = ck::Tuple<
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>>;
using DGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;

static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default;

// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on

using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;

int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;

// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;

ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;

if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);

M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);

StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}

auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};

Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));

Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));

Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));

std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "d_m: " << d_m_host_result.mDesc << std::endl;

switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}

DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace());

a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());

auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto ds_element_op = DsElementOp{};
auto p_ds_global = ck::make_tuple(static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()));

// do GEMM
auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
p_ds_global,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
ds_element_op,
ds_element_op);

if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}

// init D
d_device_buf.SetValue(ck::NumericLimits<DDataType>::Lowest());

float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});

std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

float gb_per_sec = num_btype / 1.E6 / ave_time;

std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;

bool pass = true;

if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
d_device_buf.FromDevice(d_m_device_result.mData.data());

auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();

auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);

ref_invoker.Run(ref_argument);

auto d_reduce_op = DsReduceOp{}[ck::Number<0>{}];

for(int m = 0; m < M; ++m)
{
ReduceAccDataType d_acc = d_reduce_op.GetReductionZeroVal();

for(int n = 0; n < N; ++n)
d_reduce_op(d_acc, c_m_n_host_result(m, n));

d_m_host_result(m) = d_acc;
}

pass = ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData,
"Error: Incorrect results c") &&
ck::utils::check_err(d_m_device_result.mData,
d_m_host_result.mData,
"Error: Incorrect results d",
1e-3,
1e-3);
}

return pass ? 0 : 1;
}
Loading