Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/19_binary_elementwise/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_example_executable(example_broadcast_add_2d broadcast_add_2d.cpp)
add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp)
add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp)
add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp)
add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp)
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!");
"DeviceBinaryElementwise instance, exiting!");
};

auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
Expand All @@ -123,7 +123,7 @@ int main()
0>(host_c_m_n, a_m_n, b_n, M, N, Add{});

pass &= ck::utils::check_err(
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}

return pass ? 0 : 1;
Expand Down
122 changes: 122 additions & 0 deletions example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include <iostream>
#include <cstdlib>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"

#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"

using F16 = ck::half_t;
using F32 = float;

using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;

using Add = ck::tensor_operation::binary_element_wise::Add;

using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
3,
8,
1,
8>;

template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor>
void host_broadcast3D_am_bmnk(HostTensorC& C,
const HostTensorA& A,
const HostTensorB& B,
const std::vector<std::size_t>& shape,
Functor functor)
{
using ctype = ck::remove_reference_t<decltype(C(0, 0))>;

for(std::size_t m = 0; m < shape[0]; ++m)
for(std::size_t n = 0; n < shape[1]; ++n)
for(std::size_t k = 0; k < shape[2]; ++k)
{
ComputeDataType a_val = static_cast<ComputeDataType>(A(m));
ComputeDataType b_val = static_cast<ComputeDataType>(B(m, n, k));
ComputeDataType c_val = 0;
functor(c_val, a_val, b_val);
C(m, n, k) = static_cast<ctype>(c_val);
}
}

int main()
{
bool do_verification = true;
bool time_kernel = false;

std::vector<std::size_t> mnk = {4, 16, 32};
ck::index_t M = mnk[0];

Tensor<ABDataType> a_m({M});
Tensor<ABDataType> b_m_n_k(mnk);
Tensor<CDataType> c_m_n_k(mnk);

a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});

DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace());
DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace());
DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace());

a_m_device_buf.ToDevice(a_m.mData.data());
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());

auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
a_m_device_buf.GetDeviceBuffer(),
b_m_n_k_device_buf.GetDeviceBuffer(),
c_m_n_k_device_buf.GetDeviceBuffer(),
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
{1, 0, 0}, // broadcast A on second and third dimension
std::vector<ck::index_t>{b_m_n_k.mDesc.GetStrides().begin(),
b_m_n_k.mDesc.GetStrides().end()},
std::vector<ck::index_t>{c_m_n_k.mDesc.GetStrides().begin(),
c_m_n_k.mDesc.GetStrides().end()},
Add{});

if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!");
};

auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
float ave_time =
broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});

std::cout << "Perf: " << ave_time << " ms" << std::endl;

bool pass = true;
if(do_verification)
{
c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data());
Tensor<CDataType> host_c_m_n_k(mnk);

host_broadcast3D_am_bmnk<Tensor<ABDataType>,
Tensor<ABDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});

pass &= ck::utils::check_err(
c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}

return pass ? 0 : 1;
}
4 changes: 2 additions & 2 deletions example/19_binary_elementwise/elementwise_add_1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!");
"DeviceBinaryElementwise instance, exiting!");
};

auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
Expand All @@ -103,7 +103,7 @@ int main()
Add>(host_c_m, a_m, b_m, M, Add{});

pass &= ck::utils::check_err(
c_m.mData, host_c_m.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}

return pass ? 0 : 1;
Expand Down
4 changes: 2 additions & 2 deletions example/19_binary_elementwise/elementwise_add_4d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!");
"DeviceBinaryElementwise instance, exiting!");
};

auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
Expand All @@ -105,7 +105,7 @@ int main()
Add>(host_c, a, b, nchw, Add{});

pass &=
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}

return pass ? 0 : 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ template <typename ADataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t Dim,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NDim

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_t ScalarPerVector>
index_t M0PerThread,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering why call the dimension M0 instead of just M?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only support [M,N] before. So I used M0.
Now I can use M directly
221146a

index_t AScalarPerVector = M0PerThread,
index_t BScalarPerVector = M0PerThread>
Comment thread
asroy marked this conversation as resolved.
Outdated
struct DeviceBinaryElementwise : public BaseOperator
{
static constexpr auto I0 = Number<0>{};
Expand All @@ -25,7 +27,7 @@ struct DeviceBinaryElementwise : public BaseOperator
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
{
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
const index_t loop_step = gridSize * blockSize * M0PerThread;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
Expand Down Expand Up @@ -68,7 +70,9 @@ struct DeviceBinaryElementwise : public BaseOperator
ComputeDataType,
GridDesc_M0,
ElementwiseFunctor,
ScalarPerVector>;
M0PerThread,
AScalarPerVector,
BScalarPerVector>;

struct Argument : public BaseArgument
{
Expand All @@ -84,6 +88,8 @@ struct DeviceBinaryElementwise : public BaseOperator
p_b_(p_b),
p_c_(p_c),
shape_(shape),
Comment thread
asroy marked this conversation as resolved.
Outdated
stride_a_(stride_a),
stride_b_(stride_b),
functor_(functor),
blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
Expand All @@ -100,6 +106,8 @@ struct DeviceBinaryElementwise : public BaseOperator
GridDesc_M0 a_grid_desc_m0_;
GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_;
std::vector<index_t> stride_a_;
Comment thread
asroy marked this conversation as resolved.
Outdated
std::vector<index_t> stride_b_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
Expand Down Expand Up @@ -139,14 +147,35 @@ struct DeviceBinaryElementwise : public BaseOperator
}
};

bool IsScalarPerVectorValid(bool broadcastOnFastest, int scalarPerVector)
Comment thread
asroy marked this conversation as resolved.
Outdated
{
bool ret = true;

if(broadcastOnFastest)
ret = scalarPerVector == 1;
else
ret = M0PerThread % scalarPerVector == 0;

return ret;
}

bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);

if(pArg == nullptr)
return false;

if(pArg->shape_.back() % ScalarPerVector != 0)
if(pArg->shape_.size() != Dim)
return false;

if(pArg->shape_.back() % M0PerThread != 0)
return false;

if(!IsScalarPerVectorValid(pArg->stride_a_.back() == 0, AScalarPerVector))
Comment thread
asroy marked this conversation as resolved.
Outdated
return false;

if(!IsScalarPerVectorValid(pArg->stride_b_.back() == 0, BScalarPerVector))
Comment thread
asroy marked this conversation as resolved.
Outdated
return false;

return true;
Expand Down Expand Up @@ -180,7 +209,7 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off
str << "DeviceBinaryElementwise"
<< "<"
<< "ScalarPerVector = " << ScalarPerVector
<< "M0PerThread = " << M0PerThread
<< ">";
// clang-format on

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,21 @@ template <typename ADataType,
typename ComputeDataType,
typename GridDesc_M0,
Comment thread
asroy marked this conversation as resolved.
Outdated
typename ElementwiseFunctor,
index_t ScalarPerVector>
index_t M0PerThread,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, wondering why M0 instead of simply M

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_t AScalarPerVector = M0PerThread,
index_t BScalarPerVector = M0PerThread>
Comment thread
asroy marked this conversation as resolved.
Outdated
struct GridwiseBinaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
make_naive_tensor_descriptor_packed(make_tuple(Number<M0PerThread>{}));

using PassThrough = tensor_operation::element_wise::PassThrough;

static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * ScalarPerVector);
return make_multi_index(global_thread_id * M0PerThread);
}

__device__ static void Run(const ADataType* __restrict__ p_a_global,
Expand All @@ -66,9 +68,9 @@ struct GridwiseBinaryElementwise_1D
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m0.GetElementSpaceSize());

StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, M0PerThread, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, M0PerThread, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, M0PerThread, true> c_thread_buf;

const auto thread_store_global_offset = CalculateElementwiseIndex();

Expand All @@ -77,10 +79,10 @@ struct GridwiseBinaryElementwise_1D
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
Sequence<M0PerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
AScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m0, thread_store_global_offset};

Expand All @@ -89,10 +91,10 @@ struct GridwiseBinaryElementwise_1D
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
Sequence<M0PerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
BScalarPerVector,
1, // SrcScalarStrideInVector
false>{b_grid_desc_m0, thread_store_global_offset};

Expand All @@ -102,10 +104,10 @@ struct GridwiseBinaryElementwise_1D
decltype(thread_desc_m0),
GridDesc_M0,
PassThrough,
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
ScalarPerVector,
Sequence<M0PerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
M0PerThread,
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
Expand All @@ -114,20 +116,20 @@ struct GridwiseBinaryElementwise_1D
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto m0 = c_grid_desc_m0.GetLength(I0);

@asroy asroy May 25, 2022

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const auto M = c_grid_desc_m.GetLength(I0);

For length and stride of a dimension (instead of index), let's use capital letter

https://github.com/ROCmSoftwarePlatform/composable_kernel/wiki/Coding-Style

@rocking5566 rocking5566 May 25, 2022

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector;
const index_t loop_step = blockPerGrid * blockSize * M0PerThread;
const auto loop_step_index = make_multi_index(loop_step);

index_t num_iter = m0 / (loop_step);
do
{
// read and process ScalarPerVector elements
// read and process M0PerThread elements
a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);

b_global_load.Run(
b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf);

static_for<0, ScalarPerVector, 1>{}([&](auto m) {
static_for<0, M0PerThread, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
Expand Down