-
Notifications
You must be signed in to change notification settings - Fork 300
Hotfix binary elementwise (for broadcast on fastest axis) #254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
414f375
11d873f
ffa66b3
07e8d29
a24dc28
4a77339
d2f0f98
495d05b
c343472
0334a92
5912fa0
221146a
f564d7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| add_example_executable(example_broadcast_add_2d broadcast_add_2d.cpp) | ||
| add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp) | ||
| add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp) | ||
| add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp) | ||
| add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| #include <iostream> | ||
| #include <cstdlib> | ||
| #include "check_err.hpp" | ||
| #include "config.hpp" | ||
| #include "device.hpp" | ||
| #include "host_tensor.hpp" | ||
| #include "host_tensor_generator.hpp" | ||
|
|
||
| #include "device_tensor.hpp" | ||
| #include "binary_element_wise_operation.hpp" | ||
| #include "device_binary_elementwise.hpp" | ||
|
|
||
| using F16 = ck::half_t; | ||
| using F32 = float; | ||
|
|
||
| using ABDataType = F16; | ||
| using CDataType = F16; | ||
| using EltwiseComputeDataType = F32; | ||
|
|
||
| using Add = ck::tensor_operation::binary_element_wise::Add; | ||
|
|
||
| using DeviceElementwiseAddInstance = | ||
| ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType, | ||
| ABDataType, | ||
| CDataType, | ||
| EltwiseComputeDataType, | ||
| Add, | ||
| 3, | ||
| 8, | ||
| 1, | ||
| 8>; | ||
|
|
||
| template <typename HostTensorA, | ||
| typename HostTensorB, | ||
| typename HostTensorC, | ||
| typename ComputeDataType, | ||
| typename Functor> | ||
| void host_broadcast3D_am_bmnk(HostTensorC& C, | ||
| const HostTensorA& A, | ||
| const HostTensorB& B, | ||
| const std::vector<std::size_t>& shape, | ||
| Functor functor) | ||
| { | ||
| using ctype = ck::remove_reference_t<decltype(C(0, 0))>; | ||
|
|
||
| for(std::size_t m = 0; m < shape[0]; ++m) | ||
| for(std::size_t n = 0; n < shape[1]; ++n) | ||
| for(std::size_t k = 0; k < shape[2]; ++k) | ||
| { | ||
| ComputeDataType a_val = static_cast<ComputeDataType>(A(m)); | ||
| ComputeDataType b_val = static_cast<ComputeDataType>(B(m, n, k)); | ||
| ComputeDataType c_val = 0; | ||
| functor(c_val, a_val, b_val); | ||
| C(m, n, k) = static_cast<ctype>(c_val); | ||
| } | ||
| } | ||
|
|
||
| int main() | ||
| { | ||
| bool do_verification = true; | ||
| bool time_kernel = false; | ||
|
|
||
| std::vector<std::size_t> mnk = {4, 16, 32}; | ||
| ck::index_t M = mnk[0]; | ||
|
|
||
| Tensor<ABDataType> a_m({M}); | ||
| Tensor<ABDataType> b_m_n_k(mnk); | ||
| Tensor<CDataType> c_m_n_k(mnk); | ||
|
|
||
| a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); | ||
| b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); | ||
|
|
||
| DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); | ||
| DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace()); | ||
| DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace()); | ||
|
|
||
| a_m_device_buf.ToDevice(a_m.mData.data()); | ||
| b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data()); | ||
|
|
||
| auto broadcastAdd = DeviceElementwiseAddInstance{}; | ||
| auto argument = broadcastAdd.MakeArgumentPointer( | ||
| a_m_device_buf.GetDeviceBuffer(), | ||
| b_m_n_k_device_buf.GetDeviceBuffer(), | ||
| c_m_n_k_device_buf.GetDeviceBuffer(), | ||
| std::vector<ck::index_t>{mnk.begin(), mnk.end()}, | ||
| {1, 0, 0}, // broadcast A on second and third dimension | ||
| std::vector<ck::index_t>{b_m_n_k.mDesc.GetStrides().begin(), | ||
| b_m_n_k.mDesc.GetStrides().end()}, | ||
| std::vector<ck::index_t>{c_m_n_k.mDesc.GetStrides().begin(), | ||
| c_m_n_k.mDesc.GetStrides().end()}, | ||
| Add{}); | ||
|
|
||
| if(!broadcastAdd.IsSupportedArgument(argument.get())) | ||
| { | ||
| throw std::runtime_error("The runtime parameters seems not supported by the " | ||
| "DeviceBinaryElementwise instance, exiting!"); | ||
| }; | ||
|
|
||
| auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); | ||
| float ave_time = | ||
| broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); | ||
|
|
||
| std::cout << "Perf: " << ave_time << " ms" << std::endl; | ||
|
|
||
| bool pass = true; | ||
| if(do_verification) | ||
| { | ||
| c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data()); | ||
| Tensor<CDataType> host_c_m_n_k(mnk); | ||
|
|
||
| host_broadcast3D_am_bmnk<Tensor<ABDataType>, | ||
| Tensor<ABDataType>, | ||
| Tensor<CDataType>, | ||
| EltwiseComputeDataType, | ||
| Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{}); | ||
|
|
||
| pass &= ck::utils::check_err( | ||
| c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3); | ||
| } | ||
|
|
||
| return pass ? 0 : 1; | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,9 @@ template <typename ADataType, | |
| typename ComputeDataType, | ||
| typename ElementwiseFunctor, | ||
| index_t Dim, | ||
| index_t ScalarPerVector> | ||
| index_t M0PerThread, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering why call the dimension
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I only support [M,N] before. So I used M0. |
||
| index_t AScalarPerVector = M0PerThread, | ||
| index_t BScalarPerVector = M0PerThread> | ||
|
asroy marked this conversation as resolved.
Outdated
|
||
| struct DeviceBinaryElementwise : public BaseOperator | ||
| { | ||
| static constexpr auto I0 = Number<0>{}; | ||
|
|
@@ -25,7 +27,7 @@ struct DeviceBinaryElementwise : public BaseOperator | |
| static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) | ||
| { | ||
| const auto m0 = desc_m0.GetLength(I0); | ||
| const index_t loop_step = gridSize * blockSize * ScalarPerVector; | ||
| const index_t loop_step = gridSize * blockSize * M0PerThread; | ||
| const auto pad = math::integer_least_multiple(m0, loop_step) - m0; | ||
| const auto desc_m0_pad = | ||
| transform_tensor_descriptor(desc_m0, | ||
|
|
@@ -68,7 +70,9 @@ struct DeviceBinaryElementwise : public BaseOperator | |
| ComputeDataType, | ||
| GridDesc_M0, | ||
| ElementwiseFunctor, | ||
| ScalarPerVector>; | ||
| M0PerThread, | ||
| AScalarPerVector, | ||
| BScalarPerVector>; | ||
|
|
||
| struct Argument : public BaseArgument | ||
| { | ||
|
|
@@ -84,6 +88,8 @@ struct DeviceBinaryElementwise : public BaseOperator | |
| p_b_(p_b), | ||
| p_c_(p_c), | ||
| shape_(shape), | ||
|
asroy marked this conversation as resolved.
Outdated
|
||
| stride_a_(stride_a), | ||
| stride_b_(stride_b), | ||
| functor_(functor), | ||
| blockSize_(256), | ||
| gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future | ||
|
|
@@ -100,6 +106,8 @@ struct DeviceBinaryElementwise : public BaseOperator | |
| GridDesc_M0 a_grid_desc_m0_; | ||
| GridDesc_M0 b_grid_desc_m0_; | ||
| GridDesc_M0 c_grid_desc_m0_; | ||
| std::vector<index_t> stride_a_; | ||
|
asroy marked this conversation as resolved.
Outdated
|
||
| std::vector<index_t> stride_b_; | ||
| ElementwiseFunctor functor_; | ||
| index_t blockSize_; | ||
| index_t gridSize_; | ||
|
|
@@ -139,14 +147,35 @@ struct DeviceBinaryElementwise : public BaseOperator | |
| } | ||
| }; | ||
|
|
||
| bool IsScalarPerVectorValid(bool broadcastOnFastest, int scalarPerVector) | ||
|
asroy marked this conversation as resolved.
Outdated
|
||
| { | ||
| bool ret = true; | ||
|
|
||
| if(broadcastOnFastest) | ||
| ret = scalarPerVector == 1; | ||
| else | ||
| ret = M0PerThread % scalarPerVector == 0; | ||
|
|
||
| return ret; | ||
| } | ||
|
|
||
| bool IsSupportedArgument(const BaseArgument* p_arg) override | ||
| { | ||
| const Argument* pArg = dynamic_cast<const Argument*>(p_arg); | ||
|
|
||
| if(pArg == nullptr) | ||
| return false; | ||
|
|
||
| if(pArg->shape_.back() % ScalarPerVector != 0) | ||
| if(pArg->shape_.size() != Dim) | ||
| return false; | ||
|
|
||
| if(pArg->shape_.back() % M0PerThread != 0) | ||
| return false; | ||
|
|
||
| if(!IsScalarPerVectorValid(pArg->stride_a_.back() == 0, AScalarPerVector)) | ||
|
asroy marked this conversation as resolved.
Outdated
|
||
| return false; | ||
|
|
||
| if(!IsScalarPerVectorValid(pArg->stride_b_.back() == 0, BScalarPerVector)) | ||
|
asroy marked this conversation as resolved.
Outdated
|
||
| return false; | ||
|
|
||
| return true; | ||
|
|
@@ -180,7 +209,7 @@ struct DeviceBinaryElementwise : public BaseOperator | |
| // clang-format off | ||
| str << "DeviceBinaryElementwise" | ||
| << "<" | ||
| << "ScalarPerVector = " << ScalarPerVector | ||
| << "M0PerThread = " << M0PerThread | ||
| << ">"; | ||
| // clang-format on | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,19 +36,21 @@ template <typename ADataType, | |
| typename ComputeDataType, | ||
| typename GridDesc_M0, | ||
|
asroy marked this conversation as resolved.
Outdated
|
||
| typename ElementwiseFunctor, | ||
| index_t ScalarPerVector> | ||
| index_t M0PerThread, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, wondering why
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| index_t AScalarPerVector = M0PerThread, | ||
| index_t BScalarPerVector = M0PerThread> | ||
|
asroy marked this conversation as resolved.
Outdated
|
||
| struct GridwiseBinaryElementwise_1D | ||
| { | ||
| static constexpr auto I0 = Number<0>{}; | ||
| static constexpr auto thread_desc_m0 = | ||
| make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{})); | ||
| make_naive_tensor_descriptor_packed(make_tuple(Number<M0PerThread>{})); | ||
|
|
||
| using PassThrough = tensor_operation::element_wise::PassThrough; | ||
|
|
||
| static __device__ auto CalculateElementwiseIndex() | ||
| { | ||
| const index_t global_thread_id = get_thread_global_1d_id(); | ||
| return make_multi_index(global_thread_id * ScalarPerVector); | ||
| return make_multi_index(global_thread_id * M0PerThread); | ||
| } | ||
|
|
||
| __device__ static void Run(const ADataType* __restrict__ p_a_global, | ||
|
|
@@ -66,9 +68,9 @@ struct GridwiseBinaryElementwise_1D | |
| auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( | ||
| p_c_global, c_grid_desc_m0.GetElementSpaceSize()); | ||
|
|
||
| StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf; | ||
| StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf; | ||
| StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf; | ||
| StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, M0PerThread, true> a_thread_buf; | ||
| StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, M0PerThread, true> b_thread_buf; | ||
| StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, M0PerThread, true> c_thread_buf; | ||
|
|
||
| const auto thread_store_global_offset = CalculateElementwiseIndex(); | ||
|
|
||
|
|
@@ -77,10 +79,10 @@ struct GridwiseBinaryElementwise_1D | |
| ComputeDataType, | ||
| GridDesc_M0, | ||
| decltype(thread_desc_m0), | ||
| Sequence<ScalarPerVector>, // SliceLengths | ||
| Sequence<0>, // DimAccessOrder | ||
| 0, // SrcVectorDim | ||
| ScalarPerVector, | ||
| Sequence<M0PerThread>, // SliceLengths | ||
| Sequence<0>, // DimAccessOrder | ||
| 0, // SrcVectorDim | ||
| AScalarPerVector, | ||
| 1, // SrcScalarStrideInVector | ||
| false>{a_grid_desc_m0, thread_store_global_offset}; | ||
|
|
||
|
|
@@ -89,10 +91,10 @@ struct GridwiseBinaryElementwise_1D | |
| ComputeDataType, | ||
| GridDesc_M0, | ||
| decltype(thread_desc_m0), | ||
| Sequence<ScalarPerVector>, // SliceLengths | ||
| Sequence<0>, // DimAccessOrder | ||
| 0, // SrcVectorDim | ||
| ScalarPerVector, | ||
| Sequence<M0PerThread>, // SliceLengths | ||
| Sequence<0>, // DimAccessOrder | ||
| 0, // SrcVectorDim | ||
| BScalarPerVector, | ||
| 1, // SrcScalarStrideInVector | ||
| false>{b_grid_desc_m0, thread_store_global_offset}; | ||
|
|
||
|
|
@@ -102,10 +104,10 @@ struct GridwiseBinaryElementwise_1D | |
| decltype(thread_desc_m0), | ||
| GridDesc_M0, | ||
| PassThrough, | ||
| Sequence<ScalarPerVector>, // SliceLengths | ||
| Sequence<0>, // DimAccessOrder | ||
| 0, // DstVectorDim | ||
| ScalarPerVector, | ||
| Sequence<M0PerThread>, // SliceLengths | ||
| Sequence<0>, // DimAccessOrder | ||
| 0, // DstVectorDim | ||
| M0PerThread, | ||
| InMemoryDataOperationEnum::Set, | ||
| 1, // DstScalarStrideInVector | ||
| false>{ | ||
|
|
@@ -114,20 +116,20 @@ struct GridwiseBinaryElementwise_1D | |
| const index_t blockSize = get_block_size(); | ||
| const index_t blockPerGrid = get_grid_size(); | ||
| const auto m0 = c_grid_desc_m0.GetLength(I0); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For length and stride of a dimension (instead of index), let's use capital letter https://github.com/ROCmSoftwarePlatform/composable_kernel/wiki/Coding-Style
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector; | ||
| const index_t loop_step = blockPerGrid * blockSize * M0PerThread; | ||
| const auto loop_step_index = make_multi_index(loop_step); | ||
|
|
||
| index_t num_iter = m0 / (loop_step); | ||
| do | ||
| { | ||
| // read and process ScalarPerVector elements | ||
| // read and process M0PerThread elements | ||
| a_global_load.Run( | ||
| a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); | ||
|
|
||
| b_global_load.Run( | ||
| b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf); | ||
|
|
||
| static_for<0, ScalarPerVector, 1>{}([&](auto m) { | ||
| static_for<0, M0PerThread, 1>{}([&](auto m) { | ||
| constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); | ||
| functor(c_thread_buf(Number<offset>{}), | ||
| a_thread_buf(Number<offset>{}), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NDim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
221146a