diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.cc b/onnxruntime/core/providers/cpu/object_detection/roialign.cc index d8c81e5cb63e5..6ecbfaa3993ca 100644 --- a/onnxruntime/core/providers/cpu/object_detection/roialign.cc +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.cc @@ -294,6 +294,41 @@ Status CheckROIAlignValidInput(const Tensor* X_ptr, const Tensor* rois_ptr, cons return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "First dimension (num_rois) of batch_indices and rois don't match"); } + + // Validate batch_indices values are within [0, batch_size). + // Only check when the tensor data is accessible from the host (CPU). + // For GPU tensors (e.g. CUDA EP), Data() returns a device pointer + // that cannot be safely dereferenced on the host. A device-side bounds + // check for the CUDA path would require passing batch_size into the + // CUDA kernel — tracked as a follow-up. + if (batch_indices_ptr->Location().device.Type() == OrtDevice::CPU) { + const int64_t batch_size = X_ptr->Shape()[0]; + const int64_t num_rois = batch_indices_dims[0]; + + auto check_bounds = [batch_size, num_rois](const auto* batch_indices_data) -> Status { + for (int64_t i = 0; i < num_rois; ++i) { + if (batch_indices_data[i] < 0 || batch_indices_data[i] >= batch_size) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "batch_indices value " + std::to_string(batch_indices_data[i]) + + " at index " + std::to_string(i) + + " is out of range [0, " + std::to_string(batch_size) + ")"); + } + } + return Status::OK(); + }; + + if (batch_indices_ptr->IsDataType()) { + auto status = check_bounds(batch_indices_ptr->Data()); + if (!status.IsOK()) return status; + } else if (batch_indices_ptr->IsDataType()) { + auto status = check_bounds(batch_indices_ptr->Data()); + if (!status.IsOK()) return status; + } else { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "batch_indices must be of type int64_t or int32_t"); + } + } + return Status::OK(); } diff --git a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc index 58a616717316e..1a1c1b6cde3b5 100644 --- a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc @@ -812,5 +812,47 @@ TEST(RoiAlignTest, MismatchNumRois) { test.Run(OpTester::ExpectResult::kExpectFailure, "[ShapeInferenceError] Dimension mismatch in unification between 4 and 5"); } + +TEST(RoiAlignTest, BatchIndicesOutOfRange) { + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f); + + test.AddInput("X", {1, 1, 4, 4}, + {0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, + 8.f, 9.f, 10.f, 11.f, + 12.f, 13.f, 14.f, 15.f}); + test.AddInput("rois", {1, 4}, {0.f, 0.f, 3.f, 3.f}); + test.AddInput("batch_indices", {1}, {1}); // <-- failure condition + test.AddOutput("Y", {1, 1, 2, 2}, {0.f, 0.f, 0.f, 0.f}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, "batch_indices value 1 at index 0 is out of range [0, 1)", {}, nullptr, &execution_providers); +} + +TEST(RoiAlignTest, BatchIndicesNegative) { + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f); + + test.AddInput("X", {1, 1, 4, 4}, + {0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, + 8.f, 9.f, 10.f, 11.f, + 12.f, 13.f, 14.f, 15.f}); + test.AddInput("rois", {1, 4}, {0.f, 0.f, 3.f, 3.f}); + test.AddInput("batch_indices", {1}, {-1}); // <-- failure condition + test.AddOutput("Y", {1, 1, 2, 2}, {0.f, 0.f, 0.f, 0.f}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, "batch_indices value -1 at index 0 is out of range [0, 1)", {}, nullptr, &execution_providers); +} } // namespace test } // namespace onnxruntime