diff --git a/onnxruntime/core/providers/cpu/tensor/tile.cc b/onnxruntime/core/providers/cpu/tensor/tile.cc index 92b3290bd4abd..9f26e4e88a5f8 100644 --- a/onnxruntime/core/providers/cpu/tensor/tile.cc +++ b/onnxruntime/core/providers/cpu/tensor/tile.cc @@ -99,6 +99,31 @@ Status TileCoreForFixedSizeTypes(const Tensor& input_tensor, Tensor& output_tens return Status::OK(); } +namespace TileOp { +// Find the first non-1 repeat and check the input shape to the left of that dimension, +// if the dim values are 1, then the tiling logic is essentially copying the input buffer +// multiple times. The number of times can be computed as the product of the repeat values. +bool IsTileMemcpy(const TensorShape& input_shape, + const int64_t* repeats, + size_t rank, + /*out*/ size_t& num_of_copies) { + for (int64_t i = static_cast(rank) - 1; i >= 0; --i) { + if (repeats[i] != 1) { + if (input_shape.SizeToDimension(i) == 1) { + num_of_copies = 1; + for (int64_t j = 0; j <= i; ++j) { + num_of_copies *= repeats[j]; + } + return true; + } else { + break; + } + } + } + return false; +} +} // namespace TileOp + Status Tile::Compute(OpKernelContext* ctx) const { const auto* tensor_pointer = ctx->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the first one is empty"); @@ -116,19 +141,47 @@ Status Tile::Compute(OpKernelContext* ctx) const { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must have the same length as the 'input' tensor"); // Calculate the shape of the output tensor - auto* repeats = repeats_tensor.template Data(); + const auto* repeats = repeats_tensor.template Data(); std::vector output_dims = input_shape.GetDims(); for (size_t axis = 0; axis < input_rank; axis++) { output_dims[axis] *= repeats[axis]; } - TensorShape outputShape(output_dims); - auto& output_tensor = *ctx->Output(0, outputShape); + TensorShape output_shape(output_dims); + auto& output_tensor = *ctx->Output(0, output_shape); // Repeat tensor input can have 0 as a valid value - // check if the computed outputshape size is 0 and + // check if the computed output_shape size is 0 and // return an empty tensor if so. - if (outputShape.Size() == 0) { + if (output_shape.Size() == 0) { + return Status::OK(); + } + + // Repeat tensor has all 1s in it + if (output_shape == input_shape) { + // TODO: Handle string copies when the kernel eventually supports string type. + // For now, it shouldn't throw in the enforce as the kernel doesn't claim string support + ORT_ENFORCE(!input_tensor.IsDataType(), "Tile doesn't support string type yet"); + memcpy(output_tensor.MutableDataRaw(), input_tensor.DataRaw(), input_tensor.SizeInBytes()); + return Status::OK(); + } + + size_t num_of_copies = 1; + if (TileOp::IsTileMemcpy(input_shape, repeats, input_rank, num_of_copies)) { + // TODO: Handle string copies when the kernel eventually supports string type. + // For now, it shouldn't throw in the enforce as the kernel doesn't claim string support + ORT_ENFORCE(!input_tensor.IsDataType(), "Tile doesn't support string type yet"); + + int8_t* output_data_casted = reinterpret_cast(output_tensor.MutableDataRaw()); + const void* input_data_raw = input_tensor.DataRaw(); + size_t tensor_size_in_bytes = input_tensor.SizeInBytes(); + + // TODO: Add multi-threading logic if num_of_copies is large enough + for (size_t i = 0; i < num_of_copies; ++i) { + memcpy(static_cast(output_data_casted), input_data_raw, tensor_size_in_bytes); + output_data_casted += tensor_size_in_bytes; + } + return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/tile.h b/onnxruntime/core/providers/cpu/tensor/tile.h index e97afb07148d1..74d8ee7ded3ac 100644 --- a/onnxruntime/core/providers/cpu/tensor/tile.h +++ b/onnxruntime/core/providers/cpu/tensor/tile.h @@ -7,8 +7,21 @@ namespace onnxruntime { -struct Tile final : OpKernel { - Tile(const OpKernelInfo& info) : OpKernel(info) { +namespace TileOp { +// Function to determine if the tiling operation is just multiple copies +// of the input data buffer +// E.g.: input_shape: [1, 1, 256 * 50] +// repeats: [1, 200, 1] +// output shape: [1, 200, 256 * 50] + +bool IsTileMemcpy(const TensorShape& input_shape, + const int64_t* repeats, + size_t rank, + /*out*/ size_t& num_of_copies); +} // namespace TileOp + +struct Tile : OpKernel { + explicit Tile(const OpKernelInfo& info) : OpKernel(info) { } Status Compute(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc index bec33de6dd141..07dd5df43bce8 100644 --- a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc @@ -48,7 +48,7 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const { if (input_data != output_data) { // TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ? - cudaMemcpyAsync(output_data, input_data, element_size * input_shape.Size(), cudaMemcpyDeviceToDevice); + cudaMemcpyAsync(output_data, input_data, input_tensor->SizeInBytes(), cudaMemcpyDeviceToDevice); } // Bail out early diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc index 5ef2409a7f6aa..5aac0d0430525 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.cc +++ b/onnxruntime/core/providers/cuda/tensor/tile.cc @@ -4,6 +4,7 @@ #include "core/providers/cuda/tensor/tile.h" #include "core/providers/cpu/tensor/utils.h" #include "tile_impl.h" + using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { @@ -51,22 +52,66 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const { // Calculate the shape of the output tensor auto* repeats = repeats_tensor.template Data(); - const auto& input_shape = input_tensor.Shape().GetDims(); - std::vector output_dims(input_shape); + const auto& input_shape = input_tensor.Shape(); + const auto& input_dims = input_shape.GetDims(); + std::vector output_dims(input_dims); for (auto axis = 0; axis < rank; axis++) output_dims[axis] *= repeats[axis]; - TensorShape outputShape(output_dims); - auto& output_tensor = *ctx->Output(0, outputShape); + TensorShape output_shape(output_dims); + auto& output_tensor = *ctx->Output(0, output_shape); void* output_data = output_tensor.MutableDataRaw(); const void* input_data = input_tensor.DataRaw(); - TensorPitches input_pitches(input_shape); + // Repeat tensor input can have 0 as a valid value + // check if the computed output_shape size is 0 and + // return an empty tensor if so. + if (output_shape.Size() == 0) { + return Status::OK(); + } + + // Repeat tensor has all 1s in it + if (output_shape == input_shape) { + cudaMemcpyAsync(output_tensor.MutableDataRaw(), input_tensor.DataRaw(), input_tensor.SizeInBytes(), cudaMemcpyDeviceToDevice); + return Status::OK(); + } + + size_t num_of_copies = 1; + if (TileOp::IsTileMemcpy(input_shape, repeats, rank, num_of_copies)) { + if (input_tensor.IsDataType() || + input_tensor.IsDataType()) { + TileMemcpyImpl( + reinterpret_cast::MappedType*>(input_data), + input_shape.Size(), + reinterpret_cast::MappedType*>(output_data), + output_shape.Size()); + } else if (input_tensor.IsDataType() || + input_tensor.IsDataType()) { + TileMemcpyImpl( + reinterpret_cast::MappedType*>(input_data), + input_shape.Size(), + reinterpret_cast::MappedType*>(output_data), + output_shape.Size()); + } else if (input_tensor.IsDataType()) { + TileMemcpyImpl( + reinterpret_cast::MappedType*>(input_data), + input_shape.Size(), + reinterpret_cast::MappedType*>(output_data), + output_shape.Size()); + } else { + // Won't hit this as the kernel doesn't claim support for any type that will trigger this + ORT_THROW("Tile doesn't have an implementation yet for the type: ", input_tensor.DataType()); + } + + return Status::OK(); + } + + TensorPitches input_pitches(input_dims); TArray input_strides(input_pitches); TArray fdm_input_shape(rank); - for (int32_t i = 0; i < input_shape.size(); ++i) { - fdm_input_shape[i] = fast_divmod(gsl::narrow_cast(input_shape[i])); + for (int32_t i = 0; i < input_dims.size(); ++i) { + fdm_input_shape[i] = fast_divmod(gsl::narrow_cast(input_dims[i])); } TArray fdm_output_strides(rank); diff --git a/onnxruntime/core/providers/cuda/tensor/tile.h b/onnxruntime/core/providers/cuda/tensor/tile.h index 6dbacb268e6a6..5aea4574526fc 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.h +++ b/onnxruntime/core/providers/cuda/tensor/tile.h @@ -3,12 +3,13 @@ #include "core/common/common.h" #include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cpu/tensor/tile.h" namespace onnxruntime { namespace cuda { struct Tile final : CudaKernel { - Tile(const OpKernelInfo& info) : CudaKernel(info) { + explicit Tile(const OpKernelInfo& info) : CudaKernel(info) { } Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/tensor/tile_impl.cu b/onnxruntime/core/providers/cuda/tensor/tile_impl.cu index 7fa1834c7b4ae..33696d1a26076 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/tile_impl.cu @@ -45,8 +45,31 @@ void TileImpl( fdm_output_strides, output_data, (CUDA_LONG)N); } -#define SPECIALIZED_IMPL(T) \ - template void TileImpl(const size_t shape_rank, const TArray& fdm_input_shape, const TArray& input_stride, const T* input_data, const TArray& fdm_output_strides, T* output_data, const size_t N); +template +__global__ void _TileMemcpyKernel( + const T* input_data, + const size_t num_input_elements, + T* output_data, + const size_t N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + auto input_index = id % num_input_elements; + output_data[id] = input_data[input_index]; +} + +template +void TileMemcpyImpl( + const T* input_data, + const size_t num_input_elements, + T* output_data, + const size_t num_output_elements) { + int blocksPerGrid = (int)(ceil(static_cast(num_output_elements) / GridDim::maxThreadsPerBlock)); + _TileMemcpyKernel<<>>( + input_data, num_input_elements, output_data, (CUDA_LONG)num_output_elements); +} + +#define SPECIALIZED_IMPL(T) \ + template void TileImpl(const size_t shape_rank, const TArray& fdm_input_shape, const TArray& input_stride, const T* input_data, const TArray& fdm_output_strides, T* output_data, const size_t N); \ + template void TileMemcpyImpl(const T* input_data, const size_t num_input_elements, T* output_data, const size_t num_output_elements); SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double) diff --git a/onnxruntime/core/providers/cuda/tensor/tile_impl.h b/onnxruntime/core/providers/cuda/tensor/tile_impl.h index 356d40d9495b1..cfe5391073f68 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/tile_impl.h @@ -18,5 +18,12 @@ void TileImpl( T* output_data, const size_t N); +template +void TileMemcpyImpl( + const T* input_data, + const size_t num_input_elements, + T* output_data, + const size_t num_output_elements); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index f1d634a24802c..4d5c2384f7aec 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -19,7 +19,7 @@ void RunTest(std::initializer_list input, test.AddInput("repeats", repeat_dims, repeat); test.AddOutput("output", output_dims, output); if (std::is_same::value) - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT reports error: Assertion Error in makePaddedScale: 0 (regionRanges != nullptr) + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT reports error: Assertion Error in makePaddedScale: 0 (regionRanges != nullptr) else test.Run(); } @@ -43,6 +43,15 @@ void RunTestWrapper() { // Tile3D RunTest({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 122, 123, 124, 122, 123, 124}, {2, 2, 3}); + + // Tile1DWithOneRepeats + RunTest({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 1, 1}, {3}, {111, 112, 113, 122, 123, 124}, {2, 1, 3}); + + // TileWhichIsBasicallyCopiesOfInputBuffer - 1 + RunTest({111, 112, 113}, {1, 1, 3}, {2, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 111, 112, 113, 111, 112, 113}, {2, 2, 3}); + + // TileWhichIsBasicallyCopiesOfInputBuffer - 2 + RunTest({111, 112, 113}, {1, 1, 3}, {3, 1, 1}, {3}, {111, 112, 113, 111, 112, 113, 111, 112, 113}, {3, 1, 3}); } template <> @@ -64,6 +73,15 @@ void RunTestWrapper() { // Tile3D RunTest({true, false, true, false, true, false}, {2, 1, 3}, {1, 2, 1}, {3}, {true, false, true, true, false, true, false, true, false, false, true, false}, {2, 2, 3}); + + // Tile1DWithOneRepeats + RunTest({true, false, true, false, true, true}, {2, 1, 3}, {1, 1, 1}, {3}, {true, false, true, false, true, true}, {2, 1, 3}); + + // TileWhichIsBasicallyCopiesOfInputBuffer - 1 + RunTest({true, false, true}, {1, 1, 3}, {2, 2, 1}, {3}, {true, false, true, true, false, true, true, false, true, true, false, true}, {2, 2, 3}); + + // TileWhichIsBasicallyCopiesOfInputBuffer - 2 + RunTest({true, false, true}, {1, 1, 3}, {3, 1, 1}, {3}, {true, false, true, true, false, true, true, false, true}, {3, 1, 3}); } TEST(TensorOpTest, TileFloatType) {