Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions onnxruntime/core/providers/cpu/tensor/gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin
}
}

auto lambda = [&](int64_t index) {
int64_t batch = index / N;
int64_t i = index % N;
auto lambda = [&](ptrdiff_t index) {
const int64_t batch = static_cast<int64_t>(index / N);
const int64_t i = static_cast<int64_t>(index % N);

const int64_t src_offset_batch = batch * data_batch_bytes;
const int64_t dst_offset_batch = batch * gathered_batch_bytes;
Expand All @@ -97,12 +97,14 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin
memcpy(dst_base + dst_offset, src_base + src_offset, narrow<size_t>(block_size));
}
};
concurrency::ThreadPool::TryParallelFor(tp, SafeInt<ptrdiff_t>(M) * N, static_cast<double>(block_size),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int index = static_cast<int>(first), end = static_cast<int>(last); index < end; ++index) {
lambda(index);
}
});

concurrency::ThreadPool::TryParallelFor(
tp, SafeInt<ptrdiff_t>(M) * N, static_cast<double>(block_size),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (ptrdiff_t index = first; index < last; ++index) {
lambda(index);
}
});

return Status::OK();
}
Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/gather_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,35 @@ TEST(GatherOpTest, Gather_axis1_indices2d_string) {
test.Run();
}

TEST(GatherOpTest, Gather_overflow_check) {
// The test uses dimensions (65537, 2) and indices of length 65537, which produce an output
// shape of (65537, 65537).
//
Comment thread
chilo-ms marked this conversation as resolved.
Outdated
// 65537 x 65537 = 4,295,098,369 which is greater than the maximum value of a 32-bit integer (2,147,483,647).
//
Comment thread
chilo-ms marked this conversation as resolved.
Outdated
// This test is to verify CPU implementation of the Gather operator doesn't overflow when calculating
// the output shape and generating the output tensor.

OpTester test("Gather");
test.AddAttribute<int64_t>("axis", 1LL);

// Inputs
const std::vector<int64_t> data_dims{65537, 2};
const std::vector<int64_t> indices_dims{65537};
std::vector<uint8_t> data_values(static_cast<size_t>(data_dims[0] * data_dims[1]), 1);
std::vector<int64_t> indices_values(static_cast<size_t>(indices_dims[0]), 1);
std::vector<uint8_t> expected_output_values(static_cast<size_t>(65537) * static_cast<size_t>(65537), 1);

test.AddInput<uint8_t>("data", {65537, 2}, data_values);
test.AddInput<int64_t>("indices", {65537}, indices_values);
test.AddOutput<uint8_t>("output", {65537, 65537}, expected_output_values);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.emplace_back(DefaultCpuExecutionProvider());

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(GatherOpTest, Gather_axis1_indices2d_bool) {
OpTester test("Gather");
test.AddAttribute<int64_t>("axis", 1LL);
Expand Down
Loading