-
Notifications
You must be signed in to change notification settings - Fork 294
Gemm reduce max #209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Gemm reduce max #209
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
5aaa11c
[What] Rename the example
rocking5566 484497a
Add global oparation to the parameter
rocking5566 52bb531
Add atomicmax
rocking5566 d4ebc8f
Fix compile error
rocking5566 119a532
Support atomicMax (hip library)
rocking5566 ca77a79
Rename the reduction example
rocking5566 9bd7b22
Fix target name
rocking5566 0162956
use p_d1_grid as the indicator directly
rocking5566 f1efd7a
Prevent performance issue. Let passthrough handle it.
rocking5566 d43efc6
Implement the function template the specialize the float2
rocking5566 04e5a26
No need to separate into two lines
rocking5566 bb8ee9e
Remove empty line
rocking5566 43a9777
add comment
0343326
Merge branch 'develop' into gemm_reduce_max
rocking5566 3856b42
Fix compile error due to merge from develop
rocking5566 f7b15bc
Merge branch 'gemm_reduce_max' of https://github.com/ROCmSoftwarePlat…
rocking5566 e046063
make the implementation of atomic_max / atomic_add explicit for each …
rocking5566 4f703e7
Refine typo
rocking5566 61ceeb2
For future CI test
rocking5566 b7a4036
Fix compiler error in ckProfiler
rocking5566 b920399
Merge commit 'de2769e3a6695b38a20529261273ddc5cdaab2fe'
rocking5566 420fe89
simply use remove_pointer
rocking5566 b1dfdb3
Rename type and var
rocking5566 52b2d1c
Refine example
rocking5566 4446797
Modify reducemax example
rocking5566 006a90c
Fix bug in reduction
rocking5566 08506cd
Change initialize range
rocking5566 a771da2
Implement F64 version of atomicMax
rocking5566 2c91be4
Move reduction code together
rocking5566 4ca460a
Add buffer atomic_max
d19ecd5
Fix coding style by clang-format
rocking5566 e698864
Integrate new api of DeviceGemmReduce_Xdl_CShuffle
rocking5566 5aab5aa
Integrate Batch gemm reduction
rocking5566 b348bf5
Fix example
rocking5566 afc62fa
Merge branch 'develop' into gemm_reduce_max
rocking5566 5172a16
Merge remote-tracking branch 'origin/develop' into gemm_reduce_max
bc8853c
fix example
727388b
clean up
00e5bd4
Fix batch gemm tensor operation
rocking5566 879dc62
Merge branch 'gemm_reduce_max' of https://github.com/ROCmSoftwarePlat…
rocking5566 a98b335
Fix coding style
rocking5566 980c095
Fix template augument
rocking5566 628fb87
Fix clang format
rocking5566 20ae1ce
Keep flexible of different stride for each D tensor
rocking5566 b96f809
Fix compile error for ckProfiler
rocking5566 d9e57ea
Fix typo
rocking5566 d11c7ec
[What] Fix naming
rocking5566 ec22933
Add DoutElementOp
rocking5566 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| add_example_executable(example_gemm_reduce_xdl_fp16 gemm_reduce_xdl_fp16.cpp) | ||
| add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp) | ||
| add_example_executable(example_gemm_reduce_xdl_sum_squaresum_fp16 gemm_reduce_xdl_sum_squaresum_fp16.cpp) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,249 @@ | ||
| #include <iostream> | ||
| #include <numeric> | ||
| #include <initializer_list> | ||
| #include <cstdlib> | ||
| #include <stdlib.h> | ||
|
|
||
| #include "check_err.hpp" | ||
| #include "config.hpp" | ||
| #include "device.hpp" | ||
| #include "host_tensor.hpp" | ||
| #include "host_tensor_generator.hpp" | ||
| #include "device_tensor.hpp" | ||
| #include "device_gemm_reduce_xdl_cshuffle.hpp" | ||
| #include "element_wise_operation.hpp" | ||
| #include "reference_gemm.hpp" | ||
| #include "gemm_specialization.hpp" | ||
| #include "element_wise_reduce_operation.hpp" | ||
|
|
||
| template <ck::index_t... Is> | ||
| using S = ck::Sequence<Is...>; | ||
|
|
||
| using F16 = ck::half_t; | ||
| using F32 = float; | ||
| using F64 = double; | ||
|
|
||
| using Row = ck::tensor_layout::gemm::RowMajor; | ||
| using Col = ck::tensor_layout::gemm::ColumnMajor; | ||
|
|
||
| using ADataType = F16; | ||
| using BDataType = F16; | ||
| using CDataType = F16; | ||
| using ReduceAccDataType = F32; | ||
| using DDataType = F64; | ||
| using DPtrsGlobal = ck::Tuple<DDataType*>; | ||
|
|
||
| using ALayout = ck::tensor_layout::gemm::RowMajor; | ||
| using BLayout = ck::tensor_layout::gemm::ColumnMajor; | ||
| using CLayout = ck::tensor_layout::gemm::RowMajor; | ||
|
|
||
| using AElementOp = ck::tensor_operation::element_wise::PassThrough; | ||
| using BElementOp = ck::tensor_operation::element_wise::PassThrough; | ||
| using CElementOp = ck::tensor_operation::element_wise::PassThrough; | ||
| using DsReduceOp = ck::Tuple<ck::reduce::Max<ReduceAccDataType>>; | ||
| using DsElementOp = ck::Tuple< | ||
| ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>>; | ||
| using DGlobalMemOp = | ||
| ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>; | ||
|
|
||
| static constexpr auto GemmSpecialization = | ||
| ck::tensor_operation::device::GemmSpecialization::Default; | ||
|
|
||
| // clang-format off | ||
| using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle | ||
| //######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| | ||
| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| | ||
| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | ||
| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ||
| < Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; | ||
| // clang-format on | ||
|
|
||
| using ReferenceGemmInstance = ck::tensor_operation::host:: | ||
| ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; | ||
|
|
||
| int main(int argc, char* argv[]) | ||
| { | ||
| bool do_verification = true; | ||
| int init_method = 1; | ||
| bool time_kernel = false; | ||
|
|
||
| // GEMM shape | ||
| ck::index_t M = 3840; | ||
| ck::index_t N = 4096; | ||
| ck::index_t K = 4096; | ||
|
|
||
| ck::index_t StrideA = 4096; | ||
| ck::index_t StrideB = 4096; | ||
| ck::index_t StrideC = 4096; | ||
|
|
||
| if(argc == 1) | ||
| { | ||
| // do nothing | ||
| } | ||
| else if(argc == 4) | ||
| { | ||
| do_verification = std::stoi(argv[1]); | ||
| init_method = std::stoi(argv[2]); | ||
| time_kernel = std::stoi(argv[3]); | ||
| } | ||
| else if(argc == 10) | ||
| { | ||
| do_verification = std::stoi(argv[1]); | ||
| init_method = std::stoi(argv[2]); | ||
| time_kernel = std::stoi(argv[3]); | ||
|
|
||
| M = std::stoi(argv[4]); | ||
| N = std::stoi(argv[5]); | ||
| K = std::stoi(argv[6]); | ||
|
|
||
| StrideA = std::stoi(argv[7]); | ||
| StrideB = std::stoi(argv[8]); | ||
| StrideC = std::stoi(argv[9]); | ||
| } | ||
| else | ||
| { | ||
| printf("arg1: verification (0=no, 1=yes)\n"); | ||
| printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); | ||
| printf("arg3: run kernel # of times (>1)\n"); | ||
| printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); | ||
| exit(0); | ||
| } | ||
|
|
||
| auto f_host_tensor_descriptor = | ||
| [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { | ||
| if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) | ||
| { | ||
| return HostTensorDescriptor(std::vector<std::size_t>({row, col}), | ||
| std::vector<std::size_t>({stride, 1})); | ||
| } | ||
| else | ||
| { | ||
| return HostTensorDescriptor(std::vector<std::size_t>({row, col}), | ||
| std::vector<std::size_t>({1, stride})); | ||
| } | ||
| }; | ||
|
|
||
| Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); | ||
| Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); | ||
|
|
||
| Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); | ||
| Tensor<DDataType> d_m_host_result( | ||
| HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); | ||
|
|
||
| Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); | ||
| Tensor<DDataType> d_m_device_result( | ||
| HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); | ||
|
|
||
| std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; | ||
| std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; | ||
| std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; | ||
| std::cout << "d_m: " << d_m_host_result.mDesc << std::endl; | ||
|
|
||
| switch(init_method) | ||
| { | ||
| case 0: break; | ||
| case 1: | ||
| a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); | ||
| b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); | ||
| break; | ||
| default: | ||
| a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); | ||
| b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); | ||
| break; | ||
| } | ||
|
|
||
| DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); | ||
| DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); | ||
| DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); | ||
| DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace()); | ||
|
|
||
| a_device_buf.ToDevice(a_m_k.mData.data()); | ||
| b_device_buf.ToDevice(b_k_n.mData.data()); | ||
|
|
||
| auto a_element_op = AElementOp{}; | ||
| auto b_element_op = BElementOp{}; | ||
| auto c_element_op = CElementOp{}; | ||
| auto ds_element_op = DsElementOp{}; | ||
| auto p_ds_global = ck::make_tuple(static_cast<DDataType*>(d_device_buf.GetDeviceBuffer())); | ||
|
|
||
| // do GEMM | ||
| auto gemm = DeviceGemmReduceInstance{}; | ||
| auto invoker = gemm.MakeInvoker(); | ||
| auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), | ||
| static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), | ||
| static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), | ||
| p_ds_global, | ||
| M, | ||
| N, | ||
| K, | ||
| StrideA, | ||
| StrideB, | ||
| StrideC, | ||
| a_element_op, | ||
| b_element_op, | ||
| c_element_op, | ||
| ds_element_op, | ||
| ds_element_op); | ||
|
|
||
| if(!gemm.IsSupportedArgument(argument)) | ||
| { | ||
| throw std::runtime_error( | ||
| "wrong! device_gemm with the specified compilation parameters does " | ||
| "not support this GEMM problem"); | ||
| } | ||
|
|
||
| // init D | ||
| d_device_buf.SetValue(ck::NumericLimits<DDataType>::Lowest()); | ||
|
|
||
| float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); | ||
|
|
||
| std::size_t flop = std::size_t(2) * M * N * K; | ||
| std::size_t num_btype = | ||
| sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; | ||
|
|
||
| float tflops = static_cast<float>(flop) / 1.E9 / ave_time; | ||
|
|
||
| float gb_per_sec = num_btype / 1.E6 / ave_time; | ||
|
|
||
| std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " | ||
| << gemm.GetTypeString() << std::endl; | ||
|
|
||
| bool pass = true; | ||
|
|
||
| if(do_verification) | ||
| { | ||
| c_device_buf.FromDevice(c_m_n_device_result.mData.data()); | ||
| d_device_buf.FromDevice(d_m_device_result.mData.data()); | ||
|
|
||
| auto ref_gemm = ReferenceGemmInstance{}; | ||
| auto ref_invoker = ref_gemm.MakeInvoker(); | ||
|
|
||
| auto ref_argument = ref_gemm.MakeArgument( | ||
| a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); | ||
|
|
||
| ref_invoker.Run(ref_argument); | ||
|
|
||
| auto d_reduce_op = DsReduceOp{}[ck::Number<0>{}]; | ||
|
|
||
| for(int m = 0; m < M; ++m) | ||
| { | ||
| ReduceAccDataType d_acc = d_reduce_op.GetReductionZeroVal(); | ||
|
|
||
| for(int n = 0; n < N; ++n) | ||
| d_reduce_op(d_acc, c_m_n_host_result(m, n)); | ||
|
|
||
| d_m_host_result(m) = d_acc; | ||
| } | ||
|
|
||
| pass = ck::utils::check_err(c_m_n_device_result.mData, | ||
| c_m_n_host_result.mData, | ||
| "Error: Incorrect results c") && | ||
| ck::utils::check_err(d_m_device_result.mData, | ||
| d_m_host_result.mData, | ||
| "Error: Incorrect results d", | ||
| 1e-3, | ||
| 1e-3); | ||
| } | ||
|
|
||
| return pass ? 0 : 1; | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.