diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bea34bc860ce2..721659d7a92ba 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -887,7 +887,9 @@ Do not modify directly.* |||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| +|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|22+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(int64)| +|||[16, 21]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(int64)| +|||[10, 15]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |RotaryEmbedding|*in* X:**T**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**M**
*out* Y:**T**|23+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| |Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.h b/onnxruntime/core/providers/cpu/object_detection/roialign.h index bb97de158369b..4ce4825e1d78c 100644 --- a/onnxruntime/core/providers/cpu/object_detection/roialign.h +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.h @@ -129,6 +129,10 @@ class RoiAlignBase { std::string coordinate_transformation_mode; if (info.template GetAttr("coordinate_transformation_mode", &coordinate_transformation_mode).IsOK()) { half_pixel_ = coordinate_transformation_mode == "half_pixel"; + } else { + // For opset 16+, the default is "half_pixel" per ONNX spec. + // For opset 10 (which has no coordinate_transformation_mode attribute), false is correct. + half_pixel_ = info.node().SinceVersion() >= 16; } if (mode_ == RoiAlignMode::max && sampling_ratio_ != 1) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 4c0bebfd2d864..5ba5295df8b89 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -944,8 +944,11 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, ReverseSequence); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, float, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, double, RoiAlign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, MLFloat16, RoiAlign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, ThresholdedRelu); @@ -1601,6 +1604,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, GRU); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, GRU); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, GRU); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, RoiAlign); // Opset 23. class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention); @@ -2042,8 +2049,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2700,6 +2710,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 23 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign.cc b/onnxruntime/core/providers/cuda/object_detection/roialign.cc index 71fb066c2898f..5d876ae5a2cc9 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign.cc +++ b/onnxruntime/core/providers/cuda/object_detection/roialign.cc @@ -7,11 +7,37 @@ namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ +#define ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ RoiAlign, \ kOnnxDomain, \ 10, \ + 15, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + RoiAlign); + +#define ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + RoiAlign, \ + kOnnxDomain, \ + 16, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + RoiAlign); + +#define ADD_TYPED_ROIALIGN_OP_22(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RoiAlign, \ + kOnnxDomain, \ + 22, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -67,13 +93,22 @@ Status RoiAlign::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -#define SPECIALIZED_COMPUTE(T) \ - REGISTER_KERNEL_TYPED(T) \ +#define SPECIALIZED_COMPUTE(T) \ + ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \ + ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \ + ADD_TYPED_ROIALIGN_OP_22(T) \ template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; SPECIALIZED_COMPUTE(float) SPECIALIZED_COMPUTE(double) -// SPECIALIZED_COMPUTE(MLFloat16) +// MLFloat16 is available for RoiAlign op from version 16 (not version 10): +ADD_VERSIONED_TYPED_ROIALIGN_OP_16(MLFloat16) +ADD_TYPED_ROIALIGN_OP_22(MLFloat16) +template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; + +// BFloat16 is available for RoiAlign op from version 22: +ADD_TYPED_ROIALIGN_OP_22(BFloat16) +template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const; } // namespace cuda }; // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu index 7acfd9d075461..87f4aba8e45b2 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu +++ b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu @@ -17,64 +17,72 @@ #include "roialign_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/accumulation_type.h" namespace onnxruntime { namespace cuda { template -__device__ T bilinear_interpolate( +__device__ AccumulationType_t bilinear_interpolate( const T* bottom_data, const int height, const int width, - T y, - T x, + AccumulationType_t y, + AccumulationType_t x, const bool is_mode_avg, const int index /* index for debug only*/) { + using TAcc = AccumulationType_t; + // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { + if (y < static_cast(-1.0f) || y > static_cast(height) || + x < static_cast(-1.0f) || x > static_cast(width)) { // empty - return 0; + return static_cast(0.0f); } - if (y <= 0) { - y = 0; + if (y <= static_cast(0.0f)) { + y = static_cast(0.0f); } - if (x <= 0) { - x = 0; + if (x <= static_cast(0.0f)) { + x = static_cast(0.0f); } - int y_low = (int)y; - int x_low = (int)x; + int y_low = static_cast(y); + int x_low = static_cast(x); int y_high; int x_high; if (y_low >= height - 1) { y_high = y_low = height - 1; - y = (T)y_low; + y = static_cast(y_low); } else { y_high = y_low + 1; } if (x_low >= width - 1) { x_high = x_low = width - 1; - x = (T)x_low; + x = static_cast(x_low); } else { x_high = x_low + 1; } - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; + TAcc ly = y - static_cast(y_low); + TAcc lx = x - static_cast(x_low); + TAcc hy = static_cast(1.0f) - ly; + TAcc hx = static_cast(1.0f) - lx; // do bilinear interpolation - T v1 = bottom_data[y_low * width + x_low]; - T v2 = bottom_data[y_low * width + x_high]; - T v3 = bottom_data[y_high * width + x_low]; - T v4 = bottom_data[y_high * width + x_high]; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + TAcc v1 = static_cast(bottom_data[y_low * width + x_low]); + TAcc v2 = static_cast(bottom_data[y_low * width + x_high]); + TAcc v3 = static_cast(bottom_data[y_high * width + x_low]); + TAcc v4 = static_cast(bottom_data[y_high * width + x_high]); + TAcc w1 = hy * hx; + TAcc w2 = hy * lx; + TAcc w3 = ly * hx; + TAcc w4 = ly * lx; - T val = is_mode_avg - ? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg - : max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max + TAcc val = is_mode_avg + ? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg + : max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max return val; } @@ -97,6 +105,8 @@ __global__ void RoIAlignForward( const bool half_pixel, const int64_t* batch_indices_ptr, const int64_t batch_size) { + using TAcc = AccumulationType_t; + for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -111,26 +121,27 @@ __global__ void RoIAlignForward( // If the index is out of range, we set the output to 0 for this RoI element. if (roi_batch_ind < 0 || roi_batch_ind >= batch_size) { CUDA_KERNEL_ASSERT(false && "batch_indices values are out of range"); - top_data[index] = 0; + top_data[index] = static_cast(0.0f); continue; } // Do not using rounding; this implementation detail is critical - T roi_offset = half_pixel ? T(0.5) : T(0); - T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset; - T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset; - T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset; - T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; + const TAcc spatial_scale_acc = static_cast(spatial_scale); + const TAcc roi_offset = half_pixel ? static_cast(0.5f) : static_cast(0.0f); + TAcc roi_start_w = static_cast(offset_bottom_rois[0]) * spatial_scale_acc - roi_offset; + TAcc roi_start_h = static_cast(offset_bottom_rois[1]) * spatial_scale_acc - roi_offset; + TAcc roi_end_w = static_cast(offset_bottom_rois[2]) * spatial_scale_acc - roi_offset; + TAcc roi_end_h = static_cast(offset_bottom_rois[3]) * spatial_scale_acc - roi_offset; + + TAcc roi_width = roi_end_w - roi_start_w; + TAcc roi_height = roi_end_h - roi_start_h; if (!half_pixel) { // backward compatibility // Force malformed ROIs to be 1x1 - roi_width = max(roi_width, (T)1.); - roi_height = max(roi_height, (T)1.); + roi_width = max(roi_width, static_cast(1.0f)); + roi_height = max(roi_height, static_cast(1.0f)); } - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + const TAcc bin_size_h = roi_height / static_cast(pooled_height); + const TAcc bin_size_w = roi_width / static_cast(pooled_width); const T* offset_bottom_data = bottom_data + static_cast((roi_batch_ind * channels + c) * height * width); @@ -138,26 +149,27 @@ __global__ void RoIAlignForward( // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio - : _Ceil(roi_height / pooled_height); // e.g., = 2 + : static_cast(_Ceil(roi_height / static_cast(pooled_height))); // e.g., = 2 int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : _Ceil(roi_width / pooled_width); + (sampling_ratio > 0) ? sampling_ratio : static_cast(_Ceil(roi_width / static_cast(pooled_width))); // We do average (integral) pooling inside a bin - const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + const int grid_count = max(roi_bin_grid_h * roi_bin_grid_w, 1); + const TAcc count = static_cast(grid_count); // e.g. = 4 - T output_val = 0.; + TAcc output_val = static_cast(0.0f); bool max_flag = false; for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + const TAcc y = roi_start_h + static_cast(ph) * bin_size_h + + (static_cast(iy) + static_cast(0.5f)) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); + const TAcc x = roi_start_w + static_cast(pw) * bin_size_w + + (static_cast(ix) + static_cast(0.5f)) * bin_size_w / + static_cast(roi_bin_grid_w); - T val = bilinear_interpolate( + const TAcc val = bilinear_interpolate( offset_bottom_data, height, width, y, x, is_mode_avg, index); if (is_mode_avg) { @@ -176,7 +188,7 @@ __global__ void RoIAlignForward( output_val /= count; } - top_data[index] = output_val; + top_data[index] = static_cast(output_val); } } @@ -241,6 +253,8 @@ void RoiAlignImpl( SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double) +SPECIALIZED_IMPL(half) +SPECIALIZED_IMPL(BFloat16) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc index 1eeb3683bc9aa..b2abe353693a2 100644 --- a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" #include "test/common/trt_op_test_utils.h" @@ -906,5 +907,138 @@ TEST(RoiAlignTest, BatchIndicesNegative_CUDA) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); #endif } + +TEST(RoiAlignTest, Float16_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToFloat16({1.25f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoiAlignTest, Float16_Opset22) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 22); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToFloat16({1.25f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoiAlignTest, BFloat16_Opset22) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 22); + test.AddAttribute("output_height", 1); + test.AddAttribute("output_width", 1); + test.AddAttribute("sampling_ratio", 1); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + test.AddInput("X", {1, 1, 2, 2}, ToBFloat16({1., 2., 1., 1.})); + test.AddInput("rois", {1, 4}, ToBFloat16({0., 0., 1., 1.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 1, 1}, ToBFloat16({1.25f})); + + test.SetOutputTolerance(0.05f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test half_pixel mode (default for Opset 16+) with Float16 on larger spatial dimensions. +// Uses 8x8 input (0..63), ROI [0,0,7,7], output 2x2, sampling_ratio=2. +// Expected values from ONNX reference implementation: {11.25, 14.75, 39.25, 42.75} +TEST(RoiAlignTest, Float16_HalfPixel_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + + std::vector X_val(64); + for (int i = 0; i < 64; ++i) X_val[i] = static_cast(i); + test.AddInput("X", {1, 1, 8, 8}, ToFloat16(X_val)); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 7., 7.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 2, 2}, ToFloat16({11.25f, 14.75f, 39.25f, 42.75f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test adaptive sampling (sampling_ratio=0) with Float16 on larger spatial dimensions. +// Uses 8x8 input (0..63), ROI [0,0,7,7], output 2x2, half_pixel mode. +// Adaptive: ceil(3.0/2)=2 samples per dim. +// Expected values from ONNX reference implementation: {11.39062, 14.875, 39.26562, 42.75} +TEST(RoiAlignTest, Float16_AdaptiveSampling_Opset16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep.get() == nullptr) { + GTEST_SKIP() << "Skipping because there is no CUDA execution provider available."; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 0); // adaptive + test.AddAttribute("spatial_scale", 1.0f); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + + std::vector X_val(64); + for (int i = 0; i < 64; ++i) X_val[i] = static_cast(i); + test.AddInput("X", {1, 1, 8, 8}, ToFloat16(X_val)); + test.AddInput("rois", {1, 4}, ToFloat16({0., 0., 7., 7.})); + test.AddInput("batch_indices", {1}, {0}); + test.AddOutput("Y", {1, 1, 2, 2}, + ToFloat16({11.39062f, 14.875f, 39.26562f, 42.75f})); + + test.SetOutputTolerance(0.01f); + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime