Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 62 additions & 13 deletions onnxruntime/core/providers/cpu/tensor/tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,32 @@ Status TileCoreForFixedSizeTypes(const Tensor& input_tensor, Tensor& output_tens
}

namespace TileOp {
// Find the first non-1 repeat and check the input shape to the left of that dimension,
// if the dim values are 1, then the tiling logic is essentially copying the input buffer
// multiple times. The number of times can be computed as the product of the repeat values.
// Find the first non-1 repeat and check the input shape to the left of that dimension:
// 1) If the dim values to the left are all 1s (or don't exist), then the tiling logic is essentially copying the input buffer
// multiple times. The number of times can be computed as the product of the repeat values. (OR)
// 2) Allow at-most one non-1 dim value to the left (for the batch dimension), in this case, the sub-tensor at each batch index
// is copied multiple times. This is still faster because it avoids other Tile operator's machinery.
bool IsTileMemcpy(const TensorShape& input_shape,
const int64_t* repeats,
size_t rank,
/*out*/ size_t& num_of_copies) {
/*out*/ bool& is_batched_memcpy,
/*out*/ size_t& num_of_elements_per_batch,
/*out*/ size_t& num_of_copies_per_batch,
/*out*/ size_t& num_of_batch_copies) {
for (int64_t i = static_cast<int64_t>(rank) - 1; i >= 0; --i) {
if (repeats[i] != 1) {
if (input_shape.SizeToDimension(i) == 1) {
num_of_copies = 1;
num_of_copies_per_batch = 1;
for (int64_t j = 0; j <= i; ++j) {
num_of_copies *= repeats[j];
num_of_copies_per_batch *= repeats[j];
}
is_batched_memcpy = false;
return true;
} else if (i == 1) { // else check if the previous dim is just the batch dim
num_of_elements_per_batch = static_cast<size_t>(input_shape.SizeFromDimension(1));
num_of_copies_per_batch = repeats[i];
num_of_batch_copies = repeats[0];
is_batched_memcpy = true;
return true;
} else {
break;
Expand Down Expand Up @@ -166,20 +178,57 @@ Status Tile::Compute(OpKernelContext* ctx) const {
return Status::OK();
}

size_t num_of_copies = 1;
if (TileOp::IsTileMemcpy(input_shape, repeats, input_rank, num_of_copies)) {
bool is_batched_memcpy = false;
size_t num_of_elements_per_batch = 1;
size_t num_of_copies_per_batch = 1;
size_t num_of_batch_copies = 1;
if (TileOp::IsTileMemcpy(input_shape,
repeats,
input_rank,
is_batched_memcpy,
num_of_elements_per_batch,
num_of_copies_per_batch,
num_of_batch_copies)) {
// TODO: Handle string copies when the kernel eventually supports string type.
// For now, it shouldn't throw in the enforce as the kernel doesn't claim string support
ORT_ENFORCE(!input_tensor.IsDataType<std::string>(), "Tile doesn't support string type yet");

int8_t* output_data_casted = reinterpret_cast<int8_t*>(output_tensor.MutableDataRaw());
const int8_t* input_data_casted = reinterpret_cast<const int8_t*>(input_tensor.DataRaw());
const void* input_data_raw = input_tensor.DataRaw();
size_t tensor_size_in_bytes = input_tensor.SizeInBytes();

// TODO: Add multi-threading logic if num_of_copies is large enough
for (size_t i = 0; i < num_of_copies; ++i) {
memcpy(static_cast<void*>(output_data_casted), input_data_raw, tensor_size_in_bytes);
output_data_casted += tensor_size_in_bytes;
if (!is_batched_memcpy) {
size_t copy_bytes = input_tensor.SizeInBytes();
// TODO: Add multi-threading logic if num_of_copies_per_batch is large enough
for (size_t i = 0; i < num_of_copies_per_batch; ++i) {
memcpy(static_cast<void*>(output_data_casted), input_data_raw, copy_bytes);
output_data_casted += copy_bytes;
}
} else {
size_t copy_bytes = num_of_elements_per_batch * input_tensor.DataType()->Size();
size_t batch_count = static_cast<size_t>(input_tensor.Shape()[0]); // The tensor is atleast 1-D- this is safe

// TODO: Multi-thread if needed
for (size_t batch = 0; batch < batch_count; ++batch) {
for (size_t i = 0; i < num_of_copies_per_batch; ++i) {
memcpy(static_cast<void*>(output_data_casted), static_cast<const void*>(input_data_casted), copy_bytes);
output_data_casted += copy_bytes;
}
input_data_casted += copy_bytes;
}

// Now account for batch dim repeat
if (num_of_batch_copies > 1) {
// reset some values
output_data_casted = reinterpret_cast<int8_t*>(output_tensor.MutableDataRaw());
copy_bytes *= num_of_copies_per_batch * batch_count;
int8_t* copy_ptr = output_data_casted + copy_bytes;

for (size_t i = 1; i < num_of_batch_copies; ++i) {
memcpy(static_cast<void*>(copy_ptr), static_cast<const void*>(output_data_casted), copy_bytes);
copy_ptr += copy_bytes;
}
}
}

return Status::OK();
Expand Down
16 changes: 15 additions & 1 deletion onnxruntime/core/providers/cpu/tensor/tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,24 @@ namespace TileOp {
// repeats: [1, 200, 1]
// output shape: [1, 200, 256 * 50]

// As a slight extension, it also supports "batched" multiple copies of the input data buffer
// (`is_batched_memcpy` will be set to true)
// E.g.: input_shape: [5, 1, 256 * 50]
// repeats: [1, 200, 1]
// output shape: [5, 200, 256 * 50]

// Repeating the batch is also supported
// E.g.: input_shape: [5, 1, 256 * 50]
// repeats: [2, 200, 1]
// output shape: [10, 200, 256 * 50]

bool IsTileMemcpy(const TensorShape& input_shape,
const int64_t* repeats,
size_t rank,
/*out*/ size_t& num_of_copies);
/*out*/ bool& is_batched_memcpy,
/*out*/ size_t& num_of_elements_per_batch,
/*out*/ size_t& num_of_copies_per_batch,
/*out*/ size_t& num_of_batch_copies);
} // namespace TileOp

struct Tile : OpKernel {
Expand Down
90 changes: 66 additions & 24 deletions onnxruntime/core/providers/cuda/tensor/tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,31 +76,73 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const {
return Status::OK();
}

size_t num_of_copies = 1;
if (TileOp::IsTileMemcpy(input_shape, repeats, rank, num_of_copies)) {
if (input_tensor.IsDataType<float>() ||
input_tensor.IsDataType<int32_t>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<float>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<float>::MappedType*>(output_data),
output_shape.Size());
} else if (input_tensor.IsDataType<double>() ||
input_tensor.IsDataType<int64_t>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<double>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<double>::MappedType*>(output_data),
output_shape.Size());
} else if (input_tensor.IsDataType<MLFloat16>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<MLFloat16>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<MLFloat16>::MappedType*>(output_data),
output_shape.Size());
bool is_batched_memcpy = false;
size_t num_of_elements_per_batch = 1;
size_t num_of_copies_per_batch = 1;
size_t num_of_batch_copies = 1;
if (TileOp::IsTileMemcpy(input_shape,
repeats,
rank,
is_batched_memcpy,
num_of_elements_per_batch,
num_of_copies_per_batch,
num_of_batch_copies)) {
if (!is_batched_memcpy) {
if (input_tensor.IsDataType<float>() ||
input_tensor.IsDataType<int32_t>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<float>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<float>::MappedType*>(output_data),
output_shape.Size());
} else if (input_tensor.IsDataType<double>() ||
input_tensor.IsDataType<int64_t>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<double>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<double>::MappedType*>(output_data),
output_shape.Size());
} else if (input_tensor.IsDataType<MLFloat16>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<MLFloat16>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<MLFloat16>::MappedType*>(output_data),
output_shape.Size());
} else {
// Won't hit this as the kernel doesn't claim support for any type that will trigger this
ORT_THROW("Tile doesn't have an implementation yet for the type: ", input_tensor.DataType());
}
} else {
// Won't hit this as the kernel doesn't claim support for any type that will trigger this
ORT_THROW("Tile doesn't have an implementation yet for the type: ", input_tensor.DataType());
if (input_tensor.IsDataType<float>() ||
input_tensor.IsDataType<int32_t>()) {
TileBatchedMemcpyImpl(
reinterpret_cast<const typename ToCudaType<float>::MappedType*>(input_data),
num_of_elements_per_batch,
input_shape[0], // The tensor is atleast 1-D- this is safe
fast_divmod(static_cast<int>(num_of_elements_per_batch * num_of_copies_per_batch)),
reinterpret_cast<typename ToCudaType<float>::MappedType*>(output_data),
output_shape.Size());
} else if (input_tensor.IsDataType<double>() ||
input_tensor.IsDataType<int64_t>()) {
TileBatchedMemcpyImpl(
reinterpret_cast<const typename ToCudaType<double>::MappedType*>(input_data),
num_of_elements_per_batch,
input_shape[0], // The tensor is atleast 1-D- this is safe
fast_divmod(static_cast<int>(num_of_elements_per_batch * num_of_copies_per_batch)),
reinterpret_cast<typename ToCudaType<double>::MappedType*>(output_data),
output_shape.Size());
} else if (input_tensor.IsDataType<MLFloat16>()) {
TileBatchedMemcpyImpl(
reinterpret_cast<const typename ToCudaType<MLFloat16>::MappedType*>(input_data),
num_of_elements_per_batch,
input_shape[0], // The tensor is atleast 1-D- this is safe
fast_divmod(static_cast<int>(num_of_elements_per_batch * num_of_copies_per_batch)),
reinterpret_cast<typename ToCudaType<MLFloat16>::MappedType*>(output_data),
output_shape.Size());
} else {
// Won't hit this as the kernel doesn't claim support for any type that will trigger this
ORT_THROW("Tile doesn't have an implementation yet for the type: ", input_tensor.DataType());
}
}

return Status::OK();
Expand Down
36 changes: 35 additions & 1 deletion onnxruntime/core/providers/cuda/tensor/tile_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,43 @@ void TileMemcpyImpl(
input_data, num_input_elements, output_data, (CUDA_LONG)num_output_elements);
}

template <typename T>
__global__ void _TileBatchedMemcpyKernel(
const T* input_data,
const size_t num_of_elements_per_input_batch,
const size_t num_input_batch_count,
const fast_divmod num_of_elements_per_output_batch,
T* output_data,
const size_t N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG batch_idx = 0;
CUDA_LONG element_idx = 0;
num_of_elements_per_output_batch.divmod(id, batch_idx, element_idx);
output_data[id] = input_data[(batch_idx % num_input_batch_count) * num_of_elements_per_input_batch + (element_idx % num_of_elements_per_input_batch)];
}

template <typename T>
void TileBatchedMemcpyImpl(
const T* input_data,
const size_t num_of_elements_per_input_batch,
const size_t num_input_batch_count,
const fast_divmod& num_of_elements_per_output_batch,
T* output_data,
const size_t num_output_elements) {
int blocksPerGrid = (int)(ceil(static_cast<float>(num_output_elements) / GridDim::maxThreadsPerBlock));
_TileBatchedMemcpyKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
input_data,
num_of_elements_per_input_batch,
num_input_batch_count,
num_of_elements_per_output_batch,
output_data,
(CUDA_LONG)num_output_elements);
}

#define SPECIALIZED_IMPL(T) \
template void TileImpl<T>(const size_t shape_rank, const TArray<fast_divmod>& fdm_input_shape, const TArray<int64_t>& input_stride, const T* input_data, const TArray<fast_divmod>& fdm_output_strides, T* output_data, const size_t N); \
template void TileMemcpyImpl<T>(const T* input_data, const size_t num_input_elements, T* output_data, const size_t num_output_elements);
template void TileMemcpyImpl<T>(const T* input_data, const size_t num_input_elements, T* output_data, const size_t num_output_elements); \
template void TileBatchedMemcpyImpl<T>(const T* input_data, const size_t num_of_elements_per_input_batch, const size_t num_input_batch_count, const fast_divmod& num_of_elements_per_output_batch, T* output_data, const size_t num_output_elements);

SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/tile_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,14 @@ void TileMemcpyImpl(
T* output_data,
const size_t num_output_elements);

template <typename T>
void TileBatchedMemcpyImpl(
const T* input_data,
const size_t num_of_elements_per_input_batch,
const size_t num_input_batch_count,
const fast_divmod& num_of_elements_per_output_batch,
T* output_data,
const size_t num_output_elements);

} // namespace cuda
} // namespace onnxruntime
24 changes: 24 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/tile_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,21 @@ void RunTestWrapper() {
RunTest<T>({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 1, 1}, {3}, {111, 112, 113, 122, 123, 124}, {2, 1, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 1
// This will trigger the MemCpy optimization path
RunTest<T>({111, 112, 113}, {1, 1, 3}, {2, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 111, 112, 113, 111, 112, 113}, {2, 2, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 2
// This will trigger the MemCpy optimization path
RunTest<T>({111, 112, 113}, {1, 1, 3}, {3, 1, 1}, {3}, {111, 112, 113, 111, 112, 113, 111, 112, 113}, {3, 1, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 3 (batch > 1 and batch_repeat == 1)
// This will trigger the (Batched) MemCpy optimization path
RunTest<T>({111, 112, 113, 11, 12, 13}, {2, 1, 3}, {1, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 11, 12, 13, 11, 12, 13}, {2, 2, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 3 (batch > 1 and batch_repeat > 1)
// This will trigger the (Batched) MemCpy optimization path
RunTest<T>({111, 112, 113, 11, 12, 13}, {2, 1, 3}, {2, 2, 1}, {3},
{111, 112, 113, 111, 112, 113, 11, 12, 13, 11, 12, 13, 111, 112, 113, 111, 112, 113, 11, 12, 13, 11, 12, 13}, {4, 2, 3});
}

template <>
Expand All @@ -78,10 +89,23 @@ void RunTestWrapper<bool>() {
RunTest<bool>({true, false, true, false, true, true}, {2, 1, 3}, {1, 1, 1}, {3}, {true, false, true, false, true, true}, {2, 1, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 1
// This will trigger the MemCpy optimization path
RunTest<bool>({true, false, true}, {1, 1, 3}, {2, 2, 1}, {3}, {true, false, true, true, false, true, true, false, true, true, false, true}, {2, 2, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 2
// This will trigger the MemCpy optimization path
RunTest<bool>({true, false, true}, {1, 1, 3}, {3, 1, 1}, {3}, {true, false, true, true, false, true, true, false, true}, {3, 1, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 3 (batch > 1 and batch_repeat == 1)
// This will trigger the (Batched) MemCpy optimization path
RunTest<bool>({true, false, true, true, false, true}, {2, 1, 3}, {1, 2, 1}, {3},
{true, false, true, true, false, true, true, false, true, true, false, true}, {2, 2, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 3 (batch > 1 and batch_repeat > 1)
// This will trigger the (Batched) MemCpy optimization path
RunTest<bool>({true, false, true, true, false, true}, {2, 1, 3}, {2, 2, 1}, {3},
{true, false, true, true, false, true, true, false, true, true, false, true, true, false, true, true, false, true, true, false, true, true, false, true},
{4, 2, 3});
}

TEST(TensorOpTest, TileFloatType) {
Expand Down