Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions onnxruntime/core/providers/cpu/object_detection/roialign.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,41 @@ Status CheckROIAlignValidInput(const Tensor* X_ptr, const Tensor* rois_ptr, cons
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"First dimension (num_rois) of batch_indices and rois don't match");
}

// Validate batch_indices values are within [0, batch_size).
// Only check when the tensor data is accessible from the host (CPU).
// For GPU tensors (e.g. CUDA EP), Data<T>() returns a device pointer
// that cannot be safely dereferenced on the host. A device-side bounds
// check for the CUDA path would require passing batch_size into the
// CUDA kernel — tracked as a follow-up.
if (batch_indices_ptr->Location().device.Type() == OrtDevice::CPU) {
const int64_t batch_size = X_ptr->Shape()[0];
const int64_t num_rois = batch_indices_dims[0];

auto check_bounds = [batch_size, num_rois](const auto* batch_indices_data) -> Status {
for (int64_t i = 0; i < num_rois; ++i) {
if (batch_indices_data[i] < 0 || batch_indices_data[i] >= batch_size) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"batch_indices value " + std::to_string(batch_indices_data[i]) +
" at index " + std::to_string(i) +
" is out of range [0, " + std::to_string(batch_size) + ")");
}
}
return Status::OK();
};

if (batch_indices_ptr->IsDataType<int64_t>()) {
auto status = check_bounds(batch_indices_ptr->Data<int64_t>());
if (!status.IsOK()) return status;
} else if (batch_indices_ptr->IsDataType<int32_t>()) {
auto status = check_bounds(batch_indices_ptr->Data<int32_t>());
if (!status.IsOK()) return status;
} else {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"batch_indices must be of type int64_t or int32_t");
}
}

return Status::OK();
}

Expand Down
42 changes: 42 additions & 0 deletions onnxruntime/test/providers/cpu/object_detection/roialign_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -812,5 +812,47 @@ TEST(RoiAlignTest, MismatchNumRois) {

test.Run(OpTester::ExpectResult::kExpectFailure, "[ShapeInferenceError] Dimension mismatch in unification between 4 and 5");
}

TEST(RoiAlignTest, BatchIndicesOutOfRange) {
OpTester test("RoiAlign", 16);
test.AddAttribute<int64_t>("output_height", 2);
test.AddAttribute<int64_t>("output_width", 2);
test.AddAttribute<int64_t>("sampling_ratio", 2);
test.AddAttribute<float>("spatial_scale", 1.0f);

test.AddInput<float>("X", {1, 1, 4, 4},
{0.f, 1.f, 2.f, 3.f,
4.f, 5.f, 6.f, 7.f,
8.f, 9.f, 10.f, 11.f,
12.f, 13.f, 14.f, 15.f});
test.AddInput<float>("rois", {1, 4}, {0.f, 0.f, 3.f, 3.f});
test.AddInput<int64_t>("batch_indices", {1}, {1}); // <-- failure condition
test.AddOutput<float>("Y", {1, 1, 2, 2}, {0.f, 0.f, 0.f, 0.f});

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure, "batch_indices value 1 at index 0 is out of range [0, 1)", {}, nullptr, &execution_providers);
}

TEST(RoiAlignTest, BatchIndicesNegative) {
OpTester test("RoiAlign", 16);
test.AddAttribute<int64_t>("output_height", 2);
test.AddAttribute<int64_t>("output_width", 2);
test.AddAttribute<int64_t>("sampling_ratio", 2);
test.AddAttribute<float>("spatial_scale", 1.0f);

test.AddInput<float>("X", {1, 1, 4, 4},
{0.f, 1.f, 2.f, 3.f,
4.f, 5.f, 6.f, 7.f,
8.f, 9.f, 10.f, 11.f,
12.f, 13.f, 14.f, 15.f});
test.AddInput<float>("rois", {1, 4}, {0.f, 0.f, 3.f, 3.f});
test.AddInput<int64_t>("batch_indices", {1}, {-1}); // <-- failure condition
test.AddOutput<float>("Y", {1, 1, 2, 2}, {0.f, 0.f, 0.f, 0.f});

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure, "batch_indices value -1 at index 0 is out of range [0, 1)", {}, nullptr, &execution_providers);
}
} // namespace test
} // namespace onnxruntime
Loading