Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions onnxruntime/core/providers/cpu/tensor/gather_nd.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <atomic>
#include <core/common/safeint.h>
#include "gather_nd.h"
#include "core/platform/threadpool.h"
Expand Down Expand Up @@ -85,7 +86,7 @@ Status GatherNDBase::PrepareForCompute(const TensorShape& input_shape, const Ten
sizes_from_slice_dims[onnxruntime::narrow<size_t>(i)] = input_shape.SizeFromDimension(SafeInt<size_t>(batch_dims_) + i + 1);
}

int64_t err_index = 0;
std::atomic<const Tind*> invalid_index{nullptr};
p.element_bytes = bytes_per_value;
p.element_count_per_slice = slice_size;
p.bytes_per_slice = p.element_bytes * p.element_count_per_slice;
Expand All @@ -94,6 +95,8 @@ Status GatherNDBase::PrepareForCompute(const TensorShape& input_shape, const Ten

// Compute the element_offset
auto lambda = [&](ptrdiff_t slice_idx) {
if (invalid_index.load(std::memory_order_relaxed)) return;

const size_t batch_idx = onnxruntime::narrow<size_t>(slice_idx / num_slices_per_batch);
const size_t input_base_offset = batch_idx * SafeInt<size_t>(input_batch_stride);

Expand All @@ -104,7 +107,7 @@ Status GatherNDBase::PrepareForCompute(const TensorShape& input_shape, const Ten
const auto upper_limit = input_shape[SafeInt<size_t>(batch_dims_) + dim_idx];
const auto lower_limit = -upper_limit;
if (index < lower_limit || index >= upper_limit) {
err_index = index;
invalid_index.store(&slice_indices[dim_idx], std::memory_order_relaxed);
break;
}
if (index < 0) index += upper_limit;
Expand All @@ -123,8 +126,12 @@ Status GatherNDBase::PrepareForCompute(const TensorShape& input_shape, const Ten
}
});

return err_index == 0 ? Status::OK()
: ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index found, index = ", err_index);
if (const Tind* bad = invalid_index.load(std::memory_order_relaxed); bad != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"invalid index found, index = ", static_cast<int64_t>(*bad));
}

return Status::OK();
}

template Status GatherNDBase::PrepareForCompute<int32_t>(const TensorShape&,
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,5 +372,29 @@ TEST(GatherNDOpTest, GatherND_zero_batch_dims_error) {
&cpu_only_ep); // force CPU
}

// Test that GatherND returns an error when a non-batch dimension is zero and index is 0.
// This is a regression test: the original code used `int64_t err_index = 0` as a sentinel,
// so an out-of-bounds index of 0 was incorrectly treated as "no error".
TEST(GatherNDOpTest, GatherND_zero_dim_error) {
OpTester test("GatherND", 12, kOnnxDomain);

// Input shape {2, 0, 3}: the second dimension has size 0, so any index into it is invalid.
// Indices shape {1, 2}: last dim is 2, so each index targets dimensions 0 and 1.
// Index {0, 0} targets dim 0 (size 2, valid) then dim 1 (size 0, invalid).
// Output shape would be {1, 3} (non-empty), so the early-exit doesn't trigger.
test.AddInput<float>("data", {2, 0, 3}, {});
test.AddInput<int64_t>("indices", {1, 2}, {0LL, 0LL});
test.AddOutput<float>("output", {1, 3}, {0.f, 0.f, 0.f}); // dummy output, won't be used

std::vector<std::unique_ptr<onnxruntime::IExecutionProvider>> cpu_only_ep;
cpu_only_ep.push_back(DefaultCpuExecutionProvider());

test.Run(OpTester::ExpectResult::kExpectFailure,
"invalid index found, index = 0",
{},
nullptr,
&cpu_only_ep);
}

} // namespace test
} // namespace onnxruntime
Loading