Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions onnxruntime/core/providers/cpu/tensor/tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,31 @@ Status TileCoreForFixedSizeTypes(const Tensor& input_tensor, Tensor& output_tens
return Status::OK();
}

namespace TileOp {
// Find the first non-1 repeat and check the input shape to the left of that dimension,
// if the dim values are 1, then the tiling logic is essentially copying the input buffer
// multiple times. The number of times can be computed as the product of the repeat values.
bool IsTileMemcpy(const TensorShape& input_shape,
const int64_t* repeats,
size_t rank,
/*out*/ size_t& num_of_copies) {
for (int64_t i = static_cast<int64_t>(rank) - 1; i >= 0; --i) {
if (repeats[i] != 1) {
if (input_shape.SizeToDimension(i) == 1) {
num_of_copies = 1;
for (int64_t j = 0; j <= i; ++j) {
num_of_copies *= repeats[j];
}
return true;
} else {
break;
}
}
}
return false;
}
} // namespace TileOp

Status Tile::Compute(OpKernelContext* ctx) const {
const auto* tensor_pointer = ctx->Input<Tensor>(0);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the first one is empty");
Expand All @@ -116,19 +141,47 @@ Status Tile::Compute(OpKernelContext* ctx) const {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must have the same length as the 'input' tensor");

// Calculate the shape of the output tensor
auto* repeats = repeats_tensor.template Data<int64_t>();
const auto* repeats = repeats_tensor.template Data<int64_t>();
std::vector<int64_t> output_dims = input_shape.GetDims();
for (size_t axis = 0; axis < input_rank; axis++) {
output_dims[axis] *= repeats[axis];
}

TensorShape outputShape(output_dims);
auto& output_tensor = *ctx->Output(0, outputShape);
TensorShape output_shape(output_dims);
auto& output_tensor = *ctx->Output(0, output_shape);

// Repeat tensor input can have 0 as a valid value
// check if the computed outputshape size is 0 and
// check if the computed output_shape size is 0 and
// return an empty tensor if so.
if (outputShape.Size() == 0) {
if (output_shape.Size() == 0) {
return Status::OK();
}

// Repeat tensor has all 1s in it
if (output_shape == input_shape) {
// TODO: Handle string copies when the kernel eventually supports string type.
// For now, it shouldn't throw in the enforce as the kernel doesn't claim string support
ORT_ENFORCE(!input_tensor.IsDataType<std::string>(), "Tile doesn't support string type yet");
memcpy(output_tensor.MutableDataRaw(), input_tensor.DataRaw(), input_tensor.SizeInBytes());
return Status::OK();
}

size_t num_of_copies = 1;
if (TileOp::IsTileMemcpy(input_shape, repeats, input_rank, num_of_copies)) {
// TODO: Handle string copies when the kernel eventually supports string type.
// For now, it shouldn't throw in the enforce as the kernel doesn't claim string support
ORT_ENFORCE(!input_tensor.IsDataType<std::string>(), "Tile doesn't support string type yet");

int8_t* output_data_casted = reinterpret_cast<int8_t*>(output_tensor.MutableDataRaw());
const void* input_data_raw = input_tensor.DataRaw();
size_t tensor_size_in_bytes = input_tensor.SizeInBytes();

// TODO: Add multi-threading logic if num_of_copies is large enough
for (size_t i = 0; i < num_of_copies; ++i) {
memcpy(static_cast<void*>(output_data_casted), input_data_raw, tensor_size_in_bytes);
output_data_casted += tensor_size_in_bytes;
}

return Status::OK();
}

Expand Down
17 changes: 15 additions & 2 deletions onnxruntime/core/providers/cpu/tensor/tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,21 @@

namespace onnxruntime {

struct Tile final : OpKernel {
Tile(const OpKernelInfo& info) : OpKernel(info) {
namespace TileOp {
// Function to determine if the tiling operation is just multiple copies
// of the input data buffer
// E.g.: input_shape: [1, 1, 256 * 50]
// repeats: [1, 200, 1]
// output shape: [1, 200, 256 * 50]

bool IsTileMemcpy(const TensorShape& input_shape,
const int64_t* repeats,
size_t rank,
/*out*/ size_t& num_of_copies);
} // namespace TileOp

struct Tile : OpKernel {
explicit Tile(const OpKernelInfo& info) : OpKernel(info) {
}

Status Compute(OpKernelContext* context) const override;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/tensor/scatter_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {

if (input_data != output_data) {
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ?
cudaMemcpyAsync(output_data, input_data, element_size * input_shape.Size(), cudaMemcpyDeviceToDevice);
cudaMemcpyAsync(output_data, input_data, input_tensor->SizeInBytes(), cudaMemcpyDeviceToDevice);
}

// Bail out early
Expand Down
59 changes: 52 additions & 7 deletions onnxruntime/core/providers/cuda/tensor/tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/cuda/tensor/tile.h"
#include "core/providers/cpu/tensor/utils.h"
#include "tile_impl.h"

using namespace onnxruntime::common;
namespace onnxruntime {
namespace cuda {
Expand Down Expand Up @@ -51,22 +52,66 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const {

// Calculate the shape of the output tensor
auto* repeats = repeats_tensor.template Data<int64_t>();
const auto& input_shape = input_tensor.Shape().GetDims();
std::vector<int64_t> output_dims(input_shape);
const auto& input_shape = input_tensor.Shape();
const auto& input_dims = input_shape.GetDims();
std::vector<int64_t> output_dims(input_dims);
for (auto axis = 0; axis < rank; axis++)
output_dims[axis] *= repeats[axis];
TensorShape outputShape(output_dims);
auto& output_tensor = *ctx->Output(0, outputShape);
TensorShape output_shape(output_dims);
auto& output_tensor = *ctx->Output(0, output_shape);

void* output_data = output_tensor.MutableDataRaw();
const void* input_data = input_tensor.DataRaw();

TensorPitches input_pitches(input_shape);
// Repeat tensor input can have 0 as a valid value
// check if the computed output_shape size is 0 and
// return an empty tensor if so.
if (output_shape.Size() == 0) {
return Status::OK();
}

// Repeat tensor has all 1s in it
if (output_shape == input_shape) {
cudaMemcpyAsync(output_tensor.MutableDataRaw(), input_tensor.DataRaw(), input_tensor.SizeInBytes(), cudaMemcpyDeviceToDevice);
return Status::OK();
}

size_t num_of_copies = 1;
if (TileOp::IsTileMemcpy(input_shape, repeats, rank, num_of_copies)) {
if (input_tensor.IsDataType<float>() ||
input_tensor.IsDataType<int32_t>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<float>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<float>::MappedType*>(output_data),
output_shape.Size());
} else if (input_tensor.IsDataType<double>() ||
input_tensor.IsDataType<int64_t>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<double>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<double>::MappedType*>(output_data),
output_shape.Size());
} else if (input_tensor.IsDataType<MLFloat16>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<MLFloat16>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<MLFloat16>::MappedType*>(output_data),
output_shape.Size());
} else {
// Won't hit this as the kernel doesn't claim support for any type that will trigger this
ORT_THROW("Tile doesn't have an implementation yet for the type: ", input_tensor.DataType());
}

return Status::OK();
}

TensorPitches input_pitches(input_dims);
TArray<int64_t> input_strides(input_pitches);

TArray<fast_divmod> fdm_input_shape(rank);
for (int32_t i = 0; i < input_shape.size(); ++i) {
fdm_input_shape[i] = fast_divmod(gsl::narrow_cast<int>(input_shape[i]));
for (int32_t i = 0; i < input_dims.size(); ++i) {
fdm_input_shape[i] = fast_divmod(gsl::narrow_cast<int>(input_dims[i]));
}

TArray<fast_divmod> fdm_output_strides(rank);
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/tensor/tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

#include "core/common/common.h"
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/cpu/tensor/tile.h"

namespace onnxruntime {
namespace cuda {

struct Tile final : CudaKernel {
Tile(const OpKernelInfo& info) : CudaKernel(info) {
explicit Tile(const OpKernelInfo& info) : CudaKernel(info) {
}

Status ComputeInternal(OpKernelContext* context) const override;
Expand Down
27 changes: 25 additions & 2 deletions onnxruntime/core/providers/cuda/tensor/tile_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,31 @@ void TileImpl(
fdm_output_strides, output_data, (CUDA_LONG)N);
}

#define SPECIALIZED_IMPL(T) \
template void TileImpl<T>(const size_t shape_rank, const TArray<fast_divmod>& fdm_input_shape, const TArray<int64_t>& input_stride, const T* input_data, const TArray<fast_divmod>& fdm_output_strides, T* output_data, const size_t N);
template <typename T>
__global__ void _TileMemcpyKernel(
const T* input_data,
const size_t num_input_elements,
T* output_data,
const size_t N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
auto input_index = id % num_input_elements;
output_data[id] = input_data[input_index];
}

template <typename T>
void TileMemcpyImpl(
const T* input_data,
const size_t num_input_elements,
T* output_data,
const size_t num_output_elements) {
int blocksPerGrid = (int)(ceil(static_cast<float>(num_output_elements) / GridDim::maxThreadsPerBlock));
_TileMemcpyKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
input_data, num_input_elements, output_data, (CUDA_LONG)num_output_elements);
}

#define SPECIALIZED_IMPL(T) \
template void TileImpl<T>(const size_t shape_rank, const TArray<fast_divmod>& fdm_input_shape, const TArray<int64_t>& input_stride, const T* input_data, const TArray<fast_divmod>& fdm_output_strides, T* output_data, const size_t N); \
template void TileMemcpyImpl<T>(const T* input_data, const size_t num_input_elements, T* output_data, const size_t num_output_elements);

SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/tile_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,12 @@ void TileImpl(
T* output_data,
const size_t N);

template <typename T>
void TileMemcpyImpl(
const T* input_data,
const size_t num_input_elements,
T* output_data,
const size_t num_output_elements);

} // namespace cuda
} // namespace onnxruntime
20 changes: 19 additions & 1 deletion onnxruntime/test/providers/cpu/tensor/tile_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void RunTest(std::initializer_list<T> input,
test.AddInput<int64_t>("repeats", repeat_dims, repeat);
test.AddOutput<T>("output", output_dims, output);
if (std::is_same<T, int8_t>::value)
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT reports error: Assertion Error in makePaddedScale: 0 (regionRanges != nullptr)
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT reports error: Assertion Error in makePaddedScale: 0 (regionRanges != nullptr)
else
test.Run();
}
Expand All @@ -43,6 +43,15 @@ void RunTestWrapper() {

// Tile3D
RunTest<T>({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 122, 123, 124, 122, 123, 124}, {2, 2, 3});

// Tile1DWithOneRepeats
RunTest<T>({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 1, 1}, {3}, {111, 112, 113, 122, 123, 124}, {2, 1, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 1
RunTest<T>({111, 112, 113}, {1, 1, 3}, {2, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 111, 112, 113, 111, 112, 113}, {2, 2, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 2
RunTest<T>({111, 112, 113}, {1, 1, 3}, {3, 1, 1}, {3}, {111, 112, 113, 111, 112, 113, 111, 112, 113}, {3, 1, 3});
}

template <>
Expand All @@ -64,6 +73,15 @@ void RunTestWrapper<bool>() {

// Tile3D
RunTest<bool>({true, false, true, false, true, false}, {2, 1, 3}, {1, 2, 1}, {3}, {true, false, true, true, false, true, false, true, false, false, true, false}, {2, 2, 3});

// Tile1DWithOneRepeats
RunTest<bool>({true, false, true, false, true, true}, {2, 1, 3}, {1, 1, 1}, {3}, {true, false, true, false, true, true}, {2, 1, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 1
RunTest<bool>({true, false, true}, {1, 1, 3}, {2, 2, 1}, {3}, {true, false, true, true, false, true, true, false, true, true, false, true}, {2, 2, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 2
RunTest<bool>({true, false, true}, {1, 1, 3}, {3, 1, 1}, {3}, {true, false, true, true, false, true, true, false, true}, {3, 1, 3});
}

TEST(TensorOpTest, TileFloatType) {
Expand Down