diff --git a/onnxruntime/core/providers/cpu/tensor/gather.cc b/onnxruntime/core/providers/cpu/tensor/gather.cc index b13fcd4135f67..f171b33ee5f4f 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather.cc @@ -79,9 +79,9 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin } } - auto lambda = [&](int64_t index) { - int64_t batch = index / N; - int64_t i = index % N; + auto lambda = [&](ptrdiff_t index) { + const int64_t batch = static_cast(index / N); + const int64_t i = static_cast(index % N); const int64_t src_offset_batch = batch * data_batch_bytes; const int64_t dst_offset_batch = batch * gathered_batch_bytes; @@ -97,12 +97,14 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin memcpy(dst_base + dst_offset, src_base + src_offset, narrow(block_size)); } }; - concurrency::ThreadPool::TryParallelFor(tp, SafeInt(M) * N, static_cast(block_size), - [&lambda](ptrdiff_t first, ptrdiff_t last) { - for (int index = static_cast(first), end = static_cast(last); index < end; ++index) { - lambda(index); - } - }); + + concurrency::ThreadPool::TryParallelFor( + tp, SafeInt(M) * N, static_cast(block_size), + [&lambda](ptrdiff_t first, ptrdiff_t last) { + for (ptrdiff_t index = first; index < last; ++index) { + lambda(index); + } + }); return Status::OK(); } diff --git a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc index c1a5a31667315..82a9d86a3630a 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc @@ -341,6 +341,41 @@ TEST(GatherOpTest, Gather_axis1_indices2d_string) { test.Run(); } +TEST(GatherOpTest, Gather_overflow_check) { +// Skip on 32-bit platforms where size_t overflow would truncate the large expected +// output shape and where allocating the full reference tensor is infeasible. +#if SIZE_MAX <= UINT32_MAX + GTEST_SKIP() << "Gather_overflow_check skipped on 32-bit platforms."; +#endif + + // The test uses dimensions (65537, 2) and indices of length 65537, which produce an output + // shape of (65537, 65537). + // + // 65537 x 65537 = 4,295,098,369 which is greater than the maximum value of a 32-bit integer (2,147,483,647). + // + // This test is to verify CPU implementation of the Gather operator doesn't overflow when calculating + // the output shape and generating the output tensor. + + OpTester test("Gather"); + test.AddAttribute("axis", 1LL); + + // Inputs + const std::vector data_dims{65537, 2}; + const std::vector indices_dims{65537}; + std::vector data_values(static_cast(data_dims[0] * data_dims[1]), 1); + std::vector indices_values(static_cast(indices_dims[0]), 1); + std::vector expected_output_values(static_cast(65537) * static_cast(65537), 1); + + test.AddInput("data", {65537, 2}, data_values); + test.AddInput("indices", {65537}, indices_values); + test.AddOutput("output", {65537, 65537}, expected_output_values); + + std::vector> execution_providers; + execution_providers.emplace_back(DefaultCpuExecutionProvider()); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + TEST(GatherOpTest, Gather_axis1_indices2d_bool) { OpTester test("Gather"); test.AddAttribute("axis", 1LL);