From e1f43cb099c37ab485b515ceea2465253676daed Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 5 Mar 2026 11:14:29 -0800 Subject: [PATCH 1/5] Address overflow and generate tests --- onnxruntime/core/providers/cpu/tensor/tile.cc | 15 +++- .../core/providers/cuda/tensor/tile.cc | 24 +++-- .../test/providers/cpu/tensor/tile_op_test.cc | 90 +++++++++++++++++++ 3 files changed, 117 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/tile.cc b/onnxruntime/core/providers/cpu/tensor/tile.cc index c53e68021e649..b73825e0f10fe 100644 --- a/onnxruntime/core/providers/cpu/tensor/tile.cc +++ b/onnxruntime/core/providers/cpu/tensor/tile.cc @@ -178,15 +178,20 @@ bool IsTileMemcpy(const TensorShape& input_shape, Status Tile::Compute(OpKernelContext* ctx) const { const auto* tensor_pointer = ctx->Input(0); - if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the first one is empty"); + if (tensor_pointer == nullptr) + return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the first one is empty"); + const Tensor& input_tensor = *tensor_pointer; const auto& input_shape = input_tensor.Shape(); const size_t input_rank = input_shape.NumDimensions(); tensor_pointer = ctx->Input(1); - if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the second one is empty"); + if (tensor_pointer == nullptr) + return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the second one is empty"); + const Tensor& repeats_tensor = *tensor_pointer; if (repeats_tensor.Shape().NumDimensions() != 1) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must be 1 dimensional"); + if (size_t(repeats_tensor.Shape().Size()) != input_rank) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must have the same length as the 'input' tensor"); @@ -194,7 +199,11 @@ Status Tile::Compute(OpKernelContext* ctx) const { const auto* repeats = repeats_tensor.Data(); auto output_dims = input_shape.AsShapeVector(); for (size_t axis = 0; axis < input_rank; axis++) { - output_dims[axis] *= repeats[axis]; + if (repeats[axis] < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Tile repeat value must be non-negative, got: ", repeats[axis]); + } + output_dims[axis] = SafeInt(output_dims[axis]) * repeats[axis]; } TensorShape output_shape(output_dims); diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc index 01522c71dc51b..7521f1bce7930 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.cc +++ b/onnxruntime/core/providers/cuda/tensor/tile.cc @@ -43,7 +43,7 @@ ONNX_OPERATOR_KERNEL_EX( #define CASE_TILE(type) \ case sizeof(type): { \ - TileImpl(Stream(ctx), rank, fdm_input_shape, input_strides, \ + TileImpl(Stream(ctx), input_rank, fdm_input_shape, input_strides, \ reinterpret_cast::MappedType*>(input_data), fdm_output_strides, \ reinterpret_cast::MappedType*>(output_data), output_tensor.Shape().Size()); \ } break @@ -66,11 +66,11 @@ ONNX_OPERATOR_KERNEL_EX( Status Tile::ComputeInternal(OpKernelContext* ctx) const { auto& input_tensor = *ctx->Input(0); auto& repeats_tensor = *ctx->Input(1); - int32_t rank = static_cast(input_tensor.Shape().NumDimensions()); + int32_t input_rank = static_cast(input_tensor.Shape().NumDimensions()); if (repeats_tensor.Shape().NumDimensions() != 1) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must be 1 dimensional"); - if (repeats_tensor.Shape().Size() != rank) + if (repeats_tensor.Shape().Size() != input_rank) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must have the same length as the 'input' tensor"); // Calculate the shape of the output tensor @@ -78,8 +78,14 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const { const auto& input_shape = input_tensor.Shape(); const auto input_dims = input_shape.GetDims(); auto output_dims(input_shape.AsShapeVector()); - for (auto axis = 0; axis < rank; axis++) - output_dims[axis] *= repeats[axis]; + for (size_t axis = 0; axis < input_rank; axis++) { + if (repeats[axis] < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Tile repeat value must be non-negative, got: ", repeats[axis]); + } + output_dims[axis] = SafeInt(output_dims[axis]) * repeats[axis]; + } + TensorShape output_shape(output_dims); auto& output_tensor = *ctx->Output(0, output_shape); @@ -105,7 +111,7 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const { size_t num_of_batch_copies = 1; if (TileOp::IsTileMemcpy(input_shape, repeats, - rank, + input_rank, is_batched_memcpy, num_of_elements_per_batch, num_of_copies_per_batch, @@ -136,14 +142,14 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const { TensorPitches input_pitches(input_dims); TArray input_strides(input_pitches); - TArray fdm_input_shape(rank); + TArray fdm_input_shape(input_rank); for (size_t i = 0; i < input_dims.size(); ++i) { fdm_input_shape[gsl::narrow_cast(i)] = fast_divmod(gsl::narrow_cast(input_dims[i])); } - TArray fdm_output_strides(rank); + TArray fdm_output_strides(input_rank); TensorPitches output_pitches(output_dims); - for (auto i = 0; i < rank; i++) { + for (auto i = 0; i < input_rank; i++) { fdm_output_strides[i] = fast_divmod(static_cast(output_pitches[i])); } diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index 3e50b23353cb8..949890be31eb7 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -267,5 +267,95 @@ TEST(TensorOpTest, TileBoolType) { RunTestWrapperForBool(); } TEST(TensorOpTest, TileMLFloat16Type) { RunTestWrapper(); } #endif +// Test that negative repeat values are rejected with an error +TEST(TensorOpTest, TileNegativeRepeats) { + OpTester test("Tile", 13); + test.AddInput("input", {3}, {1.0f, 2.0f, 3.0f}); + test.AddInput("repeats", {1}, {-1}); + test.AddOutput("output", {0}, {}); + test.Run(OpTester::ExpectResult::kExpectFailure, "Tile repeat value must be non-negative"); +} + +// Test that negative repeat values are rejected for multi-dimensional input +TEST(TensorOpTest, TileNegativeRepeats2D) { + OpTester test("Tile", 13); + test.AddInput("input", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + test.AddInput("repeats", {2}, {2, -3}); + test.AddOutput("output", {0, 0}, {}); + test.Run(OpTester::ExpectResult::kExpectFailure, "Tile repeat value must be non-negative"); +} + +// Test that overflow in output dimension computation is caught by SafeInt. +// input_dim=3 * repeat=6148914691236517206 would overflow int64_t: +// 3 * 6148914691236517206 = 2^64 + 2, which wraps to 2 without overflow protection. +// SafeInt should detect this and throw. +TEST(TensorOpTest, TileOverflowRepeats1D) { + OpTester test("Tile", 13); + test.AddInput("input", {3}, {1.0f, 2.0f, 3.0f}); + // 6148914691236517206 = (2^64 + 2) / 3, so 3 * 6148914691236517206 overflows int64_t + test.AddInput("repeats", {1}, {int64_t{6148914691236517206}}); + test.AddOutput("output", {0}, {}); + test.Run(OpTester::ExpectResult::kExpectFailure, ""); +} + +// Test overflow detection with a 2D tensor where only one axis overflows +TEST(TensorOpTest, TileOverflowRepeats2D) { + OpTester test("Tile", 13); + test.AddInput("input", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + // Second axis: 3 * 6148914691236517206 overflows + test.AddInput("repeats", {2}, {1, int64_t{6148914691236517206}}); + test.AddOutput("output", {0, 0}, {}); + test.Run(OpTester::ExpectResult::kExpectFailure, ""); +} + +// Test overflow detection with large repeat on first axis +TEST(TensorOpTest, TileOverflowRepeatsFirstAxis) { + OpTester test("Tile", 13); + test.AddInput("input", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + // First axis: 2 * 4611686018427387904 = 2^63 which overflows signed int64_t + test.AddInput("repeats", {2}, {int64_t{4611686018427387904}, 1}); + test.AddOutput("output", {0, 0}, {}); + test.Run(OpTester::ExpectResult::kExpectFailure, ""); +} + +// Test overflow with int32 input type to ensure all fixed-size type paths are covered +TEST(TensorOpTest, TileOverflowRepeatsInt32) { + OpTester test("Tile", 13); + test.AddInput("input", {3}, {1, 2, 3}); + test.AddInput("repeats", {1}, {int64_t{6148914691236517206}}); + test.AddOutput("output", {0}, {}); + test.Run(OpTester::ExpectResult::kExpectFailure, ""); +} + +// Test overflow with string input type +TEST(TensorOpTest, TileOverflowRepeatsString) { + OpTester test("Tile", 13); + test.AddInput("input", {3}, {"a", "b", "c"}); + test.AddInput("repeats", {1}, {int64_t{6148914691236517206}}); + test.AddOutput("output", {0}, {}); + test.Run(OpTester::ExpectResult::kExpectFailure, ""); +} + +// Test with INT64_MAX as repeat value (should overflow for any input dim > 1) +TEST(TensorOpTest, TileOverflowMaxRepeat) { + OpTester test("Tile", 13); + test.AddInput("input", {2}, {1.0f, 2.0f}); + // INT64_MAX = 9223372036854775807; 2 * INT64_MAX overflows + test.AddInput("repeats", {1}, {std::numeric_limits::max()}); + test.AddOutput("output", {0}, {}); + test.Run(OpTester::ExpectResult::kExpectFailure, ""); +} + +// Test overflow where multiple axes combine to create an impossibly large output +TEST(TensorOpTest, TileOverflowMultipleAxes) { + OpTester test("Tile", 13); + test.AddInput("input", {2, 2, 2}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + // Each axis: 2 * large_value overflows + int64_t large_repeat = int64_t{4611686018427387904}; // 2^62 + test.AddInput("repeats", {3}, {large_repeat, 1, 1}); + test.AddOutput("output", {0, 0, 0}, {}); + test.Run(OpTester::ExpectResult::kExpectFailure, ""); +} + } // namespace test } // namespace onnxruntime From ac0413ef519fdb7439f676c73eace17e1c473818 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 5 Mar 2026 13:07:45 -0800 Subject: [PATCH 2/5] Update tests --- .../test/providers/cpu/tensor/tile_op_test.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index 949890be31eb7..62d28cd4748f4 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -273,7 +273,7 @@ TEST(TensorOpTest, TileNegativeRepeats) { test.AddInput("input", {3}, {1.0f, 2.0f, 3.0f}); test.AddInput("repeats", {1}, {-1}); test.AddOutput("output", {0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure, "Tile repeat value must be non-negative"); + test.Run(OpTester::ExpectResult::kExpectFailure); } // Test that negative repeat values are rejected for multi-dimensional input @@ -282,7 +282,7 @@ TEST(TensorOpTest, TileNegativeRepeats2D) { test.AddInput("input", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); test.AddInput("repeats", {2}, {2, -3}); test.AddOutput("output", {0, 0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure, "Tile repeat value must be non-negative"); + test.Run(OpTester::ExpectResult::kExpectFailure); } // Test that overflow in output dimension computation is caught by SafeInt. @@ -295,7 +295,7 @@ TEST(TensorOpTest, TileOverflowRepeats1D) { // 6148914691236517206 = (2^64 + 2) / 3, so 3 * 6148914691236517206 overflows int64_t test.AddInput("repeats", {1}, {int64_t{6148914691236517206}}); test.AddOutput("output", {0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure, ""); + test.Run(OpTester::ExpectResult::kExpectFailure); } // Test overflow detection with a 2D tensor where only one axis overflows @@ -305,7 +305,7 @@ TEST(TensorOpTest, TileOverflowRepeats2D) { // Second axis: 3 * 6148914691236517206 overflows test.AddInput("repeats", {2}, {1, int64_t{6148914691236517206}}); test.AddOutput("output", {0, 0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure, ""); + test.Run(OpTester::ExpectResult::kExpectFailure); } // Test overflow detection with large repeat on first axis @@ -315,7 +315,7 @@ TEST(TensorOpTest, TileOverflowRepeatsFirstAxis) { // First axis: 2 * 4611686018427387904 = 2^63 which overflows signed int64_t test.AddInput("repeats", {2}, {int64_t{4611686018427387904}, 1}); test.AddOutput("output", {0, 0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure, ""); + test.Run(OpTester::ExpectResult::kExpectFailure); } // Test overflow with int32 input type to ensure all fixed-size type paths are covered @@ -324,7 +324,7 @@ TEST(TensorOpTest, TileOverflowRepeatsInt32) { test.AddInput("input", {3}, {1, 2, 3}); test.AddInput("repeats", {1}, {int64_t{6148914691236517206}}); test.AddOutput("output", {0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure, ""); + test.Run(OpTester::ExpectResult::kExpectFailure); } // Test overflow with string input type @@ -333,7 +333,7 @@ TEST(TensorOpTest, TileOverflowRepeatsString) { test.AddInput("input", {3}, {"a", "b", "c"}); test.AddInput("repeats", {1}, {int64_t{6148914691236517206}}); test.AddOutput("output", {0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure, ""); + test.Run(OpTester::ExpectResult::kExpectFailure); } // Test with INT64_MAX as repeat value (should overflow for any input dim > 1) @@ -343,7 +343,7 @@ TEST(TensorOpTest, TileOverflowMaxRepeat) { // INT64_MAX = 9223372036854775807; 2 * INT64_MAX overflows test.AddInput("repeats", {1}, {std::numeric_limits::max()}); test.AddOutput("output", {0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure, ""); + test.Run(OpTester::ExpectResult::kExpectFailure); } // Test overflow where multiple axes combine to create an impossibly large output @@ -354,7 +354,7 @@ TEST(TensorOpTest, TileOverflowMultipleAxes) { int64_t large_repeat = int64_t{4611686018427387904}; // 2^62 test.AddInput("repeats", {3}, {large_repeat, 1, 1}); test.AddOutput("output", {0, 0, 0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure, ""); + test.Run(OpTester::ExpectResult::kExpectFailure); } } // namespace test From a37ba215c0ec995342d4613fbfaab9a9ede31a2b Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 5 Mar 2026 13:38:52 -0800 Subject: [PATCH 3/5] check repeat range for webgpu --- onnxruntime/core/providers/webgpu/tensor/tile.cc | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/tile.cc b/onnxruntime/core/providers/webgpu/tensor/tile.cc index 841c36724df30..d29be56197235 100644 --- a/onnxruntime/core/providers/webgpu/tensor/tile.cc +++ b/onnxruntime/core/providers/webgpu/tensor/tile.cc @@ -55,13 +55,16 @@ Status Tile::ComputeInternal(ComputeContext& context) const { const auto* repeats_data = repeats_tensor->Data(); std::vector repeats; - for (size_t i = 0; i < static_cast(repeats_tensor->Shape().Size()); i++) { - repeats.push_back(static_cast(repeats_data[i])); - } - auto output_dims = input_shape.AsShapeVector(); for (size_t axis = 0; axis < input_rank; axis++) { - output_dims[axis] *= repeats[axis]; + if (repeats_data[axis] < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Tile repeat value must be non-negative, got: ", repeats_data[axis]); + } + output_dims[axis] = SafeInt(output_dims[axis]) * repeats_data[axis]; + } + for (size_t i = 0; i < static_cast(repeats_tensor->Shape().Size()); i++) { + repeats.push_back(static_cast(repeats_data[i])); } TensorShape output_shape(output_dims); @@ -83,4 +86,4 @@ Status Tile::ComputeInternal(ComputeContext& context) const { } } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime From 173ad15ac046c870c5ce4ccf9e14c1e95709e07d Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 5 Mar 2026 15:41:07 -0800 Subject: [PATCH 4/5] Exckude TRT EP --- onnxruntime/test/providers/cpu/tensor/tile_op_test.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index 62d28cd4748f4..d5707cd1ba2db 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -295,7 +295,7 @@ TEST(TensorOpTest, TileOverflowRepeats1D) { // 6148914691236517206 = (2^64 + 2) / 3, so 3 * 6148914691236517206 overflows int64_t test.AddInput("repeats", {1}, {int64_t{6148914691236517206}}); test.AddOutput("output", {0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure); + test.Run(OpTester::ExpectResult::kExpectFailure, "", {kTensorrtExecutionProvider}); } // Test overflow detection with a 2D tensor where only one axis overflows @@ -315,7 +315,7 @@ TEST(TensorOpTest, TileOverflowRepeatsFirstAxis) { // First axis: 2 * 4611686018427387904 = 2^63 which overflows signed int64_t test.AddInput("repeats", {2}, {int64_t{4611686018427387904}, 1}); test.AddOutput("output", {0, 0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure); + test.Run(OpTester::ExpectResult::kExpectFailure, "", {kTensorrtExecutionProvider}); } // Test overflow with int32 input type to ensure all fixed-size type paths are covered @@ -324,7 +324,7 @@ TEST(TensorOpTest, TileOverflowRepeatsInt32) { test.AddInput("input", {3}, {1, 2, 3}); test.AddInput("repeats", {1}, {int64_t{6148914691236517206}}); test.AddOutput("output", {0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure); + test.Run(OpTester::ExpectResult::kExpectFailure, "", {kTensorrtExecutionProvider}); } // Test overflow with string input type @@ -354,7 +354,7 @@ TEST(TensorOpTest, TileOverflowMultipleAxes) { int64_t large_repeat = int64_t{4611686018427387904}; // 2^62 test.AddInput("repeats", {3}, {large_repeat, 1, 1}); test.AddOutput("output", {0, 0, 0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure); + test.Run(OpTester::ExpectResult::kExpectFailure, "", {kTensorrtExecutionProvider}); } } // namespace test From d1d66b87607db30f9bfd3bc941c78c6a0da0e8f6 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 10 Mar 2026 10:36:13 -0700 Subject: [PATCH 5/5] Exclude TRT from a test --- onnxruntime/test/providers/cpu/tensor/tile_op_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index d5707cd1ba2db..91da10ae03b85 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -305,7 +305,7 @@ TEST(TensorOpTest, TileOverflowRepeats2D) { // Second axis: 3 * 6148914691236517206 overflows test.AddInput("repeats", {2}, {1, int64_t{6148914691236517206}}); test.AddOutput("output", {0, 0}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure); + test.Run(OpTester::ExpectResult::kExpectFailure, "", {kTensorrtExecutionProvider}); } // Test overflow detection with large repeat on first axis