Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions onnxruntime/core/providers/cpu/tensor/tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,23 +178,32 @@ bool IsTileMemcpy(const TensorShape& input_shape,

Status Tile::Compute(OpKernelContext* ctx) const {
const auto* tensor_pointer = ctx->Input<Tensor>(0);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the first one is empty");
if (tensor_pointer == nullptr)
return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the first one is empty");

const Tensor& input_tensor = *tensor_pointer;
const auto& input_shape = input_tensor.Shape();
const size_t input_rank = input_shape.NumDimensions();
tensor_pointer = ctx->Input<Tensor>(1);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the second one is empty");
if (tensor_pointer == nullptr)
return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the second one is empty");

const Tensor& repeats_tensor = *tensor_pointer;
if (repeats_tensor.Shape().NumDimensions() != 1)
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must be 1 dimensional");

if (size_t(repeats_tensor.Shape().Size()) != input_rank)
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must have the same length as the 'input' tensor");

// Calculate the shape of the output tensor
const auto* repeats = repeats_tensor.Data<int64_t>();
auto output_dims = input_shape.AsShapeVector();
for (size_t axis = 0; axis < input_rank; axis++) {
output_dims[axis] *= repeats[axis];
if (repeats[axis] < 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Tile repeat value must be non-negative, got: ", repeats[axis]);
}
output_dims[axis] = SafeInt<int64_t>(output_dims[axis]) * repeats[axis];
}

TensorShape output_shape(output_dims);
Expand Down
24 changes: 15 additions & 9 deletions onnxruntime/core/providers/cuda/tensor/tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ONNX_OPERATOR_KERNEL_EX(

#define CASE_TILE(type) \
case sizeof(type): { \
TileImpl(Stream(ctx), rank, fdm_input_shape, input_strides, \
TileImpl(Stream(ctx), input_rank, fdm_input_shape, input_strides, \
reinterpret_cast<const typename ToCudaType<type>::MappedType*>(input_data), fdm_output_strides, \
reinterpret_cast<typename ToCudaType<type>::MappedType*>(output_data), output_tensor.Shape().Size()); \
} break
Expand All @@ -66,20 +66,26 @@ ONNX_OPERATOR_KERNEL_EX(
Status Tile::ComputeInternal(OpKernelContext* ctx) const {
auto& input_tensor = *ctx->Input<Tensor>(0);
auto& repeats_tensor = *ctx->Input<Tensor>(1);
int32_t rank = static_cast<int32_t>(input_tensor.Shape().NumDimensions());
int32_t input_rank = static_cast<int32_t>(input_tensor.Shape().NumDimensions());

if (repeats_tensor.Shape().NumDimensions() != 1)
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must be 1 dimensional");
if (repeats_tensor.Shape().Size() != rank)
if (repeats_tensor.Shape().Size() != input_rank)
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must have the same length as the 'input' tensor");

// Calculate the shape of the output tensor
auto* repeats = repeats_tensor.Data<int64_t>();
const auto& input_shape = input_tensor.Shape();
const auto input_dims = input_shape.GetDims();
auto output_dims(input_shape.AsShapeVector());
for (auto axis = 0; axis < rank; axis++)
output_dims[axis] *= repeats[axis];
for (size_t axis = 0; axis < input_rank; axis++) {
if (repeats[axis] < 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Tile repeat value must be non-negative, got: ", repeats[axis]);
}
output_dims[axis] = SafeInt<int64_t>(output_dims[axis]) * repeats[axis];
}

TensorShape output_shape(output_dims);
auto& output_tensor = *ctx->Output(0, output_shape);

Expand All @@ -105,7 +111,7 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const {
size_t num_of_batch_copies = 1;
if (TileOp::IsTileMemcpy(input_shape,
repeats,
rank,
input_rank,
is_batched_memcpy,
num_of_elements_per_batch,
num_of_copies_per_batch,
Expand Down Expand Up @@ -136,14 +142,14 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const {
TensorPitches input_pitches(input_dims);
TArray<int64_t> input_strides(input_pitches);

TArray<fast_divmod> fdm_input_shape(rank);
TArray<fast_divmod> fdm_input_shape(input_rank);
for (size_t i = 0; i < input_dims.size(); ++i) {
fdm_input_shape[gsl::narrow_cast<int>(i)] = fast_divmod(gsl::narrow_cast<int>(input_dims[i]));
}

TArray<fast_divmod> fdm_output_strides(rank);
TArray<fast_divmod> fdm_output_strides(input_rank);
TensorPitches output_pitches(output_dims);
for (auto i = 0; i < rank; i++) {
for (auto i = 0; i < input_rank; i++) {
fdm_output_strides[i] = fast_divmod(static_cast<int>(output_pitches[i]));
}

Expand Down
15 changes: 9 additions & 6 deletions onnxruntime/core/providers/webgpu/tensor/tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,16 @@ Status Tile::ComputeInternal(ComputeContext& context) const {
const auto* repeats_data = repeats_tensor->Data<int64_t>();
std::vector<uint32_t> repeats;

for (size_t i = 0; i < static_cast<uint32_t>(repeats_tensor->Shape().Size()); i++) {
repeats.push_back(static_cast<uint32_t>(repeats_data[i]));
}

auto output_dims = input_shape.AsShapeVector();
for (size_t axis = 0; axis < input_rank; axis++) {
output_dims[axis] *= repeats[axis];
if (repeats_data[axis] < 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Tile repeat value must be non-negative, got: ", repeats_data[axis]);
}
output_dims[axis] = SafeInt<int64_t>(output_dims[axis]) * repeats_data[axis];
}
for (size_t i = 0; i < static_cast<size_t>(repeats_tensor->Shape().Size()); i++) {
repeats.push_back(static_cast<uint32_t>(repeats_data[i]));
}

TensorShape output_shape(output_dims);
Expand All @@ -83,4 +86,4 @@ Status Tile::ComputeInternal(ComputeContext& context) const {
}

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime
90 changes: 90 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/tile_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,95 @@ TEST(TensorOpTest, TileBoolType) { RunTestWrapperForBool(); }
TEST(TensorOpTest, TileMLFloat16Type) { RunTestWrapper<MLFloat16>(); }
#endif

// Test that negative repeat values are rejected with an error
TEST(TensorOpTest, TileNegativeRepeats) {
OpTester test("Tile", 13);
test.AddInput<float>("input", {3}, {1.0f, 2.0f, 3.0f});
test.AddInput<int64_t>("repeats", {1}, {-1});
test.AddOutput<float>("output", {0}, {});
test.Run(OpTester::ExpectResult::kExpectFailure);
}

// Test that negative repeat values are rejected for multi-dimensional input
TEST(TensorOpTest, TileNegativeRepeats2D) {
OpTester test("Tile", 13);
test.AddInput<float>("input", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
test.AddInput<int64_t>("repeats", {2}, {2, -3});
test.AddOutput<float>("output", {0, 0}, {});
test.Run(OpTester::ExpectResult::kExpectFailure);
}

// Test that overflow in output dimension computation is caught by SafeInt.
// input_dim=3 * repeat=6148914691236517206 would overflow int64_t:
// 3 * 6148914691236517206 = 2^64 + 2, which wraps to 2 without overflow protection.
// SafeInt should detect this and throw.
TEST(TensorOpTest, TileOverflowRepeats1D) {
OpTester test("Tile", 13);
test.AddInput<float>("input", {3}, {1.0f, 2.0f, 3.0f});
// 6148914691236517206 = (2^64 + 2) / 3, so 3 * 6148914691236517206 overflows int64_t
test.AddInput<int64_t>("repeats", {1}, {int64_t{6148914691236517206}});
test.AddOutput<float>("output", {0}, {});
test.Run(OpTester::ExpectResult::kExpectFailure, "", {kTensorrtExecutionProvider});
}

// Test overflow detection with a 2D tensor where only one axis overflows
TEST(TensorOpTest, TileOverflowRepeats2D) {
OpTester test("Tile", 13);
test.AddInput<float>("input", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
// Second axis: 3 * 6148914691236517206 overflows
test.AddInput<int64_t>("repeats", {2}, {1, int64_t{6148914691236517206}});
test.AddOutput<float>("output", {0, 0}, {});
test.Run(OpTester::ExpectResult::kExpectFailure, "", {kTensorrtExecutionProvider});
}

// Test overflow detection with large repeat on first axis
TEST(TensorOpTest, TileOverflowRepeatsFirstAxis) {
OpTester test("Tile", 13);
test.AddInput<float>("input", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
// First axis: 2 * 4611686018427387904 = 2^63 which overflows signed int64_t
test.AddInput<int64_t>("repeats", {2}, {int64_t{4611686018427387904}, 1});
test.AddOutput<float>("output", {0, 0}, {});
test.Run(OpTester::ExpectResult::kExpectFailure, "", {kTensorrtExecutionProvider});
}

// Test overflow with int32 input type to ensure all fixed-size type paths are covered
TEST(TensorOpTest, TileOverflowRepeatsInt32) {
OpTester test("Tile", 13);
test.AddInput<int32_t>("input", {3}, {1, 2, 3});
test.AddInput<int64_t>("repeats", {1}, {int64_t{6148914691236517206}});
test.AddOutput<int32_t>("output", {0}, {});
test.Run(OpTester::ExpectResult::kExpectFailure, "", {kTensorrtExecutionProvider});
}

// Test overflow with string input type
TEST(TensorOpTest, TileOverflowRepeatsString) {
OpTester test("Tile", 13);
test.AddInput<std::string>("input", {3}, {"a", "b", "c"});
test.AddInput<int64_t>("repeats", {1}, {int64_t{6148914691236517206}});
test.AddOutput<std::string>("output", {0}, {});
test.Run(OpTester::ExpectResult::kExpectFailure);
}

// Test with INT64_MAX as repeat value (should overflow for any input dim > 1)
TEST(TensorOpTest, TileOverflowMaxRepeat) {
OpTester test("Tile", 13);
test.AddInput<float>("input", {2}, {1.0f, 2.0f});
// INT64_MAX = 9223372036854775807; 2 * INT64_MAX overflows
test.AddInput<int64_t>("repeats", {1}, {std::numeric_limits<int64_t>::max()});
test.AddOutput<float>("output", {0}, {});
test.Run(OpTester::ExpectResult::kExpectFailure);
}

// Test overflow where multiple axes combine to create an impossibly large output
TEST(TensorOpTest, TileOverflowMultipleAxes) {
OpTester test("Tile", 13);
test.AddInput<float>("input", {2, 2, 2}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
// Each axis: 2 * large_value overflows
int64_t large_repeat = int64_t{4611686018427387904}; // 2^62
test.AddInput<int64_t>("repeats", {3}, {large_repeat, 1, 1});
test.AddOutput<float>("output", {0, 0, 0}, {});
test.Run(OpTester::ExpectResult::kExpectFailure, "", {kTensorrtExecutionProvider});
}

} // namespace test
} // namespace onnxruntime
Loading