Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/actions/setup-android-ndk/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ runs:
set -e -x
python3 tools/python/run_android_emulator.py \
--android-sdk-root "${ANDROID_SDK_ROOT}" \
--start --emulator-extra-args="-partition-size 2047" \
--start --emulator-extra-args="-partition-size 2047 -memory 5120" \
--emulator-pid-file ./emulator.pid
echo "Emulator PID: `cat ./emulator.pid`"

- name: View Android ENVs
shell: bash
run: env | grep ANDROID
run: env | grep ANDROID
40 changes: 26 additions & 14 deletions onnxruntime/test/providers/cpu/tensor/gather_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,37 +342,49 @@ TEST(GatherOpTest, Gather_axis1_indices2d_string) {
}

TEST(GatherOpTest, Gather_overflow_check) {
// Skip on 32-bit platforms where size_t overflow would truncate the large expected
// output shape and where allocating the full reference tensor is infeasible.
// Skip on 32-bit platforms where allocating the full reference tensor is infeasible due
// to std::vector::max_size being limited to the size of ptrdiff_t (INT32_MAX on 32-bit).
// Also, peak memory usage for this test would be greater than what is addressable.
#if SIZE_MAX <= UINT32_MAX
GTEST_SKIP() << "Gather_overflow_check skipped on 32-bit platforms.";
#endif

// The test uses dimensions (65537, 2) and indices of length 65537, which produce an output
// shape of (65537, 65537).
// The test uses dimensions (46341, 2) and indices of length 46341, which produce an output
// shape of (46341, 46341).
//
// 65537 x 65537 = 4,295,098,369 which is greater than the maximum value of a 32-bit integer (2,147,483,647).
// 46341 x 46341 = 2,147,488,281 which is just greater than the maximum value of a 32-bit integer (2,147,483,647).
//
// This test is to verify CPU implementation of the Gather operator doesn't overflow when calculating
// the output shape and generating the output tensor.

constexpr int64_t dim_val = 46341;

OpTester test("Gather");
test.AddAttribute<int64_t>("axis", 1LL);

// Inputs
const std::vector<int64_t> data_dims{65537, 2};
const std::vector<int64_t> indices_dims{65537};
std::vector<uint8_t> data_values(static_cast<size_t>(data_dims[0] * data_dims[1]), 1);
std::vector<int64_t> indices_values(static_cast<size_t>(indices_dims[0]), 1);
std::vector<uint8_t> expected_output_values(static_cast<size_t>(65537) * static_cast<size_t>(65537), 1);
// Setup test inputs and outputs in a separate scope to ensure the large `expected_output_values` array
// is destroyed before we run the test via `test.Run()`.
{
const std::vector<int64_t> data_dims{dim_val, 2};
const std::vector<int64_t> indices_dims{dim_val};
std::vector<uint8_t> data_values(static_cast<size_t>(data_dims[0] * data_dims[1]), 1);
std::vector<int64_t> indices_values(static_cast<size_t>(indices_dims[0]), 1);
std::vector<uint8_t> expected_output_values(static_cast<size_t>(dim_val) * static_cast<size_t>(dim_val), 1);

test.AddInput<uint8_t>("data", {dim_val, 2}, data_values);
test.AddInput<int64_t>("indices", {dim_val}, indices_values);

test.AddInput<uint8_t>("data", {65537, 2}, data_values);
test.AddInput<int64_t>("indices", {65537}, indices_values);
test.AddOutput<uint8_t>("output", {65537, 65537}, expected_output_values);
// Note: the large ~2GiB `expected_output_values` array is copied into the OpTester.
test.AddOutput<uint8_t>("output", {dim_val, dim_val}, expected_output_values);
}

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.emplace_back(DefaultCpuExecutionProvider());

// Note: peak memory usage will be in the order of multiple GiB:
// - OpTester holds expected outputs buffer of size ~2GiB
// - The session state allocates a buffer for the output of size ~2GiB
// - Other overhead and bookkeeping.
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

Expand Down
Loading