diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.cc b/onnxruntime/core/providers/cpu/object_detection/roialign.cc index 6ecbfaa3993ca..87958a9f7e2dd 100644 --- a/onnxruntime/core/providers/cpu/object_detection/roialign.cc +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.cc @@ -295,13 +295,9 @@ Status CheckROIAlignValidInput(const Tensor* X_ptr, const Tensor* rois_ptr, cons "First dimension (num_rois) of batch_indices and rois don't match"); } - // Validate batch_indices values are within [0, batch_size). - // Only check when the tensor data is accessible from the host (CPU). - // For GPU tensors (e.g. CUDA EP), Data() returns a device pointer - // that cannot be safely dereferenced on the host. A device-side bounds - // check for the CUDA path would require passing batch_size into the - // CUDA kernel — tracked as a follow-up. if (batch_indices_ptr->Location().device.Type() == OrtDevice::CPU) { + // Validate batch_indices values are within [0, batch_size) when the tensor + // data is accessible from the host (CPU). const int64_t batch_size = X_ptr->Shape()[0]; const int64_t num_rois = batch_indices_dims[0]; diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign.cc b/onnxruntime/core/providers/cuda/object_detection/roialign.cc index a6d1520d24184..71fb066c2898f 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign.cc +++ b/onnxruntime/core/providers/cuda/object_detection/roialign.cc @@ -60,7 +60,8 @@ Status RoiAlign::ComputeInternal(OpKernelContext* context) const { reinterpret_cast::MappedType*>(Y.MutableData()), this->mode_ == RoiAlignMode::avg, this->half_pixel_, - batch_indices_ptr->Data()); + batch_indices_ptr->Data(), + x_dims[0]); // batch_size } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu index 3f56d197d6bd3..7acfd9d075461 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu +++ b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu @@ -95,7 +95,8 @@ __global__ void RoIAlignForward( T* top_data, const bool is_mode_avg, const bool half_pixel, - const int64_t* batch_indices_ptr) { + const int64_t* batch_indices_ptr, + const int64_t batch_size) { for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -106,6 +107,13 @@ __global__ void RoIAlignForward( // RoI could have 4 or 5 columns const T* offset_bottom_rois = bottom_rois + n * roi_cols; const auto roi_batch_ind = batch_indices_ptr[n]; + // Validate batch_indices values are within [0, batch_size). + // If the index is out of range, we set the output to 0 for this RoI element. + if (roi_batch_ind < 0 || roi_batch_ind >= batch_size) { + CUDA_KERNEL_ASSERT(false && "batch_indices values are out of range"); + top_data[index] = 0; + continue; + } // Do not using rounding; this implementation detail is critical T roi_offset = half_pixel ? T(0.5) : T(0); @@ -189,7 +197,8 @@ void RoiAlignImpl( T* top_data, const bool is_mode_avg, const bool half_pixel, - const int64_t* batch_indices_ptr) { + const int64_t* batch_indices_ptr, + const int64_t batch_size) { int blocksPerGrid = (int)(ceil(static_cast(nthreads) / GridDim::maxThreadsPerBlock)); RoIAlignForward<<>>( nthreads, @@ -206,27 +215,29 @@ void RoiAlignImpl( top_data, is_mode_avg, half_pixel, - batch_indices_ptr); + batch_indices_ptr, + batch_size); } -#define SPECIALIZED_IMPL(T) \ - template void RoiAlignImpl( \ - cudaStream_t stream, \ - const int64_t nthreads, \ - const T* bottom_data, \ - const T spatial_scale, \ - const int64_t channels, \ - const int64_t height, \ - const int64_t width, \ - const int64_t pooled_height, \ - const int64_t pooled_width, \ - const int64_t sampling_ratio, \ - const T* bottom_rois, \ - int64_t roi_cols, \ - T* top_data, \ - const bool is_mode_avg, \ - const bool half_pixel, \ - const int64_t* batch_indices_ptr); +#define SPECIALIZED_IMPL(T) \ + template void RoiAlignImpl( \ + cudaStream_t stream, \ + const int64_t nthreads, \ + const T* bottom_data, \ + const T spatial_scale, \ + const int64_t channels, \ + const int64_t height, \ + const int64_t width, \ + const int64_t pooled_height, \ + const int64_t pooled_width, \ + const int64_t sampling_ratio, \ + const T* bottom_rois, \ + int64_t roi_cols, \ + T* top_data, \ + const bool is_mode_avg, \ + const bool half_pixel, \ + const int64_t* batch_indices_ptr, \ + const int64_t batch_size); SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double) diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.h b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.h index 3fd2f1804322f..0b68c23b811fc 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.h +++ b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.h @@ -25,7 +25,8 @@ void RoiAlignImpl( T* top_data, const bool is_mode_avg, const bool half_pixel, - const int64_t* batch_indices_ptr); + const int64_t* batch_indices_ptr, + const int64_t batch_size); // batch size of the input tensor X } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc index 1a1c1b6cde3b5..1eeb3683bc9aa 100644 --- a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc @@ -854,5 +854,57 @@ TEST(RoiAlignTest, BatchIndicesNegative) { execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectFailure, "batch_indices value -1 at index 0 is out of range [0, 1)", {}, nullptr, &execution_providers); } + +TEST(RoiAlignTest, BatchIndicesOutOfRange_CUDA) { +#if !defined(NDEBUG) + GTEST_SKIP() << "Skipping in Debug builds because CUDA device-side asserts poison the CUDA context."; +#else + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 10); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f); + + test.AddInput("X", {1, 1, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + test.AddInput("rois", {1, 4}, {0, 0, 3, 3}); + test.AddInput("batch_indices", {1}, {1}); // batch_size is 1, so 1 is out of range + test.AddOutput("Y", {1, 1, 2, 2}, {0.f, 0.f, 0.f, 0.f}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +#endif +} + +TEST(RoiAlignTest, BatchIndicesNegative_CUDA) { +#if !defined(NDEBUG) + GTEST_SKIP() << "Skipping in Debug builds because CUDA device-side asserts poison the CUDA context."; +#else + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 10); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f); + + test.AddInput("X", {1, 1, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + test.AddInput("rois", {1, 4}, {0, 0, 3, 3}); + test.AddInput("batch_indices", {1}, {-1}); + test.AddOutput("Y", {1, 1, 2, 2}, {0.f, 0.f, 0.f, 0.f}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +#endif +} } // namespace test } // namespace onnxruntime